提交 db7d1387 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Change the name of parameter assume_continuously_differentiable of grad

上级 06d6d1ba
......@@ -4697,7 +4697,7 @@ outer = Outer()
#########################
def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False,
assume_continuously_differentiable = False):
disconnected_inputs='raise'):
"""
:type cost: Scalar (0-dimensional) `Variable`
:type wrt: `Variable` or list of `Variable`s.
......@@ -4709,13 +4709,13 @@ def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False,
:param warn_type: a value of True will cause warnings to be logged for any Op that emits a
gradient that does not match its input type.
:param assume_continuously_differentiable : flag that says if grad is strict about what it returns.
If set to false it will raise an exception for any argument in
``wrt`` for which there is no gradient either because some op does
not know how to compute the gradient with respect to that argument
or the argument is not part of the computational graph. If the flag
is set to true, the ``grad`` method returns zeros like the argument
( i.e. it makes the assumption that the gradient should be 0).
:type disconnected_inputs: string
:param disconnected_inputs: Defines the behaviour if some of the variables
in ``wrt`` are not part of the computational graph computing ``cost``
(or if all links are non-differentiable). The possible values are:
- 'ignore': considers that the gradient on these parameters is zero.
- 'warn': consider the gradient zero, and print a warning.
- 'raise': raise an exception.
:rtype: `Variable` or list of `Variable`s (depending upon `wrt`)
......@@ -4758,13 +4758,24 @@ def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False,
wrt = [wrt]
ret = []
for p in wrt:
if p not in gmap and not assume_continuously_differentiable:
raise ValueError(("grad method was asked to compute the gradient "
"with respect to a variable that is not part of "
"the computational graph of the cost, or is used "
"by a non-differentiable operator"), p)
if p in gmap:
ret.append(gmap[p])
else:
ret.append(gmap.get(p, zeros_like(p)))
message = ("grad method was asked to compute the gradient "
"with respect to a variable that is not part of "
"the computational graph of the cost, or is used "
"only by a non-differentiable operator: %s" % p)
if disconnected_inputs == 'ignore':
pass
elif disconnected_inputs == 'warn':
warnings.warn(message, stacklevel=1)
elif disconnected_inputs == 'raise':
raise ValueError(message)
else:
raise ValueError("Invalid value for keyword "
"'disconnected_inputs', valid values are "
"'ignore', 'warn' and 'raise'.")
ret.append(zeros_like(p))
if len(ret) == 1:
return ret[0]
......@@ -5039,7 +5050,7 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, abs_tol=None, rel_tol=No
g_cost = cast(g_cost, o_output.dtype)
symbolic_grad = grad(cost, tensor_pt, g_cost,
assume_continuously_differentiable = True)
disconnected_inputs='ignore')
#if o_output.dtype in ['float32','float64']:
# assert all([x.dtype == o_output.dtype for x in symbolic_grad]),("Expected grad of type %s, got %s "%( symbolic_grad.dtype, o_output.dtyp))
......
......@@ -3235,7 +3235,7 @@ class test_grad(unittest.TestCase):
o = test_grad.O()
a1 = o.make_node()
g = grad(a1.outputs[0], a1.outputs[1],
assume_continuously_differentiable = True)
disconnected_inputs='ignore')
self.assertTrue(g.owner.op == fill)
self.assertTrue(g.owner.inputs[1].data == 0)
self.assertRaises(ValueError, grad, a1.outputs[0], 'wtf')
......@@ -3245,7 +3245,7 @@ class test_grad(unittest.TestCase):
o = test_grad.O()
a1 = o.make_node()
g0,g1,g2 = grad(a1.outputs[0], a1.inputs + [scalar('z')],
assume_continuously_differentiable = True)
disconnected_inputs='ignore')
self.assertTrue(o.gval0 is g0)
self.assertTrue(o.gval1 is g1)
self.assertTrue(g2.owner.op == fill)
......@@ -3255,7 +3255,7 @@ class test_grad(unittest.TestCase):
"""Ensure that a zero gradient has the proper shape."""
x = dmatrix()
f = theano.function([x], grad(dscalar(), x,
assume_continuously_differentiable= True))
disconnected_inputs='ignore'))
a = numpy.ones((3, 7))
self.assertTrue((f(a) == 0).all()) # Zero gradient.
self.assertTrue(a.shape == f(a).shape) # With proper shape.
......
......@@ -2651,9 +2651,9 @@ def test_make_vector():
s = mv.sum()
gb = T.grad(s, b, assume_continuously_differentiable=True)
gi = T.grad(s, i, assume_continuously_differentiable=True)
gd = T.grad(s, d, assume_continuously_differentiable=True)
gb = T.grad(s, b, disconnected_inputs='ignore')
gi = T.grad(s, i, disconnected_inputs='ignore')
gd = T.grad(s, d, disconnected_inputs='ignore')
#print 'gb =', gb
#print 'gi =', gi
#print 'gd =', gd
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论