提交 63e73e69 authored 作者: Frederic's avatar Frederic

small code refactoring

上级 2b80093d
...@@ -96,25 +96,26 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -96,25 +96,26 @@ def remove_constants_and_unused_inputs_scan(node):
all_ins = gof.graph.inputs(op_outs) all_ins = gof.graph.inputs(op_outs)
for idx in xrange(op.n_seqs): for idx in xrange(op.n_seqs):
if (isinstance(node.inputs[idx + 1], tensor.TensorConstant) and node_inp = node.inputs[idx + 1]
node.inputs[idx + 1].tag.unique_value is not None): if (isinstance(node_inp, tensor.TensorConstant) and
node_inp.tag.unique_value is not None):
try: try:
# This works if input is a constant that has all entries # This works if input is a constant that has all entries
# equal # equal
givens[op_ins[idx]] = node.inputs[idx + 1].clone()[0] givens[op_ins[idx]] = node_inp.clone()[0]
except TypeError: except TypeError:
pass pass
elif op_ins[idx] in all_ins: elif op_ins[idx] in all_ins:
# Check for identical other sequence # Check for identical other sequence
identical_seqs = [x for x in nw_outer identical_seqs = [x for x in nw_outer
if scan_utils.equal_computations( if scan_utils.equal_computations(
[x], [node.inputs[idx + 1]])] [x], [node_inp])]
if identical_seqs: if identical_seqs:
index = node.inputs.index(identical_seqs[0]) - 1 index = node.inputs.index(identical_seqs[0]) - 1
givens[op_ins[idx]] = op_ins[index] givens[op_ins[idx]] = op_ins[index]
else: else:
nw_inner += [op_ins[idx]] nw_inner += [op_ins[idx]]
nw_outer += [node.inputs[idx + 1]] nw_outer += [node_inp]
nw_n_seqs = len(nw_inner) nw_n_seqs = len(nw_inner)
# Add outputs stuff # Add outputs stuff
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论