提交 5ed8d923 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Don't convert stuff to TensorVariable in safe_new()

上级 be12a5cf
......@@ -594,7 +594,9 @@ def scan(fn,
if init_out.get('taps', None) == [-1]:
actual_arg = init_out['initial']
arg = safe_new(init_out['initial'])
if not isinstance(actual_arg, tensor.Variable):
actual_arg = tensor.as_tensor_variable(actual_arg)
arg = safe_new(actual_arg)
if isinstance(arg, tensor.Constant):
# safe new returns a clone of the constants, but that is not
# what we need for initial states
......
......@@ -64,14 +64,6 @@ def safe_new(x, tag='', dtype=None):
nw_x = x.type()
nw_x.name = nw_name
return nw_x
else:
try:
x = tensor.as_tensor_variable(x)
except TypeError:
# This could happen for example for random states, and I really
# want to avoid the convoluted logic that checks for cuda
# ndarrays
pass
nw_x = x.type()
if dtype and nw_x.dtype != dtype:
nw_x = nw_x.astype(dtype).type()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论