提交 f1913900 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

fix bug reported by Nicolas

上级 60ec239c
...@@ -1268,6 +1268,8 @@ class Scan(PureOp): ...@@ -1268,6 +1268,8 @@ class Scan(PureOp):
# It means the gradient is undefined (which implies # It means the gradient is undefined (which implies
# is connected) # is connected)
gmp[x] = x gmp[x] = x
except gradient.DisconnectedInputError:
gmp[x] = None
return [gmp.get(p, None) for p in diff_inputs] return [gmp.get(p, None) for p in diff_inputs]
def _get_inner_outs(oidx): def _get_inner_outs(oidx):
......
...@@ -3427,6 +3427,15 @@ class T_Scan(unittest.TestCase): ...@@ -3427,6 +3427,15 @@ class T_Scan(unittest.TestCase):
assert numpy.allclose(outs[2], v_w + 3) assert numpy.allclose(outs[2], v_w + 3)
assert numpy.allclose(sh.get_value(), v_w + 4) assert numpy.allclose(sh.get_value(), v_w + 4)
def test_grad_bug_disconnected_input(self):
W = theano.shared(numpy.zeros((3, 3)), name='W')
v = theano.tensor.ivector(name='v')
y, _ = theano.scan(lambda i, W: W[i], sequences=v, outputs_info=None, non_sequences=W)
#This used to raise an exception
f = theano.function([v], theano.tensor.grad(y.sum(), W))
assert numpy.allclose(f([1,2]), [[0,0,0],[1,1,1],[1,1,1]])
def test_clone(self): def test_clone(self):
def test(x, y, mention_y): def test(x, y, mention_y):
if mention_y: if mention_y:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论