提交 1a9d77ab authored 作者: Frederic's avatar Frederic

Fix a bug Yao had. Is this the right fix?

上级 f09723ba
...@@ -127,11 +127,21 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -127,11 +127,21 @@ def remove_constants_and_unused_inputs_scan(node):
if scan_utils.equal_computations( if scan_utils.equal_computations(
[x], [nw_out])] [x], [nw_out])]
if identical_non_seqs: if identical_non_seqs:
identical_idx = outer_non_seqs.index(identical_non_seqs[0])
# If we have identical non sequences, the previous one
# must be in nw_inner or be a constant.
assert (non_seqs[identical_idx] in nw_inner or
isinstance(identical_non_seqs[0], tensor.Constant))
index = outer_non_seqs.index(identical_non_seqs[0]) index = outer_non_seqs.index(identical_non_seqs[0])
givens[nw_in] = non_seqs[index] givens[nw_in] = non_seqs[index]
else: else:
nw_inner += [nw_in] nw_inner += [nw_in]
nw_outer += [nw_out] nw_outer += [nw_out]
else:
# How this can happen? This case happened and if we remove
# this else, the assert in the elif will fail.
nw_inner += [nw_in]
nw_outer += [nw_out]
if len(nw_inner) != len(op_ins): if len(nw_inner) != len(op_ins):
op_outs = scan_utils.clone(op_outs, replace=givens) op_outs = scan_utils.clone(op_outs, replace=givens)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论