提交 1b4988c6 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

fix bug reported by Ilya of scan adding extra inputs when it needs not to.

上级 e8db4482
...@@ -773,13 +773,16 @@ def scan( fn ...@@ -773,13 +773,16 @@ def scan( fn
# extract still missing inputs (there still might be so) and add them # extract still missing inputs (there still might be so) and add them
# as non sequences at the end of our args # as non sequences at the end of our args
fake_nonseqs = [x.type() for x in non_seqs]
fake_outputs = scan_utils.clone(outputs,
replace=dict(zip(non_seqs,
fake_nonseqs)))
all_inputs = itertools.ifilter( all_inputs = itertools.ifilter(
lambda x: ( isinstance(x, gof.Variable) and lambda x: ( isinstance(x, gof.Variable) and
not isinstance(x, SharedVariable) and not isinstance(x, SharedVariable) and
not isinstance(x, gof.Constant) ), not isinstance(x, gof.Constant) ),
gof.graph.inputs( outputs) ) gof.graph.inputs( fake_outputs) )
extra_inputs = filter( lambda x: x not in args, extra_inputs = filter( lambda x: x not in args + fake_nonseqs,
all_inputs) all_inputs)
non_seqs += extra_inputs non_seqs += extra_inputs
## Note we do not use all_inputs directly since the order of variables ## Note we do not use all_inputs directly since the order of variables
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论