提交 44768424 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix gradient of split when there are disconnected outputs

This should fix gh-1589
上级 e5d769dc
...@@ -3225,8 +3225,23 @@ class Split(Op): ...@@ -3225,8 +3225,23 @@ class Split(Op):
def grad(self, inputs, g_outputs): def grad(self, inputs, g_outputs):
"""Join the gradients along the axis that was used to split x.""" """Join the gradients along the axis that was used to split x."""
_, axis, n = inputs x, axis, n = inputs
return [join(axis, *g_outputs), outputs = self(*inputs, **dict(return_list=True))
# If all the output gradients are disconnected, then so are the inputs
if python_all([isinstance(g.type, DisconnectedType)
for g in g_outputs]):
return [DisconnectedType()(),
grad_undefined(self, 1, axis),
grad_undefined(self, 2, n)]
# Else, we have to make them zeros before joining them
new_g_outputs = []
for o, g in zip(outputs, g_outputs):
if isinstance(g.type, DisconnectedType):
new_g_outputs.append(o.zeros_like())
else:
new_g_outputs.append(g)
return [join(axis, *new_g_outputs),
grad_undefined(self, 1, axis), grad_undefined(self, 1, axis),
grad_undefined(self, 2, n)] grad_undefined(self, 2, n)]
......
...@@ -44,7 +44,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -44,7 +44,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
dtensor3, SpecifyShape, Mean, dtensor3, SpecifyShape, Mean,
itensor3, Tile, switch, Diagonal, Diag, itensor3, Tile, switch, Diagonal, Diag,
nonzero, flatnonzero, nonzero_values, nonzero, flatnonzero, nonzero_values,
stacklists, DimShuffle) stacklists, DimShuffle, hessian)
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
...@@ -3067,6 +3067,14 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -3067,6 +3067,14 @@ class T_Join_and_Split(unittest.TestCase):
assert len([n for n in topo if isinstance(n, self.join_op)]) == 0 assert len([n for n in topo if isinstance(n, self.join_op)]) == 0
assert f.maker.fgraph.outputs[0].dtype == 'int64' assert f.maker.fgraph.outputs[0].dtype == 'int64'
def test_stack_hessian(self):
# Test the gradient of stack when used in hessian, see gh-1589
a = tensor.dvector('a')
b = tensor.dvector('b')
A = stack([a, b])
B = A.T.dot(A)
hessian(B.sum(), [a, b])
def test_join_concatenate_one_element(self): def test_join_concatenate_one_element(self):
''' Fast test of concatenate as this is an alias for join. ''' Fast test of concatenate as this is an alias for join.
also test that we remove the Join op if there is only 1 input''' also test that we remove the Join op if there is only 1 input'''
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论