提交 ae576886 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

safe_new should return a variable not a graph

上级 087493ec
......@@ -48,7 +48,10 @@ def safe_new(x, tag='', dtype=None):
nw_name = None
if isinstance(x, theano.Constant):
if dtype and x.dtype != dtype:
return x.clone().astype(dtype)
casted_x = x.astype(dtype)
nwx = x.__class__(casted_x.type, x.data, x.name)
nwx.tag = copy(x.tag)
return nwx
else:
return x.clone()
# Note, as_tensor_variable will convert the Scalar into a
......@@ -70,6 +73,8 @@ def safe_new(x, tag='', dtype=None):
# ndarrays
pass
nw_x = x.type()
if dtype and nw_x.dtype != dtype:
nw_x = nw_x.astype(dtype).type()
nw_x.name = nw_name
# Preserve test values so that the 'compute_test_value' option can be used.
# The test value is deep-copied to ensure there can be no interactions
......@@ -82,9 +87,6 @@ def safe_new(x, tag='', dtype=None):
# This means `x` has no test value.
pass
if dtype and nw_x.dtype != dtype:
nw_x = nw_x.astype(dtype)
return nw_x
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论