提交 af8f19b6 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add new checks in optimization, to prevent incorrect replacement.

We want to avoid merging different outputs of a scan node, that compute the same thing, but keep a different number of steps.
上级 d011f68e
...@@ -1280,6 +1280,9 @@ def scan_merge_inouts(node): ...@@ -1280,6 +1280,9 @@ def scan_merge_inouts(node):
if not isinstance(node.op, scan_op.Scan): if not isinstance(node.op, scan_op.Scan):
return False return False
# Do a first pass to merge identical external inputs.
# Equivalent inputs will be stored in inp_equiv, then a new
# scan node created without duplicates.
a = scan_args(node.inputs, node.outputs, a = scan_args(node.inputs, node.outputs,
node.op.inputs, node.op.outputs, node.op.info) node.op.inputs, node.op.outputs, node.op.info)
...@@ -1332,7 +1335,9 @@ def scan_merge_inouts(node): ...@@ -1332,7 +1335,9 @@ def scan_merge_inouts(node):
else: else:
na = a na = a
# start again # Now that the identical external inputs have been merged, we do a new
# loop in order to merge external outputs that compute the same things
# from the same inputs.
left = [] left = []
right = [] right = []
...@@ -1369,32 +1374,42 @@ def scan_merge_inouts(node): ...@@ -1369,32 +1374,42 @@ def scan_merge_inouts(node):
else: else:
seen[(oms, sl)] = ims seen[(oms, sl)] = ims
def map_out(i, o, seen): def map_out(outer_i, inner_o, outer_o, seen):
for si, so in seen: # Return the outer input corresponding to an
if equal_computations([i], [si], left, right): # (outer input, inner output) pair. If we see that pair for the first
return so # time, return the provided outer output. If an equivalent pair had
seen.append((i, o)) # already been seen, return that one instead.
return o # Note that we need to check that the outer input match as well,
# because they could have different sizes, and the corresponding
def map_nitsot_out(i, o, sh, seen): # outer outputs cannot be merged in that case.
for p, (si, so, ssh) in enumerate(seen): for s_outer_i, s_inner_o, s_outer_o in seen:
if equal_computations([i], [si], left, right): if (equal_computations([inner_o], [s_inner_o], left, right)
and outer_i == s_outer_i):
return s_outer_o
seen.append((outer_i, inner_o, outer_o))
return outer_o
def map_nitsot_out(outer_i, inner_o, outer_o, sh, seen):
# Like map_out, but also checks the needed shape.
for p, (s_outer_i, s_inner_o, s_outer_o, ssh) in enumerate(seen):
if (equal_computations([inner_o], [s_inner_o], left, right)
and outer_i == s_outer_i):
if equal_computations([sh], [ssh]): if equal_computations([sh], [ssh]):
return so return s_outer_o
try: try:
vsh = int(opt.get_scalar_constant_value(sh)) vsh = int(opt.get_scalar_constant_value(sh))
vssh = int(opt.get_scalar_constant_value(ssh)) vssh = int(opt.get_scalar_constant_value(ssh))
except tensor.NotScalarConstantError: except tensor.NotScalarConstantError:
return o return outer_o
if vsh == vssh: if vsh == vssh:
return so return s_outer_o
elif vsh > vssh: elif vsh > vssh:
seen[p] = (i, o, sh) seen[p] = (outer_i, inner_o, outer_o, sh)
return o return outer_o
else: else:
return so[:vsh] return s_outer_o[:vsh]
seen.append((i, o, sh)) seen.append((outer_i, inner_o, outer_o, sh))
return o return outer_o
seen = [] seen = []
...@@ -1410,36 +1425,52 @@ def scan_merge_inouts(node): ...@@ -1410,36 +1425,52 @@ def scan_merge_inouts(node):
# If x is a scalar, then it means its value is the number of # If x is a scalar, then it means its value is the number of
# items scan is supposed to store for this nit_sot sequence # items scan is supposed to store for this nit_sot sequence
shapes.append(x) shapes.append(x)
tmp = [map_nitsot_out(i, o, sh, seen) assert len(na.outer_in_nit_sot) == len(na.inner_out_nit_sot)
for i, o, sh in zip(na.inner_out_nit_sot, assert len(na.inner_out_nit_sot) == len(na.outer_out_nit_sot)
na.outer_out_nit_sot, assert len(na.outer_out_nit_sot) == len(shapes)
shapes)] na.outer_out_nit_sot = [
na.outer_out_nit_sot = [map_nitsot_out(i, o, sh, seen) map_nitsot_out(outer_i, inner_o, outer_o, sh, seen)
for i, o, sh in zip(na.inner_out_nit_sot, for outer_i, inner_o, outer_o, sh in zip(na.outer_in_nit_sot,
na.inner_out_nit_sot,
na.outer_out_nit_sot, na.outer_out_nit_sot,
shapes)] shapes)]
seen = [] seen = []
na.outer_out_sit_sot = [map_out(i, o, seen) assert len(na.outer_in_sit_sot) == len(na.inner_out_sit_sot)
for i, o in zip(na.inner_out_sit_sot, assert len(na.inner_out_sit_sot) == len(na.outer_out_sit_sot)
na.outer_out_sit_sot = [
map_out(outer_i, inner_o, outer_o, seen)
for outer_i, inner_o, outer_o in zip(na.outer_in_sit_sot,
na.inner_out_sit_sot,
na.outer_out_sit_sot)] na.outer_out_sit_sot)]
seen = [] seen = []
na.outer_out_mit_sot = [map_out(i, o, seen) assert len(na.outer_in_mit_sot) == len(na.inner_out_mit_sot)
for i, o in zip(na.inner_out_mit_sot, assert len(na.inner_out_mit_sot) == len(na.outer_out_mit_sot)
na.outer_out_mit_sot = [
map_out(outer_i, inner_o, outer_o, seen)
for outer_i, inner_o, outer_o in zip(na.outer_in_mit_sot,
na.inner_out_mit_sot,
na.outer_out_mit_sot)] na.outer_out_mit_sot)]
seen = [] seen = []
new_outer_out_mit_mot = [] new_outer_out_mit_mot = []
for imm, omm, osl in zip(na.inner_out_mit_mot, assert len(na.outer_in_mit_mot) == len(na.inner_out_mit_mot)
na.outer_out_mit_mot, na.mit_mot_out_slices): assert len(na.inner_out_mit_mot) == len(na.outer_out_mit_mot)
for simm, somm, sosl in seen: assert len(na.outer_out_mit_mot) == len(na.mit_mot_out_slices)
if osl == sosl and equal_computations(imm, simm, left, right): for outer_imm, inner_omm, outer_omm, osl in zip(na.outer_in_mit_mot,
new_outer_out_mit_mot.append(somm) na.inner_out_mit_mot,
na.outer_out_mit_mot,
na.mit_mot_out_slices):
for s_outer_imm, s_inner_omm, s_outer_omm, sosl in seen:
if (osl == sosl
and equal_computations(inner_omm, s_inner_omm, left, right)
and outer_imm == s_outer_imm):
new_outer_out_mit_mot.append(s_outer_omm)
break break
else: else:
seen.append((imm, omm, osl)) seen.append((outer_imm, inner_omm, outer_omm, osl))
new_outer_out_mit_mot.append(omm) new_outer_out_mit_mot.append(outer_omm)
na.outer_out_mit_mot = new_outer_out_mit_mot na.outer_out_mit_mot = new_outer_out_mit_mot
return na.outer_outputs return na.outer_outputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论