提交 f10a85bd authored 作者: Razvan Pascanu's avatar Razvan Pascanu

If somehow inputs of scan end up being numpy ndarrays cast them to constants

上级 6abe55be
...@@ -411,9 +411,10 @@ def scan( fn ...@@ -411,9 +411,10 @@ def scan( fn
# the output of the lambda expression directly to replace # the output of the lambda expression directly to replace
# the output of scan. # the output of scan.
# If not we need to use copies, that will be replaced at # If not we need to use copies, that will be replaced at
# each frame by the corresponding slice # each frame by the corresponding slice
nw_slice = seq['input'][0].type() _seq_val = tensor.as_tensor_variable(seq['input'])
nw_slice = _seq_val[0].type()
actual_slice = seq['input'][k-mintap] actual_slice = seq['input'][k-mintap]
...@@ -574,7 +575,8 @@ def scan( fn ...@@ -574,7 +575,8 @@ def scan( fn
for k in init_out['taps']: for k in init_out['taps']:
# create a new slice # create a new slice
actual_nw_slice = init_out['initial'][k+mintap] actual_nw_slice = init_out['initial'][k+mintap]
nw_slice = init_out['initial'][0].type() _init_out_var = tensor.as_tensor_variable(init_out['initial'])
nw_slice = _init_out_var[0].type()
# give it a name or debugging and pretty printing # give it a name or debugging and pretty printing
if getattr(init_out['initial'],'name', None) is not None: if getattr(init_out['initial'],'name', None) is not None:
......
...@@ -49,6 +49,7 @@ def info(*msg): ...@@ -49,6 +49,7 @@ def info(*msg):
def safe_new(x): def safe_new(x):
x = tensor.as_tensor_variable(x)
if cuda.cuda_available and isinstance(x.type, cuda.CudaNdarrayType): if cuda.cuda_available and isinstance(x.type, cuda.CudaNdarrayType):
return tensor.TensorType( return tensor.TensorType(
broadcastable = x.type.broadcastable broadcastable = x.type.broadcastable
...@@ -57,6 +58,7 @@ def safe_new(x): ...@@ -57,6 +58,7 @@ def safe_new(x):
return x.type() return x.type()
def safe_to_cpu(x): def safe_to_cpu(x):
x = tensor.as_tensor_variable(x)
if cuda.cuda_available and isinstance(x.type, cuda.CudaNdarrayType): if cuda.cuda_available and isinstance(x.type, cuda.CudaNdarrayType):
return cuda.basic_ops.host_from_gpu(x) return cuda.basic_ops.host_from_gpu(x)
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论