提交 f7e4787b authored 作者: Olivier Breuleux's avatar Olivier Breuleux

changed the method of pretty printing in module

上级 b2f8af83
...@@ -33,7 +33,7 @@ class LogisticRegressionN(module.FancyModule): ...@@ -33,7 +33,7 @@ class LogisticRegressionN(module.FancyModule):
self.params = [self.w, self.b] self.params = [self.w, self.b]
xent, y = nnet_ops.crossentropy_softmax_1hot( xent, y = nnet.crossentropy_softmax_1hot(
T.dot(self.x, self.w) + self.b, self.targ) T.dot(self.x, self.w) + self.b, self.targ)
xent = T.sum(xent) xent = T.sum(xent)
...@@ -69,7 +69,7 @@ class LogisticRegression2(module.FancyModule): ...@@ -69,7 +69,7 @@ class LogisticRegression2(module.FancyModule):
self.params = [self.w, self.b] self.params = [self.w, self.b]
y = nnet_ops.sigmoid(T.dot(self.x, self.w)) y = nnet.sigmoid(T.dot(self.x, self.w))
xent_elem = -self.targ * T.log(y) - (1.0 - self.targ) * T.log(1.0 - y) xent_elem = -self.targ * T.log(y) - (1.0 - self.targ) * T.log(1.0 - y)
xent = T.sum(xent_elem) xent = T.sum(xent_elem)
...@@ -86,8 +86,8 @@ class LogisticRegression2(module.FancyModule): ...@@ -86,8 +86,8 @@ class LogisticRegression2(module.FancyModule):
if __name__ == '__main__': if __name__ == '__main__':
pprint.assign(nnet_ops.crossentropy_softmax_1hot_with_bias_dx, printing.FunctionPrinter('xsoftmaxdx')) pprint.assign(nnet.crossentropy_softmax_1hot_with_bias_dx, printing.FunctionPrinter('xsoftmaxdx'))
pprint.assign(nnet_ops.crossentropy_softmax_argmax_1hot_with_bias, printing.FunctionPrinter('nll', 'softmax', 'argmax')) pprint.assign(nnet.crossentropy_softmax_argmax_1hot_with_bias, printing.FunctionPrinter('nll', 'softmax', 'argmax'))
if 1: if 1:
lrc = LogisticRegressionN() lrc = LogisticRegressionN()
......
...@@ -43,8 +43,9 @@ class Component(object): ...@@ -43,8 +43,9 @@ class Component(object):
try: try:
return self.dup().bind(parent, name, False) return self.dup().bind(parent, name, False)
except BindError, e: except BindError, e:
#TODO: Add a hint that this could be caused by a buggy dup() that doesn't e.args = (e.args[0] +
#follow it's contract ' ; This seems to have been caused by an implementation of dup'
' that keeps the previous binding (%s)' % self.dup,) + e.args[1:]
raise raise
else: else:
raise BindError("%s is already bound to %s as %s" % (self, self.parent, self.name)) raise BindError("%s is already bound to %s as %s" % (self, self.parent, self.name))
...@@ -196,7 +197,7 @@ class Method(Component): ...@@ -196,7 +197,7 @@ class Method(Component):
def allocate(self, memo): def allocate(self, memo):
return None return None
def build(self, mode, memo): def build(self, mode, memo, allocate_all = False):
self.resolve_all() self.resolve_all()
def get_storage(r, require = False): def get_storage(r, require = False):
try: try:
...@@ -215,7 +216,7 @@ class Method(Component): ...@@ -215,7 +216,7 @@ class Method(Component):
for input in inputs] for input in inputs]
inputs += [io.In(result = k, inputs += [io.In(result = k,
update = v, update = v,
value = get_storage(k, True), value = get_storage(k, not allocate_all),
mutable = True, mutable = True,
strict = True) strict = True)
for k, v in self.updates.iteritems()] for k, v in self.updates.iteritems()]
...@@ -226,17 +227,13 @@ class Method(Component): ...@@ -226,17 +227,13 @@ class Method(Component):
blockers = _inputs): blockers = _inputs):
if input not in _inputs and not isinstance(input, gof.Value): if input not in _inputs and not isinstance(input, gof.Value):
inputs += [io.In(result = input, inputs += [io.In(result = input,
value = get_storage(input, True), value = get_storage(input, not allocate_all),
mutable = False)] mutable = False)]
inputs += [(kit, get_storage(kit, True)) for kit in self.kits] inputs += [(kit, get_storage(kit, not allocate_all)) for kit in self.kits]
return F.function(inputs, outputs, mode) return F.function(inputs, outputs, mode)
def pretty(self, **kwargs): def pretty(self, **kwargs):
self.resolve_all() self.resolve_all()
# cr = '\n ' if header else '\n'
# rval = ''
# if header:
# rval += "Method(%s):" % ", ".join(map(str, self.inputs))
if self.inputs: if self.inputs:
rval = 'inputs: %s\n' % ", ".join(map(str, self.inputs)) rval = 'inputs: %s\n' % ", ".join(map(str, self.inputs))
else: else:
...@@ -244,19 +241,23 @@ class Method(Component): ...@@ -244,19 +241,23 @@ class Method(Component):
mode = kwargs.pop('mode', None) mode = kwargs.pop('mode', None)
inputs, outputs, updates = self.inputs, self.outputs if isinstance(self.outputs, (list, tuple)) else [self.outputs], self.updates inputs, outputs, updates = self.inputs, self.outputs if isinstance(self.outputs, (list, tuple)) else [self.outputs], self.updates
if mode: if mode:
nin = len(inputs) f = self.build(mode, {}, True)
nout = len(outputs) einputs, eoutputs = f.maker.env.inputs, f.maker.env.outputs
k, v = zip(*updates.items()) if updates else ((), ()) updates = dict(((k, v) for k, v in zip(einputs[len(inputs):], eoutputs[len(outputs):])))
nup = len(k) inputs, outputs = einputs[:len(inputs)], eoutputs[:len(outputs)]
eff_in = tuple(inputs) + tuple(k) # nin = len(inputs)
eff_out = tuple(outputs) + tuple(v) # nout = len(outputs)
supp_in = tuple(gof.graph.inputs(eff_out)) # k, v = zip(*updates.items()) if updates else ((), ())
env = gof.Env(*gof.graph.clone(eff_in + supp_in, # nup = len(k)
eff_out)) # eff_in = tuple(inputs) + tuple(k)
sup = F.Supervisor(set(env.inputs).difference(env.inputs[len(inputs):len(eff_in)])) # eff_out = tuple(outputs) + tuple(v)
env.extend(sup) # supp_in = tuple(gof.graph.inputs(eff_out))
mode.optimizer.optimize(env) # env = gof.Env(*gof.graph.clone(eff_in + supp_in,
inputs, outputs, updates = env.inputs[:nin], env.outputs[:nout], dict(zip(env.inputs[nin:], env.outputs[nout:])) # eff_out))
# sup = F.Supervisor(set(env.inputs).difference(env.inputs[len(inputs):len(eff_in)]))
# env.extend(sup)
# mode.optimizer.optimize(env)
# inputs, outputs, updates = env.inputs[:nin], env.outputs[:nout], dict(zip(env.inputs[nin:], env.outputs[nout:]))
rval += pprint(inputs, outputs, updates, False) rval += pprint(inputs, outputs, updates, False)
return rval return rval
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论