提交 5b53939c authored 作者: Razvan Pascanu's avatar Razvan Pascanu

removed wip code

上级 08d73bae
......@@ -1512,26 +1512,6 @@ def scan_pushout_dot1(node):
inner_ins_shapes = seqs_shape + outs_shape + input_shapes[offset:]
assert len(inner_ins_shapes) == len(op.inputs)
grab_shape = scan_utils.infer_shape(
[new_scan_out], op.inputs, inner_ins_shapes)[0]
out_equivalent = {}
for in_ns, out_ns in zip(op.inner_non_seqs(op.inputs),
op.outer_non_seqs(node.inputs)):
out_equivalent[in_ns] = out_ns
validator = scan_utils.Validator(
valid=input_shapes,
invalid=op.inputs,
valid_equivalent=out_equivalent)
grab_shape = [validator.check(x) for x in grab_shape]
if numpy.all([x is not None for x in grab_shape]) and \
op.nit_sot_buffers:
new_mem_buffer = tensor.zeros(
[node.inputs[0]] +
[x[0] for x in grab_shape])
new_info['nit_sot_buffers'] = True
else:
new_info['nit_sot_buffers'] = False
_new_inner_inps = (inner_seqs +
inner_mitmot +
inner_mitsot +
......@@ -1548,26 +1528,15 @@ def scan_pushout_dot1(node):
_new_inner_inps, _new_inner_outs)
new_op = scan_op.Scan(new_inner_inps, new_inner_outs,
new_info)
if new_info['nit_sot_buffers']:
_scan_inputs = ([node.inputs[0]] +
outer_seqs +
outer_mitmot +
outer_mitsot +
outer_sitsot +
outer_shared +
outer_nitsot +
[new_mem_buffer] +
outer_non_seqs)
else:
_scan_inputs = ([node.inputs[0]] +
outer_seqs +
outer_mitmot +
outer_mitsot +
outer_sitsot +
outer_shared +
outer_nitsot +
[node.inputs[0]] +
outer_non_seqs)
_scan_inputs = ([node.inputs[0]] +
outer_seqs +
outer_mitmot +
outer_mitsot +
outer_sitsot +
outer_shared +
outer_nitsot +
[node.inputs[0]] +
outer_non_seqs)
new_outs = new_op(*_scan_inputs)
# We need now to pair correctly the new outputs with the
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论