提交 21f87538 authored 作者: --global's avatar --global

scan_pushout_seq : avoid removing computation unrelated to sequences

上级 5d1915c3
...@@ -426,15 +426,19 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -426,15 +426,19 @@ class PushOutSeqScan(gof.Optimizer):
not nd in to_remove): not nd in to_remove):
to_remove.append(nd) to_remove.append(nd)
outside_ins = [] outside_ins = []
depends_on_seqs = False
for x in nd.inputs: for x in nd.inputs:
if x in inner_non_seqs: if x in inner_non_seqs:
_idx = inner_non_seqs.index(x) _idx = inner_non_seqs.index(x)
outside_ins += [outer_non_seqs[_idx]] outside_ins += [outer_non_seqs[_idx]]
elif x in inner_seqs: elif x in inner_seqs:
outside_ins += [outer_seqs[inner_seqs.index(x)]] outside_ins += [outer_seqs[inner_seqs.index(x)]]
depends_on_seqs = True
elif x in to_replace: elif x in to_replace:
outside_ins += [replace_with_out[ outside_ins += [replace_with_out[
to_replace.index(x)]] to_replace.index(x)]]
depends_on_seqs = True
elif isinstance(x, theano.Constant): elif isinstance(x, theano.Constant):
outside_ins += [x.clone()] outside_ins += [x.clone()]
else: else:
...@@ -444,6 +448,15 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -444,6 +448,15 @@ class PushOutSeqScan(gof.Optimizer):
'to move some computation fron scan ' 'to move some computation fron scan '
'which is not allowed to move. Report ' 'which is not allowed to move. Report '
'this on theano-users list'), x) 'this on theano-users list'), x)
if not depends_on_seqs:
# Removing this node from the inner graph of scan
# should be handled by the PushOutNonSeqScan
# optimization. The current optimization only tries
# to pull sequence-dependant computation out of
# scan.
continue
# Do not call make_node for test_value # Do not call make_node for test_value
nw_outer_node = nd.op(*outside_ins, nw_outer_node = nd.op(*outside_ins,
**dict(return_list=True))[0].owner **dict(return_list=True))[0].owner
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论