提交 f1f123b8 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Merge pull request #510 from delallea/test_value_scan_fix

Test value scan fix
...@@ -36,7 +36,7 @@ _logger = logging.getLogger('theano.scan_utils') ...@@ -36,7 +36,7 @@ _logger = logging.getLogger('theano.scan_utils')
def safe_new(x, tag=''): def safe_new(x, tag=''):
""" """
Internal function that constructs a new variable from x with the same Internal function that constructs a new variable from x with the same
type, but with a different name ( old name + tag). This function is used type, but with a different name (old name + tag). This function is used
by gradient, or the R-op to construct new variables for the inputs of by gradient, or the R-op to construct new variables for the inputs of
the inner graph such that there is no interference between the original the inner graph such that there is no interference between the original
graph and the newly constructed graph. graph and the newly constructed graph.
...@@ -58,12 +58,22 @@ def safe_new(x, tag=''): ...@@ -58,12 +58,22 @@ def safe_new(x, tag=''):
try: try:
x = tensor.as_tensor_variable(x) x = tensor.as_tensor_variable(x)
except TypeError: except TypeError:
# This could happend for example for random states, and I really # This could happen for example for random states, and I really
# want to avoid the convoluted logic that checks for cuda # want to avoid the convoluted logic that checks for cuda
# ndarrays # ndarrays
pass pass
nw_x = x.type() nw_x = x.type()
nw_x.name = nw_name nw_x.name = nw_name
# Preserve test values so that the 'compute_test_value' option can be used.
# The test value is deep-copied to ensure there can be no interactions
# between test values, due to inplace operations for instance. This may
# not be the most efficient memory-wise, though.
if theano.config.compute_test_value != 'off':
try:
nw_x.tag.test_value = copy.deepcopy(gof.op.get_test_value(x))
except AttributeError:
# This means `x` has no test value.
pass
return nw_x return nw_x
...@@ -212,9 +222,9 @@ def get_updates_and_outputs(ls): ...@@ -212,9 +222,9 @@ def get_updates_and_outputs(ls):
'The return value of your scan lambda expression may only be ' 'The return value of your scan lambda expression may only be '
'made of lists, tuples, or dictionaries containing Theano ' 'made of lists, tuples, or dictionaries containing Theano '
'variables (or `theano.scan_module.until` objects for ' 'variables (or `theano.scan_module.until` objects for '
'conditions). In particular if you need to use constant values, ' 'conditions). In particular if you need to use constant '
'you can use `tensor.constant` to turn them into Theano ' 'values, you can use `tensor.constant` to turn them into '
'variables.') 'Theano variables.')
if is_outputs(ls): if is_outputs(ls):
return None, _list(ls), {} return None, _list(ls), {}
......
...@@ -3310,3 +3310,26 @@ if __name__ == '__main__': ...@@ -3310,3 +3310,26 @@ if __name__ == '__main__':
print 37 print 37
scan_tst.test_save_mem_store_steps() scan_tst.test_save_mem_store_steps()
#''' #'''
def test_compute_test_value():
"""
Verify that test values can be used with scan.
"""
backup = theano.config.compute_test_value
theano.config.compute_test_value = 'raise'
try:
x = tensor.vector()
xv = numpy.ones(3, dtype=theano.config.floatX)
x.tag.test_value = xv
y = theano.shared(numpy.arange(3, dtype=theano.config.floatX))
z, _ = theano.scan(
fn=lambda u, v: u + v,
sequences=[x, y])
assert not _
# The gradient computation used to crash before 6af465e.
g = tensor.grad(z.sum(), x)
#f = theano.function([x], g)
#print f(xv)
finally:
theano.config.compute_test_value = backup
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论