提交 f7e4787b authored 作者: Olivier Breuleux's avatar Olivier Breuleux

changed the method of pretty printing in module

上级 b2f8af83
......@@ -33,7 +33,7 @@ class LogisticRegressionN(module.FancyModule):
self.params = [self.w, self.b]
xent, y = nnet_ops.crossentropy_softmax_1hot(
xent, y = nnet.crossentropy_softmax_1hot(
T.dot(self.x, self.w) + self.b, self.targ)
xent = T.sum(xent)
......@@ -69,7 +69,7 @@ class LogisticRegression2(module.FancyModule):
self.params = [self.w, self.b]
y = nnet_ops.sigmoid(T.dot(self.x, self.w))
y = nnet.sigmoid(T.dot(self.x, self.w))
xent_elem = -self.targ * T.log(y) - (1.0 - self.targ) * T.log(1.0 - y)
xent = T.sum(xent_elem)
......@@ -86,8 +86,8 @@ class LogisticRegression2(module.FancyModule):
if __name__ == '__main__':
pprint.assign(nnet_ops.crossentropy_softmax_1hot_with_bias_dx, printing.FunctionPrinter('xsoftmaxdx'))
pprint.assign(nnet_ops.crossentropy_softmax_argmax_1hot_with_bias, printing.FunctionPrinter('nll', 'softmax', 'argmax'))
pprint.assign(nnet.crossentropy_softmax_1hot_with_bias_dx, printing.FunctionPrinter('xsoftmaxdx'))
pprint.assign(nnet.crossentropy_softmax_argmax_1hot_with_bias, printing.FunctionPrinter('nll', 'softmax', 'argmax'))
if 1:
lrc = LogisticRegressionN()
......
......@@ -43,8 +43,9 @@ class Component(object):
try:
return self.dup().bind(parent, name, False)
except BindError, e:
#TODO: Add a hint that this could be caused by a buggy dup() that doesn't
#follow it's contract
e.args = (e.args[0] +
' ; This seems to have been caused by an implementation of dup'
' that keeps the previous binding (%s)' % self.dup,) + e.args[1:]
raise
else:
raise BindError("%s is already bound to %s as %s" % (self, self.parent, self.name))
......@@ -196,7 +197,7 @@ class Method(Component):
def allocate(self, memo):
return None
def build(self, mode, memo):
def build(self, mode, memo, allocate_all = False):
self.resolve_all()
def get_storage(r, require = False):
try:
......@@ -215,7 +216,7 @@ class Method(Component):
for input in inputs]
inputs += [io.In(result = k,
update = v,
value = get_storage(k, True),
value = get_storage(k, not allocate_all),
mutable = True,
strict = True)
for k, v in self.updates.iteritems()]
......@@ -226,17 +227,13 @@ class Method(Component):
blockers = _inputs):
if input not in _inputs and not isinstance(input, gof.Value):
inputs += [io.In(result = input,
value = get_storage(input, True),
value = get_storage(input, not allocate_all),
mutable = False)]
inputs += [(kit, get_storage(kit, True)) for kit in self.kits]
inputs += [(kit, get_storage(kit, not allocate_all)) for kit in self.kits]
return F.function(inputs, outputs, mode)
def pretty(self, **kwargs):
self.resolve_all()
# cr = '\n ' if header else '\n'
# rval = ''
# if header:
# rval += "Method(%s):" % ", ".join(map(str, self.inputs))
if self.inputs:
rval = 'inputs: %s\n' % ", ".join(map(str, self.inputs))
else:
......@@ -244,19 +241,23 @@ class Method(Component):
mode = kwargs.pop('mode', None)
inputs, outputs, updates = self.inputs, self.outputs if isinstance(self.outputs, (list, tuple)) else [self.outputs], self.updates
if mode:
nin = len(inputs)
nout = len(outputs)
k, v = zip(*updates.items()) if updates else ((), ())
nup = len(k)
eff_in = tuple(inputs) + tuple(k)
eff_out = tuple(outputs) + tuple(v)
supp_in = tuple(gof.graph.inputs(eff_out))
env = gof.Env(*gof.graph.clone(eff_in + supp_in,
eff_out))
sup = F.Supervisor(set(env.inputs).difference(env.inputs[len(inputs):len(eff_in)]))
env.extend(sup)
mode.optimizer.optimize(env)
inputs, outputs, updates = env.inputs[:nin], env.outputs[:nout], dict(zip(env.inputs[nin:], env.outputs[nout:]))
f = self.build(mode, {}, True)
einputs, eoutputs = f.maker.env.inputs, f.maker.env.outputs
updates = dict(((k, v) for k, v in zip(einputs[len(inputs):], eoutputs[len(outputs):])))
inputs, outputs = einputs[:len(inputs)], eoutputs[:len(outputs)]
# nin = len(inputs)
# nout = len(outputs)
# k, v = zip(*updates.items()) if updates else ((), ())
# nup = len(k)
# eff_in = tuple(inputs) + tuple(k)
# eff_out = tuple(outputs) + tuple(v)
# supp_in = tuple(gof.graph.inputs(eff_out))
# env = gof.Env(*gof.graph.clone(eff_in + supp_in,
# eff_out))
# sup = F.Supervisor(set(env.inputs).difference(env.inputs[len(inputs):len(eff_in)]))
# env.extend(sup)
# mode.optimizer.optimize(env)
# inputs, outputs, updates = env.inputs[:nin], env.outputs[:nout], dict(zip(env.inputs[nin:], env.outputs[nout:]))
rval += pprint(inputs, outputs, updates, False)
return rval
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论