提交 67f3ff83 authored 作者: lamblin's avatar lamblin

Merge pull request #1314 from delallea/test_value_fix

Fixed test values with ``ifelse``
......@@ -3,6 +3,7 @@
The `Op` class is the base interface for all operations
compatible with `gof`'s :doc:`graph` routines.
"""
__authors__ = "theano-dev"
__copyright__ = "(c) 2010, Universite de Montreal"
__license__ = "3-clause BSD License"
......@@ -373,7 +374,16 @@ class PureOp(object):
`default_output`, but subclasses are free to override this function and ignore
`default_output`.
:param inputs: The Op's inputs, forwarded to the call to `make_node()`.
:param kwargs: Additional keyword arguments to be forwarded to
`make_node()` *except* for optional argument `return_list` (which
defaults to False). If `return_list` is True, then the returned
value is always a list. Otherwise it is either a single Variable
when the output of `make_node()` contains a single element, or this
output (unchanged) when it contains multiple elements.
"""
return_list = kwargs.pop('return_list', False)
node = self.make_node(*inputs, **kwargs)
if self.add_stack_trace_on_call:
self.add_tag_trace(node)
......@@ -434,9 +444,14 @@ class PureOp(object):
output.tag.test_value = storage_map[output][0]
if self.default_output is not None:
return node.outputs[self.default_output]
rval = node.outputs[self.default_output]
if return_list:
rval = [rval]
return rval
else:
if len(node.outputs) == 1:
if return_list:
return list(node.outputs)
elif len(node.outputs) == 1:
return node.outputs[0]
else:
return node.outputs
......@@ -449,7 +464,6 @@ class PureOp(object):
# Python implementation #
#########################
def R_op(self, inputs, eval_points):
"""
......@@ -480,7 +494,6 @@ class PureOp(object):
"own op, implement the R_op method." %
(self, self.__class__.__name__))
def perform(self, node, inputs, output_storage):
"""
Required: Calculate the function on the inputs and put the variables in the
......@@ -610,6 +623,7 @@ class Op(utils.object2, PureOp, CLinkerOp):
rval.lazy = False
return rval
def get_test_value(v):
"""
Extract test value from `v`. Raises AttributeError if there is none.
......
......@@ -141,8 +141,9 @@ class IfElse(PureOp):
as_view=False,
gpu=False,
name='_'.join(name_tokens))
new_outs = new_ifelse.make_node(node.inputs[0],
*(new_ts_inputs + new_fs_inputs)).outputs
new_outs = new_ifelse(node.inputs[0],
*(new_ts_inputs + new_fs_inputs),
**dict(return_list=True))
else:
new_outs = []
......@@ -187,7 +188,7 @@ class IfElse(PureOp):
return Apply(self, [c] + list(args), [t.type() for t in ts])
def R_op(self, inputs, eval_points):
return self.make_node(inputs[0], *eval_points[1:]).outputs
return self(inputs[0], *eval_points[1:], **dict(return_list=True))
def grad(self, ins, grads):
ts = ins[1:][:self.n_outs]
......@@ -220,8 +221,8 @@ class IfElse(PureOp):
condition_grad = condition.zeros_like().astype(theano.config.floatX)
return ([condition_grad] +
if_true_op.make_node(*if_true).outputs +
if_false_op.make_node(*if_false).outputs)
if_true_op(*if_true, **dict(return_list=True)) +
if_false_op(*if_false, **dict(return_list=True)))
def make_thunk(self, node, storage_map, compute_map, no_recycling):
outtypes = [out.type for out in node.outputs]
......@@ -326,9 +327,11 @@ def ifelse(condition, then_branch, else_branch, name=None):
new_else_branch = []
for then_branch_elem, else_branch_elem in izip(then_branch, else_branch):
if not isinstance(then_branch_elem, theano.Variable):
then_branch_elem = theano.tensor.as_tensor_variable(then_branch_elem)
then_branch_elem = theano.tensor.as_tensor_variable(
then_branch_elem)
if not isinstance(else_branch_elem, theano.Variable):
else_branch_elem = theano.tensor.as_tensor_variable(else_branch_elem)
else_branch_elem = theano.tensor.as_tensor_variable(
else_branch_elem)
if then_branch_elem.type != else_branch_elem.type:
# If one of them is a TensorType, and the other one can be
......@@ -371,7 +374,7 @@ def ifelse(condition, then_branch, else_branch, name=None):
name=name)
ins = [condition] + list(new_then_branch) + list(new_else_branch)
rval = new_ifelse.make_node(*ins).outputs
rval = new_ifelse(*ins, **dict(return_list=True))
if rval_type is None:
return rval[0]
......@@ -388,7 +391,7 @@ def cond_make_inplace(node):
return IfElse(n_outs=op.n_outs,
as_view=True,
gpu=op.gpu,
name=op.name).make_node(*node.inputs).outputs
name=op.name)(*node.inputs, **dict(return_list=True))
return False
......@@ -481,14 +484,12 @@ def ifelse_lift_single_if_through_acceptable_ops(main_node):
else:
true_ins.append(x)
false_ins.append(x)
true_eval = mop.make_node(*true_ins).outputs
false_eval = mop.make_node(*false_ins).outputs
true_eval = mop(*true_ins, **dict(return_list=True))
false_eval = mop(*false_ins, **dict(return_list=True))
#true_eval = clone(outs, replace = dict(zip(node.outputs, ts)))
#false_eval = clone(outs, replace = dict(zip(node.outputs, fs)))
nw_outs = ifelse(node.inputs[0], true_eval, false_eval)
if type(nw_outs) not in (tuple, list):
nw_outs = [nw_outs]
nw_outs = ifelse(node.inputs[0], true_eval, false_eval, return_list=True)
return nw_outs
......@@ -513,7 +514,7 @@ def cond_merge_ifs_true(node):
old_ins = list(node.inputs)
for pos, var in replace.items():
old_ins[pos] = var
return op.make_node(*old_ins).outputs
return op(*old_ins, **dict(return_list=True))
@gof.local_optimizer([None])
......@@ -538,7 +539,7 @@ def cond_merge_ifs_false(node):
old_ins = list(node.inputs)
for pos, var in replace.items():
old_ins[pos] = var
return op.make_node(*old_ins).outputs
return op(*old_ins, **dict(return_list=True))
class CondMerge(gof.Optimizer):
......@@ -576,7 +577,7 @@ class CondMerge(gof.Optimizer):
gpu=False,
name=mn_name + '&' + pl_name)
print 'here'
new_outs = new_ifelse.make_node(*new_ins).outputs
new_outs = new_ifelse(*new_ins, **dict(return_list=True))
new_outs = [clone(x) for x in new_outs]
old_outs = []
if type(merging_node.outputs) not in (list, tuple):
......@@ -630,7 +631,7 @@ def cond_remove_identical(node):
name=op.name)
new_ins = [node.inputs[0]] + nw_ts + nw_fs
new_outs = new_ifelse.make_node(*new_ins).outputs
new_outs = new_ifelse(*new_ins, **dict(return_list=True))
rval = []
for idx in xrange(len(node.outputs)):
......@@ -681,7 +682,7 @@ def cond_merge_random_op(main_node):
as_view=False,
gpu=False,
name=mn_name + '&' + pl_name)
new_outs = new_ifelse.make_node(*new_ins).outputs
new_outs = new_ifelse(*new_ins, **dict(return_list=True))
old_outs = []
if type(merging_node.outputs) not in (list, tuple):
old_outs += [merging_node.outputs]
......
......@@ -367,6 +367,20 @@ class test_ifelse(unittest.TestCase, utt.TestOptimizationMixin):
assert numpy.allclose(f(vx1, vx2, vy1, vy2, vw1, vw2, 0),
vx2 + vy2 + vw2)
def test_grad_test_values(self):
"""
Regression test for test values of `ifelse` gradient.
"""
backup = theano.config.compute_test_value
theano.config.compute_test_value = 'raise'
try:
x = tensor.scalar('x')
x.tag.test_value = 1
# Used to crash due to undefined test value.
tensor.grad(ifelse(0, x, x), x)
finally:
theano.config.compute_test_value = backup
if __name__ == '__main__':
print ' Use nosetests to run these tests '
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论