提交 44768424 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix gradient of split when there are disconnected outputs

This should fix gh-1589
上级 e5d769dc
......@@ -3225,8 +3225,23 @@ class Split(Op):
def grad(self, inputs, g_outputs):
"""Join the gradients along the axis that was used to split x."""
_, axis, n = inputs
return [join(axis, *g_outputs),
x, axis, n = inputs
outputs = self(*inputs, **dict(return_list=True))
# If all the output gradients are disconnected, then so are the inputs
if python_all([isinstance(g.type, DisconnectedType)
for g in g_outputs]):
return [DisconnectedType()(),
grad_undefined(self, 1, axis),
grad_undefined(self, 2, n)]
# Else, we have to make them zeros before joining them
new_g_outputs = []
for o, g in zip(outputs, g_outputs):
if isinstance(g.type, DisconnectedType):
new_g_outputs.append(o.zeros_like())
else:
new_g_outputs.append(g)
return [join(axis, *new_g_outputs),
grad_undefined(self, 1, axis),
grad_undefined(self, 2, n)]
......
......@@ -44,7 +44,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
dtensor3, SpecifyShape, Mean,
itensor3, Tile, switch, Diagonal, Diag,
nonzero, flatnonzero, nonzero_values,
stacklists, DimShuffle)
stacklists, DimShuffle, hessian)
from theano.tests import unittest_tools as utt
......@@ -3067,6 +3067,14 @@ class T_Join_and_Split(unittest.TestCase):
assert len([n for n in topo if isinstance(n, self.join_op)]) == 0
assert f.maker.fgraph.outputs[0].dtype == 'int64'
def test_stack_hessian(self):
# Test the gradient of stack when used in hessian, see gh-1589
a = tensor.dvector('a')
b = tensor.dvector('b')
A = stack([a, b])
B = A.T.dot(A)
hessian(B.sum(), [a, b])
def test_join_concatenate_one_element(self):
''' Fast test of concatenate as this is an alias for join.
also test that we remove the Join op if there is only 1 input'''
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论