提交 5b53939c authored 作者: Razvan Pascanu's avatar Razvan Pascanu

removed wip code

上级 08d73bae
...@@ -1512,26 +1512,6 @@ def scan_pushout_dot1(node): ...@@ -1512,26 +1512,6 @@ def scan_pushout_dot1(node):
inner_ins_shapes = seqs_shape + outs_shape + input_shapes[offset:] inner_ins_shapes = seqs_shape + outs_shape + input_shapes[offset:]
assert len(inner_ins_shapes) == len(op.inputs) assert len(inner_ins_shapes) == len(op.inputs)
grab_shape = scan_utils.infer_shape(
[new_scan_out], op.inputs, inner_ins_shapes)[0]
out_equivalent = {}
for in_ns, out_ns in zip(op.inner_non_seqs(op.inputs),
op.outer_non_seqs(node.inputs)):
out_equivalent[in_ns] = out_ns
validator = scan_utils.Validator(
valid=input_shapes,
invalid=op.inputs,
valid_equivalent=out_equivalent)
grab_shape = [validator.check(x) for x in grab_shape]
if numpy.all([x is not None for x in grab_shape]) and \
op.nit_sot_buffers:
new_mem_buffer = tensor.zeros(
[node.inputs[0]] +
[x[0] for x in grab_shape])
new_info['nit_sot_buffers'] = True
else:
new_info['nit_sot_buffers'] = False
_new_inner_inps = (inner_seqs + _new_inner_inps = (inner_seqs +
inner_mitmot + inner_mitmot +
inner_mitsot + inner_mitsot +
...@@ -1548,26 +1528,15 @@ def scan_pushout_dot1(node): ...@@ -1548,26 +1528,15 @@ def scan_pushout_dot1(node):
_new_inner_inps, _new_inner_outs) _new_inner_inps, _new_inner_outs)
new_op = scan_op.Scan(new_inner_inps, new_inner_outs, new_op = scan_op.Scan(new_inner_inps, new_inner_outs,
new_info) new_info)
if new_info['nit_sot_buffers']: _scan_inputs = ([node.inputs[0]] +
_scan_inputs = ([node.inputs[0]] + outer_seqs +
outer_seqs + outer_mitmot +
outer_mitmot + outer_mitsot +
outer_mitsot + outer_sitsot +
outer_sitsot + outer_shared +
outer_shared + outer_nitsot +
outer_nitsot + [node.inputs[0]] +
[new_mem_buffer] + outer_non_seqs)
outer_non_seqs)
else:
_scan_inputs = ([node.inputs[0]] +
outer_seqs +
outer_mitmot +
outer_mitsot +
outer_sitsot +
outer_shared +
outer_nitsot +
[node.inputs[0]] +
outer_non_seqs)
new_outs = new_op(*_scan_inputs) new_outs = new_op(*_scan_inputs)
# We need now to pair correctly the new outputs with the # We need now to pair correctly the new outputs with the
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论