提交 e42d146d authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Added a flag to the grad method that decides how this method behaves ( i.e.

throws an exception or silently returns a bunch of 0s)
上级 66632981
......@@ -4676,7 +4676,8 @@ outer = Outer()
# Gradient
#########################
def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False):
def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False,
strict = True):
"""
:type cost: Scalar (0-dimensional) `Variable`
:type wrt: `Variable` or list of `Variable`s.
......@@ -4688,6 +4689,14 @@ def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False):
:param warn_type: a value of True will cause warnings to be logged for any Op that emits a
gradient that does not match its input type.
:param strict: flag that says if grad is strict about what it returns.
If set to true it will raise an exception for any argument in
``wrt`` for which there is no gradient either because some op does
not know how to compute the gradient with respect to that argument
or the argument is not part of the computational graph. If the flag
is set to false, the ``grad`` method returns zeros like the argument
( i.e. it makes the assumption that the gradient should be 0).
:rtype: `Variable` or list of `Variable`s (depending upon `wrt`)
:return: symbolic expression of gradient of `cost` with respect to `wrt`.
......@@ -4729,12 +4738,13 @@ def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False):
wrt = [wrt]
ret = []
for p in wrt:
if p not in gmap:
if p not in gmap and strict:
raise ValueError(("grad method was asked to compute the graident "
"with respect to a variable that is not part of "
"the computational graph of the cost"),p)
"the computational graph of the cost or is used "
"by a non-differentiable operator "),p)
else:
ret.append(gmap[p])
ret.append(gmap.get(p, zeros_like(p)))
if len(ret) == 1:
return ret[0]
......@@ -5008,7 +5018,7 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, abs_tol=None, rel_tol=No
if cast_to_output_type:
g_cost = cast(g_cost, o_output.dtype)
symbolic_grad = grad(cost, tensor_pt, g_cost)
symbolic_grad = grad(cost, tensor_pt, g_cost, strict = False)
#if o_output.dtype in ['float32','float64']:
# assert all([x.dtype == o_output.dtype for x in symbolic_grad]),("Expected grad of type %s, got %s "%( symbolic_grad.dtype, o_output.dtyp))
......
......@@ -3234,7 +3234,7 @@ class test_grad(unittest.TestCase):
"""grad: Test returning a single zero value from grad"""
o = test_grad.O()
a1 = o.make_node()
g = grad(a1.outputs[0], a1.outputs[1])
g = grad(a1.outputs[0], a1.outputs[1], strict = False)
self.assertTrue(g.owner.op == fill)
self.assertTrue(g.owner.inputs[1].data == 0)
try:
......@@ -3247,7 +3247,8 @@ class test_grad(unittest.TestCase):
"""grad: Test returning some zero value from grad"""
o = test_grad.O()
a1 = o.make_node()
g0,g1,g2 = grad(a1.outputs[0], a1.inputs + [scalar('z')])
g0,g1,g2 = grad(a1.outputs[0], a1.inputs + [scalar('z')],
strict = False)
self.assertTrue(o.gval0 is g0)
self.assertTrue(o.gval1 is g1)
self.assertTrue(g2.owner.op == fill)
......@@ -3256,7 +3257,7 @@ class test_grad(unittest.TestCase):
def test_zero_gradient_shape(self):
"""Ensure that a zero gradient has the proper shape."""
x = dmatrix()
f = theano.function([x], grad(dscalar(), x))
f = theano.function([x], grad(dscalar(), x, strict= False))
a = numpy.ones((3, 7))
self.assertTrue((f(a) == 0).all()) # Zero gradient.
self.assertTrue(a.shape == f(a).shape) # With proper shape.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论