提交 90cba7f8 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

renamed flag from strict to assume_continiously_differentiable

上级 e42d146d
...@@ -4677,7 +4677,7 @@ outer = Outer() ...@@ -4677,7 +4677,7 @@ outer = Outer()
######################### #########################
def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False, def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False,
strict = True): assume_continuously_differentiable = False):
""" """
:type cost: Scalar (0-dimensional) `Variable` :type cost: Scalar (0-dimensional) `Variable`
:type wrt: `Variable` or list of `Variable`s. :type wrt: `Variable` or list of `Variable`s.
...@@ -4689,12 +4689,12 @@ def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False, ...@@ -4689,12 +4689,12 @@ def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False,
:param warn_type: a value of True will cause warnings to be logged for any Op that emits a :param warn_type: a value of True will cause warnings to be logged for any Op that emits a
gradient that does not match its input type. gradient that does not match its input type.
:param strict: flag that says if grad is strict about what it returns. :param assume_continuously_differentiable : flag that says if grad is strict about what it returns.
If set to true it will raise an exception for any argument in If set to false it will raise an exception for any argument in
``wrt`` for which there is no gradient either because some op does ``wrt`` for which there is no gradient either because some op does
not know how to compute the gradient with respect to that argument not know how to compute the gradient with respect to that argument
or the argument is not part of the computational graph. If the flag or the argument is not part of the computational graph. If the flag
is set to false, the ``grad`` method returns zeros like the argument is set to true, the ``grad`` method returns zeros like the argument
( i.e. it makes the assumption that the gradient should be 0). ( i.e. it makes the assumption that the gradient should be 0).
:rtype: `Variable` or list of `Variable`s (depending upon `wrt`) :rtype: `Variable` or list of `Variable`s (depending upon `wrt`)
...@@ -4738,7 +4738,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False, ...@@ -4738,7 +4738,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False,
wrt = [wrt] wrt = [wrt]
ret = [] ret = []
for p in wrt: for p in wrt:
if p not in gmap and strict: if p not in gmap and not assume_continuously_differentiable:
raise ValueError(("grad method was asked to compute the graident " raise ValueError(("grad method was asked to compute the graident "
"with respect to a variable that is not part of " "with respect to a variable that is not part of "
"the computational graph of the cost or is used " "the computational graph of the cost or is used "
...@@ -5018,7 +5018,8 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, abs_tol=None, rel_tol=No ...@@ -5018,7 +5018,8 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, abs_tol=None, rel_tol=No
if cast_to_output_type: if cast_to_output_type:
g_cost = cast(g_cost, o_output.dtype) g_cost = cast(g_cost, o_output.dtype)
symbolic_grad = grad(cost, tensor_pt, g_cost, strict = False) symbolic_grad = grad(cost, tensor_pt, g_cost,
assume_continuously_differentiable = True)
#if o_output.dtype in ['float32','float64']: #if o_output.dtype in ['float32','float64']:
# assert all([x.dtype == o_output.dtype for x in symbolic_grad]),("Expected grad of type %s, got %s "%( symbolic_grad.dtype, o_output.dtyp)) # assert all([x.dtype == o_output.dtype for x in symbolic_grad]),("Expected grad of type %s, got %s "%( symbolic_grad.dtype, o_output.dtyp))
......
...@@ -3234,7 +3234,8 @@ class test_grad(unittest.TestCase): ...@@ -3234,7 +3234,8 @@ class test_grad(unittest.TestCase):
"""grad: Test returning a single zero value from grad""" """grad: Test returning a single zero value from grad"""
o = test_grad.O() o = test_grad.O()
a1 = o.make_node() a1 = o.make_node()
g = grad(a1.outputs[0], a1.outputs[1], strict = False) g = grad(a1.outputs[0], a1.outputs[1],
assume_continuously_differentiable = True)
self.assertTrue(g.owner.op == fill) self.assertTrue(g.owner.op == fill)
self.assertTrue(g.owner.inputs[1].data == 0) self.assertTrue(g.owner.inputs[1].data == 0)
try: try:
...@@ -3248,7 +3249,7 @@ class test_grad(unittest.TestCase): ...@@ -3248,7 +3249,7 @@ class test_grad(unittest.TestCase):
o = test_grad.O() o = test_grad.O()
a1 = o.make_node() a1 = o.make_node()
g0,g1,g2 = grad(a1.outputs[0], a1.inputs + [scalar('z')], g0,g1,g2 = grad(a1.outputs[0], a1.inputs + [scalar('z')],
strict = False) assume_continuously_differentiable = True)
self.assertTrue(o.gval0 is g0) self.assertTrue(o.gval0 is g0)
self.assertTrue(o.gval1 is g1) self.assertTrue(o.gval1 is g1)
self.assertTrue(g2.owner.op == fill) self.assertTrue(g2.owner.op == fill)
...@@ -3257,7 +3258,8 @@ class test_grad(unittest.TestCase): ...@@ -3257,7 +3258,8 @@ class test_grad(unittest.TestCase):
def test_zero_gradient_shape(self): def test_zero_gradient_shape(self):
"""Ensure that a zero gradient has the proper shape.""" """Ensure that a zero gradient has the proper shape."""
x = dmatrix() x = dmatrix()
f = theano.function([x], grad(dscalar(), x, strict= False)) f = theano.function([x], grad(dscalar(), x,
assume_continuously_differentiable= True))
a = numpy.ones((3, 7)) a = numpy.ones((3, 7))
self.assertTrue((f(a) == 0).all()) # Zero gradient. self.assertTrue((f(a) == 0).all()) # Zero gradient.
self.assertTrue(a.shape == f(a).shape) # With proper shape. self.assertTrue(a.shape == f(a).shape) # With proper shape.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论