提交 f1f123b8 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Merge pull request #510 from delallea/test_value_scan_fix

Test value scan fix
......@@ -36,7 +36,7 @@ _logger = logging.getLogger('theano.scan_utils')
def safe_new(x, tag=''):
"""
Internal function that constructs a new variable from x with the same
type, but with a different name ( old name + tag). This function is used
type, but with a different name (old name + tag). This function is used
by gradient, or the R-op to construct new variables for the inputs of
the inner graph such that there is no interference between the original
graph and the newly constructed graph.
......@@ -58,12 +58,22 @@ def safe_new(x, tag=''):
try:
x = tensor.as_tensor_variable(x)
except TypeError:
# This could happend for example for random states, and I really
# This could happen for example for random states, and I really
# want to avoid the convoluted logic that checks for cuda
# ndarrays
pass
nw_x = x.type()
nw_x.name = nw_name
# Preserve test values so that the 'compute_test_value' option can be used.
# The test value is deep-copied to ensure there can be no interactions
# between test values, due to inplace operations for instance. This may
# not be the most efficient memory-wise, though.
if theano.config.compute_test_value != 'off':
try:
nw_x.tag.test_value = copy.deepcopy(gof.op.get_test_value(x))
except AttributeError:
# This means `x` has no test value.
pass
return nw_x
......@@ -212,9 +222,9 @@ def get_updates_and_outputs(ls):
'The return value of your scan lambda expression may only be '
'made of lists, tuples, or dictionaries containing Theano '
'variables (or `theano.scan_module.until` objects for '
'conditions). In particular if you need to use constant values, '
'you can use `tensor.constant` to turn them into Theano '
'variables.')
'conditions). In particular if you need to use constant '
'values, you can use `tensor.constant` to turn them into '
'Theano variables.')
if is_outputs(ls):
return None, _list(ls), {}
......
......@@ -3310,3 +3310,26 @@ if __name__ == '__main__':
print 37
scan_tst.test_save_mem_store_steps()
#'''
def test_compute_test_value():
"""
Verify that test values can be used with scan.
"""
backup = theano.config.compute_test_value
theano.config.compute_test_value = 'raise'
try:
x = tensor.vector()
xv = numpy.ones(3, dtype=theano.config.floatX)
x.tag.test_value = xv
y = theano.shared(numpy.arange(3, dtype=theano.config.floatX))
z, _ = theano.scan(
fn=lambda u, v: u + v,
sequences=[x, y])
assert not _
# The gradient computation used to crash before 6af465e.
g = tensor.grad(z.sum(), x)
#f = theano.function([x], g)
#print f(xv)
finally:
theano.config.compute_test_value = backup
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论