提交 e645134c authored 作者: Frederic's avatar Frederic

Do equivalent fix in Scan.grad() on other new input sequence.

Also add assert when test_value are available.
上级 31b07c03
...@@ -1581,7 +1581,15 @@ class Scan(PureOp): ...@@ -1581,7 +1581,15 @@ class Scan(PureOp):
if not isinstance(x.type, DisconnectedType): if not isinstance(x.type, DisconnectedType):
outer_inp_seqs.append(x[::-1]) outer_inp_seqs.append(x[::-1])
outer_inp_seqs += [x[::-1] for x in self.outer_mitsot_outs(outs)] if hasattr(inputs[0].tag, 'test_value'):
for x in self.outer_mitsot_outs(outs):
if hasattr(x[::-1][:inputs[0]].tag, 'test_value'):
assert x[::-1][:inputs[0]].tag.test_value.shape[0] == inputs[0].tag.test_value
for x in self.outer_sitsot_outs(outs):
if hasattr(x[::-1][:-1].tag, 'test_value'):
assert x[::-1][:-1].tag.test_value.shape[0] == inputs[0].tag.test_value
outer_inp_seqs += [x[::-1][:inputs[0]]
for x in self.outer_mitsot_outs(outs)]
outer_inp_seqs += [x[::-1][:-1] for x in self.outer_sitsot_outs(outs)] outer_inp_seqs += [x[::-1][:-1] for x in self.outer_sitsot_outs(outs)]
outer_inp_seqs += [x[::-1] for x in self.outer_nitsot_outs(outs)] outer_inp_seqs += [x[::-1] for x in self.outer_nitsot_outs(outs)]
......
...@@ -1141,7 +1141,7 @@ class T_Scan(unittest.TestCase): ...@@ -1141,7 +1141,7 @@ class T_Scan(unittest.TestCase):
go_backwards=False) go_backwards=False)
gX, gY = tensor.grad(values[1].sum(), [x, y]) gX, gY = tensor.grad(values[1].sum(), [x, y])
f = theano.function([c, x, y], [gX, gY], f = theano.function([c, x, y], [gX, gY],
allow_input_downcast=True) allow_input_downcast=True)
# Check for runtime errors # Check for runtime errors
f(numpy.int32(0), numpy.float32(1.), numpy.float32(.5)) f(numpy.int32(0), numpy.float32(1.), numpy.float32(.5))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论