提交 e8db4482 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

fix bug reported in ticket 766 by Fred ( passing numpy values or ints to

scan)
上级 f9813f0f
...@@ -340,8 +340,15 @@ def scan( fn ...@@ -340,8 +340,15 @@ def scan( fn
seqs = wrap_into_list(sequences) seqs = wrap_into_list(sequences)
outs_info = wrap_into_list(outputs_info) outs_info = wrap_into_list(outputs_info)
non_seqs = wrap_into_list(non_sequences)
# Make sure we get rid of numpy arrays or ints or anything like that
# passed as inputs to scan
non_seqs = []
for elem in wrap_into_list(non_sequences):
if not isinstance(elem, gof.Variable):
non_seqs.append(tensor.as_tensor_variable(elem))
else:
non_seqs.append(elem)
# If we provided a known number of steps ( before compilation) # If we provided a known number of steps ( before compilation)
# and if that number is 1 or -1, then we can skip the Scan Op, # and if that number is 1 or -1, then we can skip the Scan Op,
...@@ -380,7 +387,7 @@ def scan( fn ...@@ -380,7 +387,7 @@ def scan( fn
# wrap outputs info in a dictionary if they are not already in one # wrap outputs info in a dictionary if they are not already in one
for i in xrange(n_outs): for i in xrange(n_outs):
if outs_info[i]: if outs_info[i] is not None:
if isinstance(outs_info[i], dict): if isinstance(outs_info[i], dict):
# DEPRICATED : # DEPRICATED :
if outs_info[i].get('return_steps', None): if outs_info[i].get('return_steps', None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论