提交 e645134c authored 作者: Frederic's avatar Frederic

Do equivalent fix in Scan.grad() on other new input sequence.

Also add assert when test_value are available.
上级 31b07c03
...@@ -1581,7 +1581,15 @@ class Scan(PureOp): ...@@ -1581,7 +1581,15 @@ class Scan(PureOp):
if not isinstance(x.type, DisconnectedType): if not isinstance(x.type, DisconnectedType):
outer_inp_seqs.append(x[::-1]) outer_inp_seqs.append(x[::-1])
outer_inp_seqs += [x[::-1] for x in self.outer_mitsot_outs(outs)] if hasattr(inputs[0].tag, 'test_value'):
for x in self.outer_mitsot_outs(outs):
if hasattr(x[::-1][:inputs[0]].tag, 'test_value'):
assert x[::-1][:inputs[0]].tag.test_value.shape[0] == inputs[0].tag.test_value
for x in self.outer_sitsot_outs(outs):
if hasattr(x[::-1][:-1].tag, 'test_value'):
assert x[::-1][:-1].tag.test_value.shape[0] == inputs[0].tag.test_value
outer_inp_seqs += [x[::-1][:inputs[0]]
for x in self.outer_mitsot_outs(outs)]
outer_inp_seqs += [x[::-1][:-1] for x in self.outer_sitsot_outs(outs)] outer_inp_seqs += [x[::-1][:-1] for x in self.outer_sitsot_outs(outs)]
outer_inp_seqs += [x[::-1] for x in self.outer_nitsot_outs(outs)] outer_inp_seqs += [x[::-1] for x in self.outer_nitsot_outs(outs)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论