提交 9184e0f2 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

merge identical inputs as well (seqs and non-seqs)

上级 6fc6d024
...@@ -103,20 +103,38 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -103,20 +103,38 @@ def remove_constants_and_unused_inputs_scan(node):
except TypeError: except TypeError:
pass pass
elif op_ins[idx] in all_ins: elif op_ins[idx] in all_ins:
nw_inner += [op_ins[idx]] # Check for identical other sequence
nw_outer += [node.inputs[idx + 1]] identical_seqs = [x for x in nw_outer
if scan_utils.equal_computations(
[x], [node.inputs[idx + 1]])]
if identical_seqs and False:
index = node.inputs.index(identical_seqs[0]) - 1
if op_ins[index] not in givens.keys():
givens[op_ins[idx]] = op_ins[index]
else:
givens[op_ins[idx]] = givens[op_ins[index]]
else:
nw_inner += [op_ins[idx]]
nw_outer += [node.inputs[idx + 1]]
nw_n_seqs = len(nw_inner) nw_n_seqs = len(nw_inner)
# Add outputs stuff # Add outputs stuff
nw_inner += out_stuff_inner nw_inner += out_stuff_inner
nw_outer += out_stuff_outer nw_outer += out_stuff_outer
# Look through non sequences # Look through non sequences
for nw_in, nw_out in zip(non_seqs, outer_non_seqs): for idx, (nw_in, nw_out) in enumerate(zip(non_seqs, outer_non_seqs)):
if isinstance(nw_out, tensor.Constant): if isinstance(nw_out, tensor.Constant):
givens[nw_in] = nw_out.clone() givens[nw_in] = nw_out.clone()
elif nw_in in all_ins: elif nw_in in all_ins:
nw_inner += [nw_in] identical_non_seqs = [x for x in outer_non_seqs[:idx]
nw_outer += [nw_out] if scan_utils.equal_computations(
[x], [nw_out])]
if identical_non_seqs:
index = outer_non_seqs.index(identical_non_seqs[0])
givens[nw_in] = non_seqs[index]
else:
nw_inner += [nw_in]
nw_outer += [nw_out]
if len(nw_inner) != len(op_ins): if len(nw_inner) != len(op_ins):
op_outs = scan_utils.clone(op_outs, replace=givens) op_outs = scan_utils.clone(op_outs, replace=givens)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论