提交 a2ea4492 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

fix to optimization to deal with identical nit sot of different lengths

上级 c36bdaae
...@@ -1385,10 +1385,46 @@ def scan_merge_inouts(node): ...@@ -1385,10 +1385,46 @@ def scan_merge_inouts(node):
seen.append((i, o)) seen.append((i, o))
return o return o
def map_nitsot_out(i, o, sh, seen):
for p,(si, so, ssh) in enumerate(seen):
if equal_computations([i], [si], left, right):
if equal_computations([sh], [ssh]):
return so
try:
vsh = int(opt.get_constant_value(sh))
vssh = int(opt.get_constant_value(ssh))
except:
return o
if vsh == vssh:
return so
elif vsh > vssh:
seen[p] = (i,o,sh)
return o
else:
return so[:vsh]
seen.append((i, o, sh))
return o
seen = [] seen = []
na.outer_out_nit_sot = [map_out(i, o, seen)
for i, o in zip(na.inner_out_nit_sot, shapes = []
na.outer_out_nit_sot)] for x in na.outer_in_nit_sot:
if x.ndim > 0:
shapes.append(
node.fgraph.shape_feature.shape_of[x][0])
else:
shapes.append(x)
tmp = [map_nitsot_out(i, o, sh, seen)
for i, o, sh in zip(na.inner_out_nit_sot,
na.outer_out_nit_sot,
shapes)]
na.outer_out_nit_sot = [map_nitsot_out(i, o, sh, seen)
for i, o, sh in zip(na.inner_out_nit_sot,
na.outer_out_nit_sot,
shapes)]
seen = [] seen = []
na.outer_out_sit_sot = [map_out(i, o, seen) na.outer_out_sit_sot = [map_out(i, o, seen)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论