提交 6af465ee authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Fix when using compute_test_value option with scan

This avoids warnings / crashes (depending on the value of the compute_test_value config option) caused by scan's cloned variables. Also fixed a couple small typos.
上级 2dff4ecf
......@@ -36,7 +36,7 @@ _logger = logging.getLogger('theano.scan_utils')
def safe_new(x, tag=''):
"""
Internal function that constructs a new variable from x with the same
type, but with a different name ( old name + tag). This function is used
type, but with a different name (old name + tag). This function is used
by gradient, or the R-op to construct new variables for the inputs of
the inner graph such that there is no interference between the original
graph and the newly constructed graph.
......@@ -58,12 +58,22 @@ def safe_new(x, tag=''):
try:
x = tensor.as_tensor_variable(x)
except TypeError:
# This could happend for example for random states, and I really
# This could happen for example for random states, and I really
# want to avoid the convoluted logic that checks for cuda
# ndarrays
pass
nw_x = x.type()
nw_x.name = nw_name
# Preserve test values so that the 'compute_test_value' option can be used.
# The test value is deep-copied to ensure there can be no interactions
# between test values, due to inplace operations for instance. This may
# not be the most efficient memory-wise, though.
if theano.config.compute_test_value != 'off':
try:
nw_x.tag.test_value = copy.deepcopy(gof.op.get_test_value(x))
except AttributeError:
# This means `x` has no test value.
pass
return nw_x
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论