提交 3b4ccf8b authored 作者: abergeron's avatar abergeron

Merge pull request #3481 from carriepl/scan_grad_speedup_bug

Scan grad speedup bug
...@@ -2052,7 +2052,6 @@ class Scan(PureOp): ...@@ -2052,7 +2052,6 @@ class Scan(PureOp):
dc_dxts_idx += 1 dc_dxts_idx += 1
else: else:
if isinstance(dC_douts[i].type, DisconnectedType): if isinstance(dC_douts[i].type, DisconnectedType):
dc_dxts_idx += 1
continue continue
else: else:
if diff_outputs[i] in known_grads: if diff_outputs[i] in known_grads:
......
...@@ -4057,6 +4057,26 @@ class T_Scan(unittest.TestCase): ...@@ -4057,6 +4057,26 @@ class T_Scan(unittest.TestCase):
# scan could not detect the connection between `m2` and `x` # scan could not detect the connection between `m2` and `x`
tensor.grad(m2.sum(), m) tensor.grad(m2.sum(), m)
def test_disconnected_gradient3(self):
# This tests for a crash that would occur sometimes when taking the
# gradient through a scan with a non-recurrent output which would
# receive a disconnected gradient
v = tensor.dvector('v')
def step(seq):
out1 = seq + 1
out2 = out1 + 1
return out1, out2
[out1, out2], _ = theano.scan(step, sequences=v)
gv = tensor.grad(out2.sum(), [v])
f = theano.function([v], gv)
# Ensure the output of the function is valid
output = f(numpy.random.random(5))
utt.assert_allclose(output, numpy.ones(5))
def test_dot_optimization(self): def test_dot_optimization(self):
A = tensor.matrix('A') A = tensor.matrix('A')
B = tensor.matrix('B') B = tensor.matrix('B')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论