提交 a8be1ef7 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Cast variable before cloning, so test value gets casted

上级 2846a3aa
...@@ -73,11 +73,13 @@ def safe_new(x, tag='', dtype=None): ...@@ -73,11 +73,13 @@ def safe_new(x, tag='', dtype=None):
# want to avoid the convoluted logic that checks for cuda # want to avoid the convoluted logic that checks for cuda
# ndarrays # ndarrays
pass pass
# Cast x if needed. If x has a test value, this will also cast it.
if dtype and x.dtype != dtype:
x = x.astype(dtype)
nw_x = x.type() nw_x = x.type()
if dtype and nw_x.dtype != dtype:
nw_x = nw_x.astype(dtype).type()
nw_x.name = nw_name nw_x.name = nw_name
# Preserve test values so that the 'compute_test_value' option can be used. # Preserve test values so that the 'compute_test_value' option can be used.
# The test value is deep-copied to ensure there can be no interactions # The test value is deep-copied to ensure there can be no interactions
# between test values, due to inplace operations for instance. This may # between test values, due to inplace operations for instance. This may
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论