提交 062ead13 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Fixed test values with ``ifelse``

By using ``__call__`` instead of ``make_node``, test values are properly set.
上级 009177d7
......@@ -141,8 +141,9 @@ class IfElse(PureOp):
as_view=False,
gpu=False,
name='_'.join(name_tokens))
new_outs = new_ifelse.make_node(node.inputs[0],
*(new_ts_inputs + new_fs_inputs)).outputs
new_outs = new_ifelse(node.inputs[0],
*(new_ts_inputs + new_fs_inputs),
**dict(return_list=True))
else:
new_outs = []
......@@ -187,7 +188,7 @@ class IfElse(PureOp):
return Apply(self, [c] + list(args), [t.type() for t in ts])
def R_op(self, inputs, eval_points):
return self.make_node(inputs[0], *eval_points[1:]).outputs
return self(inputs[0], *eval_points[1:], **dict(return_list=True))
def grad(self, ins, grads):
ts = ins[1:][:self.n_outs]
......@@ -220,8 +221,8 @@ class IfElse(PureOp):
condition_grad = condition.zeros_like().astype(theano.config.floatX)
return ([condition_grad] +
if_true_op.make_node(*if_true).outputs +
if_false_op.make_node(*if_false).outputs)
if_true_op(*if_true, **dict(return_list=True)) +
if_false_op(*if_false, **dict(return_list=True)))
def make_thunk(self, node, storage_map, compute_map, no_recycling):
outtypes = [out.type for out in node.outputs]
......@@ -371,7 +372,7 @@ def ifelse(condition, then_branch, else_branch, name=None):
name=name)
ins = [condition] + list(new_then_branch) + list(new_else_branch)
rval = new_ifelse.make_node(*ins).outputs
rval = new_ifelse(*ins, **dict(return_list=True))
if rval_type is None:
return rval[0]
......@@ -388,7 +389,7 @@ def cond_make_inplace(node):
return IfElse(n_outs=op.n_outs,
as_view=True,
gpu=op.gpu,
name=op.name).make_node(*node.inputs).outputs
name=op.name)(*node.inputs, **dict(return_list=True))
return False
......@@ -481,14 +482,12 @@ def ifelse_lift_single_if_through_acceptable_ops(main_node):
else:
true_ins.append(x)
false_ins.append(x)
true_eval = mop.make_node(*true_ins).outputs
false_eval = mop.make_node(*false_ins).outputs
true_eval = mop(*true_ins, **dict(return_list=True))
false_eval = mop(*false_ins, **dict(return_list=True))
#true_eval = clone(outs, replace = dict(zip(node.outputs, ts)))
#false_eval = clone(outs, replace = dict(zip(node.outputs, fs)))
nw_outs = ifelse(node.inputs[0], true_eval, false_eval)
if type(nw_outs) not in (tuple, list):
nw_outs = [nw_outs]
nw_outs = ifelse(node.inputs[0], true_eval, false_eval, return_list=True)
return nw_outs
......@@ -513,7 +512,7 @@ def cond_merge_ifs_true(node):
old_ins = list(node.inputs)
for pos, var in replace.items():
old_ins[pos] = var
return op.make_node(*old_ins).outputs
return op(*old_ins, **dict(return_list=True))
@gof.local_optimizer([None])
......@@ -538,7 +537,7 @@ def cond_merge_ifs_false(node):
old_ins = list(node.inputs)
for pos, var in replace.items():
old_ins[pos] = var
return op.make_node(*old_ins).outputs
return op(*old_ins, **dict(return_list=True))
class CondMerge(gof.Optimizer):
......@@ -576,7 +575,7 @@ class CondMerge(gof.Optimizer):
gpu=False,
name=mn_name + '&' + pl_name)
print 'here'
new_outs = new_ifelse.make_node(*new_ins).outputs
new_outs = new_ifelse(*new_ins, **dict(return_list=True))
new_outs = [clone(x) for x in new_outs]
old_outs = []
if type(merging_node.outputs) not in (list, tuple):
......@@ -630,7 +629,7 @@ def cond_remove_identical(node):
name=op.name)
new_ins = [node.inputs[0]] + nw_ts + nw_fs
new_outs = new_ifelse.make_node(*new_ins).outputs
new_outs = new_ifelse(*new_ins, **dict(return_list=True))
rval = []
for idx in xrange(len(node.outputs)):
......@@ -681,7 +680,7 @@ def cond_merge_random_op(main_node):
as_view=False,
gpu=False,
name=mn_name + '&' + pl_name)
new_outs = new_ifelse.make_node(*new_ins).outputs
new_outs = new_ifelse(*new_ins, **dict(return_list=True))
old_outs = []
if type(merging_node.outputs) not in (list, tuple):
old_outs += [merging_node.outputs]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论