提交 062ead13 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Fixed test values with ``ifelse``

By using ``__call__`` instead of ``make_node``, test values are properly set.
上级 009177d7
...@@ -141,8 +141,9 @@ class IfElse(PureOp): ...@@ -141,8 +141,9 @@ class IfElse(PureOp):
as_view=False, as_view=False,
gpu=False, gpu=False,
name='_'.join(name_tokens)) name='_'.join(name_tokens))
new_outs = new_ifelse.make_node(node.inputs[0], new_outs = new_ifelse(node.inputs[0],
*(new_ts_inputs + new_fs_inputs)).outputs *(new_ts_inputs + new_fs_inputs),
**dict(return_list=True))
else: else:
new_outs = [] new_outs = []
...@@ -187,7 +188,7 @@ class IfElse(PureOp): ...@@ -187,7 +188,7 @@ class IfElse(PureOp):
return Apply(self, [c] + list(args), [t.type() for t in ts]) return Apply(self, [c] + list(args), [t.type() for t in ts])
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
return self.make_node(inputs[0], *eval_points[1:]).outputs return self(inputs[0], *eval_points[1:], **dict(return_list=True))
def grad(self, ins, grads): def grad(self, ins, grads):
ts = ins[1:][:self.n_outs] ts = ins[1:][:self.n_outs]
...@@ -220,8 +221,8 @@ class IfElse(PureOp): ...@@ -220,8 +221,8 @@ class IfElse(PureOp):
condition_grad = condition.zeros_like().astype(theano.config.floatX) condition_grad = condition.zeros_like().astype(theano.config.floatX)
return ([condition_grad] + return ([condition_grad] +
if_true_op.make_node(*if_true).outputs + if_true_op(*if_true, **dict(return_list=True)) +
if_false_op.make_node(*if_false).outputs) if_false_op(*if_false, **dict(return_list=True)))
def make_thunk(self, node, storage_map, compute_map, no_recycling): def make_thunk(self, node, storage_map, compute_map, no_recycling):
outtypes = [out.type for out in node.outputs] outtypes = [out.type for out in node.outputs]
...@@ -371,7 +372,7 @@ def ifelse(condition, then_branch, else_branch, name=None): ...@@ -371,7 +372,7 @@ def ifelse(condition, then_branch, else_branch, name=None):
name=name) name=name)
ins = [condition] + list(new_then_branch) + list(new_else_branch) ins = [condition] + list(new_then_branch) + list(new_else_branch)
rval = new_ifelse.make_node(*ins).outputs rval = new_ifelse(*ins, **dict(return_list=True))
if rval_type is None: if rval_type is None:
return rval[0] return rval[0]
...@@ -388,7 +389,7 @@ def cond_make_inplace(node): ...@@ -388,7 +389,7 @@ def cond_make_inplace(node):
return IfElse(n_outs=op.n_outs, return IfElse(n_outs=op.n_outs,
as_view=True, as_view=True,
gpu=op.gpu, gpu=op.gpu,
name=op.name).make_node(*node.inputs).outputs name=op.name)(*node.inputs, **dict(return_list=True))
return False return False
...@@ -481,14 +482,12 @@ def ifelse_lift_single_if_through_acceptable_ops(main_node): ...@@ -481,14 +482,12 @@ def ifelse_lift_single_if_through_acceptable_ops(main_node):
else: else:
true_ins.append(x) true_ins.append(x)
false_ins.append(x) false_ins.append(x)
true_eval = mop.make_node(*true_ins).outputs true_eval = mop(*true_ins, **dict(return_list=True))
false_eval = mop.make_node(*false_ins).outputs false_eval = mop(*false_ins, **dict(return_list=True))
#true_eval = clone(outs, replace = dict(zip(node.outputs, ts))) #true_eval = clone(outs, replace = dict(zip(node.outputs, ts)))
#false_eval = clone(outs, replace = dict(zip(node.outputs, fs))) #false_eval = clone(outs, replace = dict(zip(node.outputs, fs)))
nw_outs = ifelse(node.inputs[0], true_eval, false_eval) nw_outs = ifelse(node.inputs[0], true_eval, false_eval, return_list=True)
if type(nw_outs) not in (tuple, list):
nw_outs = [nw_outs]
return nw_outs return nw_outs
...@@ -513,7 +512,7 @@ def cond_merge_ifs_true(node): ...@@ -513,7 +512,7 @@ def cond_merge_ifs_true(node):
old_ins = list(node.inputs) old_ins = list(node.inputs)
for pos, var in replace.items(): for pos, var in replace.items():
old_ins[pos] = var old_ins[pos] = var
return op.make_node(*old_ins).outputs return op(*old_ins, **dict(return_list=True))
@gof.local_optimizer([None]) @gof.local_optimizer([None])
...@@ -538,7 +537,7 @@ def cond_merge_ifs_false(node): ...@@ -538,7 +537,7 @@ def cond_merge_ifs_false(node):
old_ins = list(node.inputs) old_ins = list(node.inputs)
for pos, var in replace.items(): for pos, var in replace.items():
old_ins[pos] = var old_ins[pos] = var
return op.make_node(*old_ins).outputs return op(*old_ins, **dict(return_list=True))
class CondMerge(gof.Optimizer): class CondMerge(gof.Optimizer):
...@@ -576,7 +575,7 @@ class CondMerge(gof.Optimizer): ...@@ -576,7 +575,7 @@ class CondMerge(gof.Optimizer):
gpu=False, gpu=False,
name=mn_name + '&' + pl_name) name=mn_name + '&' + pl_name)
print 'here' print 'here'
new_outs = new_ifelse.make_node(*new_ins).outputs new_outs = new_ifelse(*new_ins, **dict(return_list=True))
new_outs = [clone(x) for x in new_outs] new_outs = [clone(x) for x in new_outs]
old_outs = [] old_outs = []
if type(merging_node.outputs) not in (list, tuple): if type(merging_node.outputs) not in (list, tuple):
...@@ -630,7 +629,7 @@ def cond_remove_identical(node): ...@@ -630,7 +629,7 @@ def cond_remove_identical(node):
name=op.name) name=op.name)
new_ins = [node.inputs[0]] + nw_ts + nw_fs new_ins = [node.inputs[0]] + nw_ts + nw_fs
new_outs = new_ifelse.make_node(*new_ins).outputs new_outs = new_ifelse(*new_ins, **dict(return_list=True))
rval = [] rval = []
for idx in xrange(len(node.outputs)): for idx in xrange(len(node.outputs)):
...@@ -681,7 +680,7 @@ def cond_merge_random_op(main_node): ...@@ -681,7 +680,7 @@ def cond_merge_random_op(main_node):
as_view=False, as_view=False,
gpu=False, gpu=False,
name=mn_name + '&' + pl_name) name=mn_name + '&' + pl_name)
new_outs = new_ifelse.make_node(*new_ins).outputs new_outs = new_ifelse(*new_ins, **dict(return_list=True))
old_outs = [] old_outs = []
if type(merging_node.outputs) not in (list, tuple): if type(merging_node.outputs) not in (list, tuple):
old_outs += [merging_node.outputs] old_outs += [merging_node.outputs]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论