提交 6af465ee authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Fix when using compute_test_value option with scan

This avoids warnings / crashes (depending on the value of the compute_test_value config option) caused by scan's cloned variables. Also fixed a couple small typos.
上级 2dff4ecf
...@@ -36,7 +36,7 @@ _logger = logging.getLogger('theano.scan_utils') ...@@ -36,7 +36,7 @@ _logger = logging.getLogger('theano.scan_utils')
def safe_new(x, tag=''): def safe_new(x, tag=''):
""" """
Internal function that constructs a new variable from x with the same Internal function that constructs a new variable from x with the same
type, but with a different name ( old name + tag). This function is used type, but with a different name (old name + tag). This function is used
by gradient, or the R-op to construct new variables for the inputs of by gradient, or the R-op to construct new variables for the inputs of
the inner graph such that there is no interference between the original the inner graph such that there is no interference between the original
graph and the newly constructed graph. graph and the newly constructed graph.
...@@ -58,12 +58,22 @@ def safe_new(x, tag=''): ...@@ -58,12 +58,22 @@ def safe_new(x, tag=''):
try: try:
x = tensor.as_tensor_variable(x) x = tensor.as_tensor_variable(x)
except TypeError: except TypeError:
# This could happend for example for random states, and I really # This could happen for example for random states, and I really
# want to avoid the convoluted logic that checks for cuda # want to avoid the convoluted logic that checks for cuda
# ndarrays # ndarrays
pass pass
nw_x = x.type() nw_x = x.type()
nw_x.name = nw_name nw_x.name = nw_name
# Preserve test values so that the 'compute_test_value' option can be used.
# The test value is deep-copied to ensure there can be no interactions
# between test values, due to inplace operations for instance. This may
# not be the most efficient memory-wise, though.
if theano.config.compute_test_value != 'off':
try:
nw_x.tag.test_value = copy.deepcopy(gof.op.get_test_value(x))
except AttributeError:
# This means `x` has no test value.
pass
return nw_x return nw_x
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论