提交 958f0cab authored 作者: Razvan Pascanu's avatar Razvan Pascanu

reuse functionality for outputs

上级 9184e0f2
......@@ -190,26 +190,19 @@ class PushOutNonSeqScan(gof.Optimizer):
replace_with_out = []
op = node.op
# Construct the list of non_sequences to simplify a few things
st = op.n_seqs
st += int(numpy.sum([len(x) for x in
op.tap_array[:(op.n_mit_mot + op.n_mit_sot)]]))
st += op.n_sit_sot
st += op.n_shared_outs
non_seqs = clean_inputs[st:]
st = (op.n_seqs +
op.n_mit_mot +
op.n_mit_sot +
op.n_sit_sot +
op.n_nit_sot +
op.n_shared_outs + 1)
outer_non_seqs = node.inputs[st:]
assert len(non_seqs) == len(outer_non_seqs)
inner_non_seqs = op.inner_non_seqs(clean_inputs)
outer_non_seqs = op.outer_non_seqs(node.inputs)
inner_seqs = op.inner_seqs(clean_inputs)
outer_seqs = op.outer_seqs(node.inputs)
assert len(inner_non_seqs) == len(outer_non_seqs)
assert len(inner_seqs) == len(outer_seqs)
while changed and counts < max_iterations:
counts += 1
changed = False
for nd in local_fgraph.toposort():
if (numpy.all([(x in non_seqs) or
if (numpy.all([(x in inner_non_seqs) or
(x.owner in to_remove) or
isinstance(x, tensor.Constant)
for x in nd.inputs]) and
......@@ -226,8 +219,9 @@ class PushOutNonSeqScan(gof.Optimizer):
to_remove.append(nd)
outside_ins = []
for x in nd.inputs:
if x in non_seqs:
outside_ins += [outer_non_seqs[non_seqs.index(x)]]
if x in inner_non_seqs:
_idx = inner_non_seqs.index(x)
outside_ins += [outer_non_seqs[_idx]]
elif x in to_replace:
outside_ins += [
replace_with_out[to_replace.index(x)]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论