提交 21f87538 authored 作者: --global's avatar --global

scan_pushout_seq : avoid removing computation unrelated to sequences

上级 5d1915c3
......@@ -426,15 +426,19 @@ class PushOutSeqScan(gof.Optimizer):
not nd in to_remove):
to_remove.append(nd)
outside_ins = []
depends_on_seqs = False
for x in nd.inputs:
if x in inner_non_seqs:
_idx = inner_non_seqs.index(x)
outside_ins += [outer_non_seqs[_idx]]
elif x in inner_seqs:
outside_ins += [outer_seqs[inner_seqs.index(x)]]
depends_on_seqs = True
elif x in to_replace:
outside_ins += [replace_with_out[
to_replace.index(x)]]
depends_on_seqs = True
elif isinstance(x, theano.Constant):
outside_ins += [x.clone()]
else:
......@@ -444,6 +448,15 @@ class PushOutSeqScan(gof.Optimizer):
'to move some computation fron scan '
'which is not allowed to move. Report '
'this on theano-users list'), x)
if not depends_on_seqs:
# Removing this node from the inner graph of scan
# should be handled by the PushOutNonSeqScan
# optimization. The current optimization only tries
# to pull sequence-dependant computation out of
# scan.
continue
# Do not call make_node for test_value
nw_outer_node = nd.op(*outside_ins,
**dict(return_list=True))[0].owner
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论