提交 5a0fb0e7 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Properly disable unused code in Scan.infer_shape

上级 ccbf2e98
......@@ -2097,29 +2097,29 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# sequences
seqs_shape = [x[1:] for x in input_shapes[1 : 1 + info.n_seqs]]
# We disable extra infer_shape for now. See gh-3765.
extra_infer_shape = False
# if extra_infer_shape:
# inner_seqs = self.inputs[: info.n_seqs]
# outer_seqs = node.inputs[1 : 1 + info.n_seqs]
# for in_s, out_s in zip(inner_seqs, outer_seqs):
# out_equivalent[in_s] = out_s[0]
#
# # mit_mot, mit_sot, sit_sot
# outer_inp_idx = 1 + info.n_seqs
# inner_inp_idx = info.n_seqs
# else:
# outer_inp_idx = 0
outer_inp_idx = 0
if extra_infer_shape:
inner_seqs = self.inputs[: info.n_seqs]
outer_seqs = node.inputs[1 : 1 + info.n_seqs]
for in_s, out_s in zip(inner_seqs, outer_seqs):
out_equivalent[in_s] = out_s[0]
# mit_mot, mit_sot, sit_sot
outer_inp_idx = 1 + info.n_seqs
inner_inp_idx = info.n_seqs
else:
outer_inp_idx = 0
n_outs = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot
outs_shape = []
for idx in range(n_outs):
mintap = abs(min(info.tap_array[idx]))
abs(min(info.tap_array[idx]))
for k in info.tap_array[idx]:
outs_shape += [input_shapes[idx + info.n_seqs + 1][1:]]
if extra_infer_shape:
corresponding_tap = node.inputs[outer_inp_idx][mintap + k]
out_equivalent[self.inputs[inner_inp_idx]] = corresponding_tap
inner_inp_idx += 1
# if extra_infer_shape:
# corresponding_tap = node.inputs[outer_inp_idx][mintap + k]
# out_equivalent[self.inputs[inner_inp_idx]] = corresponding_tap
# inner_inp_idx += 1
outer_inp_idx += 1
# shared_outs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论