提交 629017bb authored 作者: Pascal Lamblin's avatar Pascal Lamblin

First fix in scan_merge: clone each inner graph independently

上级 1d639d66
...@@ -1100,9 +1100,13 @@ class ScanMerge(gof.Optimizer): ...@@ -1100,9 +1100,13 @@ class ScanMerge(gof.Optimizer):
info['as_while'] = as_while info['as_while'] = as_while
info['profile'] = nodes[0].op.profile info['profile'] = nodes[0].op.profile
inner_ins = [] # We keep the inner_ins and inner_outs of each original node separated.
# To be able to recombine them in the right order after the clone,
# we also need to split them by types (seq, mitmot, ...).
# On the other hand, outer_ins, outer_outs and info are held together.
inner_ins = [[] for nd in nodes]
outer_ins = [] outer_ins = []
inner_outs = [] inner_outs = [[] for nd in nodes]
outer_outs = [] outer_outs = []
def rename(ls, suffix): def rename(ls, suffix):
...@@ -1113,13 +1117,14 @@ class ScanMerge(gof.Optimizer): ...@@ -1113,13 +1117,14 @@ class ScanMerge(gof.Optimizer):
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# Seq # Seq
inner_ins += rename(nd.op.inner_seqs(nd.op.inputs), idx) inner_ins[idx].append(rename(nd.op.inner_seqs(nd.op.inputs), idx))
outer_ins += rename(nd.op.outer_seqs(nd.inputs), idx) outer_ins += rename(nd.op.outer_seqs(nd.inputs), idx)
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# MitMot # MitMot
inner_ins += rename(nd.op.inner_mitmot(nd.op.inputs), idx) inner_ins[idx].append(
inner_outs += nd.op.inner_mitmot_outs(nd.op.outputs) rename(nd.op.inner_mitmot(nd.op.inputs), idx))
inner_outs[idx].append(nd.op.inner_mitmot_outs(nd.op.outputs))
info['tap_array'] += nd.op.mitmot_taps() info['tap_array'] += nd.op.mitmot_taps()
info['mit_mot_out_slices'] += nd.op.mitmot_out_taps() info['mit_mot_out_slices'] += nd.op.mitmot_out_taps()
outer_ins += rename(nd.op.outer_mitmot(nd.inputs), idx) outer_ins += rename(nd.op.outer_mitmot(nd.inputs), idx)
...@@ -1127,51 +1132,108 @@ class ScanMerge(gof.Optimizer): ...@@ -1127,51 +1132,108 @@ class ScanMerge(gof.Optimizer):
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# MitSot # MitSot
inner_ins += rename(nd.op.inner_mitsot(nd.op.inputs), idx) inner_ins[idx].append(
inner_outs += nd.op.inner_mitsot_outs(nd.op.outputs) rename(nd.op.inner_mitsot(nd.op.inputs), idx))
inner_outs[idx].append(nd.op.inner_mitsot_outs(nd.op.outputs))
info['tap_array'] += nd.op.mitsot_taps() info['tap_array'] += nd.op.mitsot_taps()
outer_ins += rename(nd.op.outer_mitsot(nd.inputs), idx) outer_ins += rename(nd.op.outer_mitsot(nd.inputs), idx)
outer_outs += nd.op.outer_mitsot_outs(nd.outputs) outer_outs += nd.op.outer_mitsot_outs(nd.outputs)
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# SitSot # SitSot
inner_ins += rename(nd.op.inner_sitsot(nd.op.inputs), idx) inner_ins[idx].append(
rename(nd.op.inner_sitsot(nd.op.inputs), idx))
info['tap_array'] += [[-1] for x in xrange(nd.op.n_sit_sot)] info['tap_array'] += [[-1] for x in xrange(nd.op.n_sit_sot)]
inner_outs += nd.op.inner_sitsot_outs(nd.op.outputs) inner_outs[idx].append(nd.op.inner_sitsot_outs(nd.op.outputs))
outer_ins += rename(nd.op.outer_sitsot(nd.inputs), idx) outer_ins += rename(nd.op.outer_sitsot(nd.inputs), idx)
outer_outs += nd.op.outer_sitsot_outs(nd.outputs) outer_outs += nd.op.outer_sitsot_outs(nd.outputs)
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# Shared # Shared
inner_ins += rename(nd.op.inner_shared(nd.op.inputs), idx) inner_ins[idx].append(
rename(nd.op.inner_shared(nd.op.inputs), idx))
outer_ins += rename(nd.op.outer_shared(nd.inputs), idx) outer_ins += rename(nd.op.outer_shared(nd.inputs), idx)
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# NitSot # NitSot
inner_outs += nd.op.inner_nitsot_outs(nd.op.outputs) inner_outs[idx].append(nd.op.inner_nitsot_outs(nd.op.outputs))
outer_ins += rename(nd.op.outer_nitsot(nd.inputs), idx) outer_ins += rename(nd.op.outer_nitsot(nd.inputs), idx)
outer_outs += nd.op.outer_nitsot_outs(nd.outputs) outer_outs += nd.op.outer_nitsot_outs(nd.outputs)
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# Shared # Shared
outer_outs += nd.op.outer_shared_outs(nd.outputs) outer_outs += nd.op.outer_shared_outs(nd.outputs)
inner_outs += nd.op.inner_shared_outs(nd.op.outputs) inner_outs[idx].append(nd.op.inner_shared_outs(nd.op.outputs))
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# Non Seqs # Non Seqs
inner_ins += rename(nd.op.inner_non_seqs(nd.op.inputs), idx) inner_ins[idx].append(
rename(nd.op.inner_non_seqs(nd.op.inputs), idx))
outer_ins += rename(nd.op.outer_non_seqs(nd.inputs), idx) outer_ins += rename(nd.op.outer_non_seqs(nd.inputs), idx)
# Add back the number of steps # Add back the number of steps
outer_ins = [nodes[0].inputs[0]] + outer_ins outer_ins = [nodes[0].inputs[0]] + outer_ins
if as_while: if as_while:
# add the condition # add the condition, which was the one of nodes[0]
inner_outs.append(condition) inner_outs[0].append([condition])
inner_ins, inner_outs = scan_utils.reconstruct_graph(inner_ins,
inner_outs)
new_op = scan_op.Scan(inner_ins, inner_outs, info) # Clone the inner graph of each node independently
for idx, nd in enumerate(nodes):
# concatenate all inner_ins and inner_outs of nd
flat_inner_ins = sum(inner_ins[idx], [])
flat_inner_outs = sum(inner_outs[idx], [])
# clone
flat_inner_ins, flat_inner_outs = scan_utils.reconstruct_graph(
flat_inner_ins, flat_inner_outs)
# split the new inner variables again in seq, mitmot, etc.
new_inner_ins = []
count = 0
for nl in inner_ins[idx]:
seq_len = len(nl)
new_inner_ins.append(flat_inner_ins[count:(count + seq_len)])
count += seq_len
new_inner_outs = []
count = 0
for nl in inner_outs[idx]:
seq_len = len(nl)
new_inner_outs.append(flat_inner_outs[count:(count + seq_len)])
count += seq_len
inner_ins[idx] = new_inner_ins
inner_outs[idx] = new_inner_outs
# Flatten inner_ins and inner_outs so that all seqs are first,
# then mitmot, etc.
new_inner_ins = []
new_inner_outs = []
nb_ins_groups = len(inner_ins[0])
nb_outs_groups = len(inner_outs[0])
for idx, nd in enumerate(nodes):
# All inner_ins should have the same length
assert len(inner_ins[idx]) == nb_ins_groups
# All inner_outs should have the same length, except if as_while,
# in which case the first one should have one more element
if as_while and idx > 0:
assert len(inner_outs[idx]) == nb_outs_groups - 1
else:
assert len(inner_outs[idx]) == nb_outs_groups
for gr_idx in range(nb_ins_groups):
for idx, nd in enumerate(nodes):
new_inner_ins += inner_ins[idx][gr_idx]
for gr_idx in range(nb_outs_groups):
for idx, nd in enumerate(nodes):
if as_while and idx > 0 and gr_idx == (nb_outs_groups - 1):
# There is no condition on that node, skip it
pass
else:
new_inner_outs += inner_outs[idx][gr_idx]
new_op = scan_op.Scan(new_inner_ins, new_inner_outs, info)
new_outs = new_op(*outer_ins) new_outs = new_op(*outer_ins)
if not isinstance(new_outs, (list, tuple)): if not isinstance(new_outs, (list, tuple)):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论