提交 8d5c376f authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Remove error in hessian if gradient is constant

上级 2c710d23
...@@ -1636,12 +1636,17 @@ def hessian(cost, wrt, consider_constant=None, ...@@ -1636,12 +1636,17 @@ def hessian(cost, wrt, consider_constant=None,
assert input.ndim == 1, \ assert input.ndim == 1, \
"tensor.hessian expects a (list of) 1 dimensional variable "\ "tensor.hessian expects a (list of) 1 dimensional variable "\
"as `wrt`" "as `wrt`"
expr = grad(cost, input) expr = grad(cost, input, consider_constant=consider_constant,
disconnected_inputs=disconnected_inputs)
# It is possible that the inputs are disconnected from expr,
# even if they are connected to cost.
# This should not be an error.
hess, updates = theano.scan(lambda i, y, x: grad( hess, updates = theano.scan(lambda i, y, x: grad(
y[i], y[i],
x, x,
consider_constant=consider_constant, consider_constant=consider_constant,
disconnected_inputs=disconnected_inputs), disconnected_inputs='ignore'),
sequences=arange(expr.shape[0]), sequences=arange(expr.shape[0]),
non_sequences=[expr, input]) non_sequences=[expr, input])
assert not updates, \ assert not updates, \
......
...@@ -3086,6 +3086,28 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -3086,6 +3086,28 @@ class T_Join_and_Split(unittest.TestCase):
assert numpy.allclose(Ha_v, 2.) assert numpy.allclose(Ha_v, 2.)
assert numpy.allclose(Hb_v, 2.) assert numpy.allclose(Hb_v, 2.)
def test_stack_hessian2(self):
# Test the hessian macro when the gradient itself does not depend
# on the input (but the cost does)
a = tensor.dvector('a')
b = tensor.dvector('b')
A = stack([a, b])
Ha, Hb = hessian(A.sum(), [a, b])
# Try some values
a_v = numpy.random.rand(4)
b_v = numpy.random.rand(4)
f = theano.function([a, b], [Ha, Hb])
Ha_v, Hb_v = f(a_v, b_v)
print Ha_v
print Hb_v
# The Hessian is always a matrix full of 0
assert Ha_v.shape == (4, 4)
assert Hb_v.shape == (4, 4)
assert numpy.allclose(Ha_v, 0.)
assert numpy.allclose(Hb_v, 0.)
def test_join_concatenate_one_element(self): def test_join_concatenate_one_element(self):
''' Fast test of concatenate as this is an alias for join. ''' Fast test of concatenate as this is an alias for join.
also test that we remove the Join op if there is only 1 input''' also test that we remove the Join op if there is only 1 input'''
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论