提交 db7d1387 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Change the name of parameter assume_continuously_differentiable of grad

上级 06d6d1ba
...@@ -4697,7 +4697,7 @@ outer = Outer() ...@@ -4697,7 +4697,7 @@ outer = Outer()
######################### #########################
def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False, def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False,
assume_continuously_differentiable = False): disconnected_inputs='raise'):
""" """
:type cost: Scalar (0-dimensional) `Variable` :type cost: Scalar (0-dimensional) `Variable`
:type wrt: `Variable` or list of `Variable`s. :type wrt: `Variable` or list of `Variable`s.
...@@ -4709,13 +4709,13 @@ def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False, ...@@ -4709,13 +4709,13 @@ def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False,
:param warn_type: a value of True will cause warnings to be logged for any Op that emits a :param warn_type: a value of True will cause warnings to be logged for any Op that emits a
gradient that does not match its input type. gradient that does not match its input type.
:param assume_continuously_differentiable : flag that says if grad is strict about what it returns. :type disconnected_inputs: string
If set to false it will raise an exception for any argument in :param disconnected_inputs: Defines the behaviour if some of the variables
``wrt`` for which there is no gradient either because some op does in ``wrt`` are not part of the computational graph computing ``cost``
not know how to compute the gradient with respect to that argument (or if all links are non-differentiable). The possible values are:
or the argument is not part of the computational graph. If the flag - 'ignore': considers that the gradient on these parameters is zero.
is set to true, the ``grad`` method returns zeros like the argument - 'warn': consider the gradient zero, and print a warning.
( i.e. it makes the assumption that the gradient should be 0). - 'raise': raise an exception.
:rtype: `Variable` or list of `Variable`s (depending upon `wrt`) :rtype: `Variable` or list of `Variable`s (depending upon `wrt`)
...@@ -4758,13 +4758,24 @@ def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False, ...@@ -4758,13 +4758,24 @@ def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False,
wrt = [wrt] wrt = [wrt]
ret = [] ret = []
for p in wrt: for p in wrt:
if p not in gmap and not assume_continuously_differentiable: if p in gmap:
raise ValueError(("grad method was asked to compute the gradient " ret.append(gmap[p])
"with respect to a variable that is not part of "
"the computational graph of the cost, or is used "
"by a non-differentiable operator"), p)
else: else:
ret.append(gmap.get(p, zeros_like(p))) message = ("grad method was asked to compute the gradient "
"with respect to a variable that is not part of "
"the computational graph of the cost, or is used "
"only by a non-differentiable operator: %s" % p)
if disconnected_inputs == 'ignore':
pass
elif disconnected_inputs == 'warn':
warnings.warn(message, stacklevel=1)
elif disconnected_inputs == 'raise':
raise ValueError(message)
else:
raise ValueError("Invalid value for keyword "
"'disconnected_inputs', valid values are "
"'ignore', 'warn' and 'raise'.")
ret.append(zeros_like(p))
if len(ret) == 1: if len(ret) == 1:
return ret[0] return ret[0]
...@@ -5039,7 +5050,7 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, abs_tol=None, rel_tol=No ...@@ -5039,7 +5050,7 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, abs_tol=None, rel_tol=No
g_cost = cast(g_cost, o_output.dtype) g_cost = cast(g_cost, o_output.dtype)
symbolic_grad = grad(cost, tensor_pt, g_cost, symbolic_grad = grad(cost, tensor_pt, g_cost,
assume_continuously_differentiable = True) disconnected_inputs='ignore')
#if o_output.dtype in ['float32','float64']: #if o_output.dtype in ['float32','float64']:
# assert all([x.dtype == o_output.dtype for x in symbolic_grad]),("Expected grad of type %s, got %s "%( symbolic_grad.dtype, o_output.dtyp)) # assert all([x.dtype == o_output.dtype for x in symbolic_grad]),("Expected grad of type %s, got %s "%( symbolic_grad.dtype, o_output.dtyp))
......
...@@ -3235,7 +3235,7 @@ class test_grad(unittest.TestCase): ...@@ -3235,7 +3235,7 @@ class test_grad(unittest.TestCase):
o = test_grad.O() o = test_grad.O()
a1 = o.make_node() a1 = o.make_node()
g = grad(a1.outputs[0], a1.outputs[1], g = grad(a1.outputs[0], a1.outputs[1],
assume_continuously_differentiable = True) disconnected_inputs='ignore')
self.assertTrue(g.owner.op == fill) self.assertTrue(g.owner.op == fill)
self.assertTrue(g.owner.inputs[1].data == 0) self.assertTrue(g.owner.inputs[1].data == 0)
self.assertRaises(ValueError, grad, a1.outputs[0], 'wtf') self.assertRaises(ValueError, grad, a1.outputs[0], 'wtf')
...@@ -3245,7 +3245,7 @@ class test_grad(unittest.TestCase): ...@@ -3245,7 +3245,7 @@ class test_grad(unittest.TestCase):
o = test_grad.O() o = test_grad.O()
a1 = o.make_node() a1 = o.make_node()
g0,g1,g2 = grad(a1.outputs[0], a1.inputs + [scalar('z')], g0,g1,g2 = grad(a1.outputs[0], a1.inputs + [scalar('z')],
assume_continuously_differentiable = True) disconnected_inputs='ignore')
self.assertTrue(o.gval0 is g0) self.assertTrue(o.gval0 is g0)
self.assertTrue(o.gval1 is g1) self.assertTrue(o.gval1 is g1)
self.assertTrue(g2.owner.op == fill) self.assertTrue(g2.owner.op == fill)
...@@ -3255,7 +3255,7 @@ class test_grad(unittest.TestCase): ...@@ -3255,7 +3255,7 @@ class test_grad(unittest.TestCase):
"""Ensure that a zero gradient has the proper shape.""" """Ensure that a zero gradient has the proper shape."""
x = dmatrix() x = dmatrix()
f = theano.function([x], grad(dscalar(), x, f = theano.function([x], grad(dscalar(), x,
assume_continuously_differentiable= True)) disconnected_inputs='ignore'))
a = numpy.ones((3, 7)) a = numpy.ones((3, 7))
self.assertTrue((f(a) == 0).all()) # Zero gradient. self.assertTrue((f(a) == 0).all()) # Zero gradient.
self.assertTrue(a.shape == f(a).shape) # With proper shape. self.assertTrue(a.shape == f(a).shape) # With proper shape.
......
...@@ -2651,9 +2651,9 @@ def test_make_vector(): ...@@ -2651,9 +2651,9 @@ def test_make_vector():
s = mv.sum() s = mv.sum()
gb = T.grad(s, b, assume_continuously_differentiable=True) gb = T.grad(s, b, disconnected_inputs='ignore')
gi = T.grad(s, i, assume_continuously_differentiable=True) gi = T.grad(s, i, disconnected_inputs='ignore')
gd = T.grad(s, d, assume_continuously_differentiable=True) gd = T.grad(s, d, disconnected_inputs='ignore')
#print 'gb =', gb #print 'gb =', gb
#print 'gi =', gi #print 'gi =', gi
#print 'gd =', gd #print 'gd =', gd
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论