提交 d6d2a4f4 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Merge pull request #1637 from lamblin/fix_stack_grad

Fix gradient of split when there are disconnected outputs
...@@ -3225,8 +3225,23 @@ class Split(Op): ...@@ -3225,8 +3225,23 @@ class Split(Op):
def grad(self, inputs, g_outputs): def grad(self, inputs, g_outputs):
"""Join the gradients along the axis that was used to split x.""" """Join the gradients along the axis that was used to split x."""
_, axis, n = inputs x, axis, n = inputs
return [join(axis, *g_outputs), outputs = self(*inputs, **dict(return_list=True))
# If all the output gradients are disconnected, then so are the inputs
if python_all([isinstance(g.type, DisconnectedType)
for g in g_outputs]):
return [DisconnectedType()(),
grad_undefined(self, 1, axis),
grad_undefined(self, 2, n)]
# Else, we have to make them zeros before joining them
new_g_outputs = []
for o, g in zip(outputs, g_outputs):
if isinstance(g.type, DisconnectedType):
new_g_outputs.append(o.zeros_like())
else:
new_g_outputs.append(g)
return [join(axis, *new_g_outputs),
grad_undefined(self, 1, axis), grad_undefined(self, 1, axis),
grad_undefined(self, 2, n)] grad_undefined(self, 2, n)]
......
...@@ -44,7 +44,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -44,7 +44,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
dtensor3, SpecifyShape, Mean, dtensor3, SpecifyShape, Mean,
itensor3, Tile, switch, Diagonal, Diag, itensor3, Tile, switch, Diagonal, Diag,
nonzero, flatnonzero, nonzero_values, nonzero, flatnonzero, nonzero_values,
stacklists, DimShuffle) stacklists, DimShuffle, hessian)
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
...@@ -3067,6 +3067,14 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -3067,6 +3067,14 @@ class T_Join_and_Split(unittest.TestCase):
assert len([n for n in topo if isinstance(n, self.join_op)]) == 0 assert len([n for n in topo if isinstance(n, self.join_op)]) == 0
assert f.maker.fgraph.outputs[0].dtype == 'int64' assert f.maker.fgraph.outputs[0].dtype == 'int64'
def test_stack_hessian(self):
# Test the gradient of stack when used in hessian, see gh-1589
a = tensor.dvector('a')
b = tensor.dvector('b')
A = stack([a, b])
B = A.T.dot(A)
hessian(B.sum(), [a, b])
def test_join_concatenate_one_element(self): def test_join_concatenate_one_element(self):
''' Fast test of concatenate as this is an alias for join. ''' Fast test of concatenate as this is an alias for join.
also test that we remove the Join op if there is only 1 input''' also test that we remove the Join op if there is only 1 input'''
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论