提交 3caaba8c authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Merge pull request #1052 from lamblin/fix_test_value

In safe_new, assign test value before casting
......@@ -55,7 +55,7 @@ def safe_new(x, tag='', dtype=None):
# making the pushout optimization fail
elif isinstance(x, scalar.ScalarVariable):
if dtype:
new_x = scalar.Scalar(dtype=dtype)()
nw_x = scalar.Scalar(dtype=dtype)()
else:
nw_x = x.type()
nw_x.name = nw_name
......@@ -69,8 +69,6 @@ def safe_new(x, tag='', dtype=None):
# ndarrays
pass
nw_x = x.type()
if dtype and nw_x.dtype != dtype:
nw_x = nw_x.astype(dtype)
nw_x.name = nw_name
# Preserve test values so that the 'compute_test_value' option can be used.
# The test value is deep-copied to ensure there can be no interactions
......@@ -82,6 +80,10 @@ def safe_new(x, tag='', dtype=None):
except AttributeError:
# This means `x` has no test value.
pass
if dtype and nw_x.dtype != dtype:
nw_x = nw_x.astype(dtype)
return nw_x
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论