提交 af8f19b6 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add new checks in optimization, to prevent incorrect replacement.

We want to avoid merging different outputs of a scan node, that compute the same thing, but keep a different number of steps.
上级 d011f68e
......@@ -1280,6 +1280,9 @@ def scan_merge_inouts(node):
if not isinstance(node.op, scan_op.Scan):
return False
# Do a first pass to merge identical external inputs.
# Equivalent inputs will be stored in inp_equiv, then a new
# scan node created without duplicates.
a = scan_args(node.inputs, node.outputs,
node.op.inputs, node.op.outputs, node.op.info)
......@@ -1332,7 +1335,9 @@ def scan_merge_inouts(node):
else:
na = a
# start again
# Now that the identical external inputs have been merged, we do a new
# loop in order to merge external outputs that compute the same things
# from the same inputs.
left = []
right = []
......@@ -1369,32 +1374,42 @@ def scan_merge_inouts(node):
else:
seen[(oms, sl)] = ims
def map_out(i, o, seen):
for si, so in seen:
if equal_computations([i], [si], left, right):
return so
seen.append((i, o))
return o
def map_nitsot_out(i, o, sh, seen):
for p, (si, so, ssh) in enumerate(seen):
if equal_computations([i], [si], left, right):
def map_out(outer_i, inner_o, outer_o, seen):
# Return the outer input corresponding to an
# (outer input, inner output) pair. If we see that pair for the first
# time, return the provided outer output. If an equivalent pair had
# already been seen, return that one instead.
# Note that we need to check that the outer input match as well,
# because they could have different sizes, and the corresponding
# outer outputs cannot be merged in that case.
for s_outer_i, s_inner_o, s_outer_o in seen:
if (equal_computations([inner_o], [s_inner_o], left, right)
and outer_i == s_outer_i):
return s_outer_o
seen.append((outer_i, inner_o, outer_o))
return outer_o
def map_nitsot_out(outer_i, inner_o, outer_o, sh, seen):
# Like map_out, but also checks the needed shape.
for p, (s_outer_i, s_inner_o, s_outer_o, ssh) in enumerate(seen):
if (equal_computations([inner_o], [s_inner_o], left, right)
and outer_i == s_outer_i):
if equal_computations([sh], [ssh]):
return so
return s_outer_o
try:
vsh = int(opt.get_scalar_constant_value(sh))
vssh = int(opt.get_scalar_constant_value(ssh))
except tensor.NotScalarConstantError:
return o
return outer_o
if vsh == vssh:
return so
return s_outer_o
elif vsh > vssh:
seen[p] = (i, o, sh)
return o
seen[p] = (outer_i, inner_o, outer_o, sh)
return outer_o
else:
return so[:vsh]
seen.append((i, o, sh))
return o
return s_outer_o[:vsh]
seen.append((outer_i, inner_o, outer_o, sh))
return outer_o
seen = []
......@@ -1410,36 +1425,52 @@ def scan_merge_inouts(node):
# If x is a scalar, then it means its value is the number of
# items scan is supposed to store for this nit_sot sequence
shapes.append(x)
tmp = [map_nitsot_out(i, o, sh, seen)
for i, o, sh in zip(na.inner_out_nit_sot,
na.outer_out_nit_sot,
shapes)]
na.outer_out_nit_sot = [map_nitsot_out(i, o, sh, seen)
for i, o, sh in zip(na.inner_out_nit_sot,
na.outer_out_nit_sot,
shapes)]
assert len(na.outer_in_nit_sot) == len(na.inner_out_nit_sot)
assert len(na.inner_out_nit_sot) == len(na.outer_out_nit_sot)
assert len(na.outer_out_nit_sot) == len(shapes)
na.outer_out_nit_sot = [
map_nitsot_out(outer_i, inner_o, outer_o, sh, seen)
for outer_i, inner_o, outer_o, sh in zip(na.outer_in_nit_sot,
na.inner_out_nit_sot,
na.outer_out_nit_sot,
shapes)]
seen = []
na.outer_out_sit_sot = [map_out(i, o, seen)
for i, o in zip(na.inner_out_sit_sot,
na.outer_out_sit_sot)]
assert len(na.outer_in_sit_sot) == len(na.inner_out_sit_sot)
assert len(na.inner_out_sit_sot) == len(na.outer_out_sit_sot)
na.outer_out_sit_sot = [
map_out(outer_i, inner_o, outer_o, seen)
for outer_i, inner_o, outer_o in zip(na.outer_in_sit_sot,
na.inner_out_sit_sot,
na.outer_out_sit_sot)]
seen = []
na.outer_out_mit_sot = [map_out(i, o, seen)
for i, o in zip(na.inner_out_mit_sot,
na.outer_out_mit_sot)]
assert len(na.outer_in_mit_sot) == len(na.inner_out_mit_sot)
assert len(na.inner_out_mit_sot) == len(na.outer_out_mit_sot)
na.outer_out_mit_sot = [
map_out(outer_i, inner_o, outer_o, seen)
for outer_i, inner_o, outer_o in zip(na.outer_in_mit_sot,
na.inner_out_mit_sot,
na.outer_out_mit_sot)]
seen = []
new_outer_out_mit_mot = []
for imm, omm, osl in zip(na.inner_out_mit_mot,
na.outer_out_mit_mot, na.mit_mot_out_slices):
for simm, somm, sosl in seen:
if osl == sosl and equal_computations(imm, simm, left, right):
new_outer_out_mit_mot.append(somm)
assert len(na.outer_in_mit_mot) == len(na.inner_out_mit_mot)
assert len(na.inner_out_mit_mot) == len(na.outer_out_mit_mot)
assert len(na.outer_out_mit_mot) == len(na.mit_mot_out_slices)
for outer_imm, inner_omm, outer_omm, osl in zip(na.outer_in_mit_mot,
na.inner_out_mit_mot,
na.outer_out_mit_mot,
na.mit_mot_out_slices):
for s_outer_imm, s_inner_omm, s_outer_omm, sosl in seen:
if (osl == sosl
and equal_computations(inner_omm, s_inner_omm, left, right)
and outer_imm == s_outer_imm):
new_outer_out_mit_mot.append(s_outer_omm)
break
else:
seen.append((imm, omm, osl))
new_outer_out_mit_mot.append(omm)
seen.append((outer_imm, inner_omm, outer_omm, osl))
new_outer_out_mit_mot.append(outer_omm)
na.outer_out_mit_mot = new_outer_out_mit_mot
return na.outer_outputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论