提交 965afb68 authored 作者: Ian Goodfellow's avatar Ian Goodfellow 提交者: Olivier Delalleau

bug fix where grad didn't obey its keep_type flag

changed default value of keep_type in accordance with decision from theano-dev mailing list added unit test to ensure that grad behaves as it is required to for pylearn2 support
上级 b4973089
......@@ -233,10 +233,8 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False,
# Gradient
#########################
# TODO For Theano 0.5, change default value of `keep_wrt_type` to True
# and get rid of the `None` option (in docstring and in code).
def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
disconnected_inputs='raise', keep_wrt_type=None):
disconnected_inputs='raise', keep_wrt_type=True):
"""
:type cost: Scalar (0-dimensional) `Variable`
:type wrt: `Variable` or list of `Variable`s.
......@@ -260,10 +258,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
returned output is of the same type. When False, if `wrt` is a one-element
list or tuple, then the returned value is a single `Variable` (and if
`wrt` is a list or tuple with at least two elements, then the returned
value is always a list -- never a tuple). This option may also be set to
None, in which case it behaves as if it was False, but a warning is also
issued when `wrt` is a one-element list or tuple, since we intend to change
the default behavior in a future Theano version.
value is always a list -- never a tuple).
This option has no effect when `wrt` is a single `Variable` (in which case
the returned value is always a single `Variable`).
......@@ -338,20 +333,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
ret = tuple(ret)
if len(ret) == 1:
if (using_list or using_tuple) and keep_wrt_type is None:
warnings.warn(
"The return type of `tensor.grad(cost, wrt)` will change "
"in the case where `wrt` is a one-element list/tuple. "
"In the future `grad(cost, wrt)` will return by default "
"an object of the same type as `wrt` (so if `wrt` is a "
"list/tuple, a list/tuple will be returned, while if it "
"is a single Variable, then a single Variable will be "
"returned). You may get rid of this warning by adding "
"'keep_wrt_type=True' (or False) when calling "
"`tensor.grad`, depending on whether you want the new "
"or old behavior.",
stacklevel=2)
if keep_wrt_type:
if keep_wrt_type and (using_list or using_tuple):
return ret
else:
return ret[0]
......
......@@ -22,6 +22,7 @@ from theano.gof.python25 import any, all, combinations
from theano.compile.mode import get_default_mode
from theano import function
from theano.tests import unittest_tools as utt
import theano.tensor as T
imported_scipy_special = False
......@@ -3662,6 +3663,25 @@ class test_grad(unittest.TestCase):
self.assertTrue(o.gval0 is g0)
self.assertTrue(o.gval1 is g1)
def test_grad_keep_type(self):
"""Tests that the theano grad method returns a list if it is passed a list
and a single variable if it is passed a single variable.
pylearn2 depends on theano behaving this way but theano developers have
repeatedly changed it """
X = T.matrix()
y = X.sum()
G = T.grad(y, [X])
assert isinstance(G,list)
G = T.grad(y, X)
assert not isinstance(G,list)
def test_1None_rval(self):
"""grad: Test returning a single zero value from grad"""
o = test_grad.O()
......@@ -5199,6 +5219,7 @@ class test_size(unittest.TestCase):
assert y.size == function([], x.size)()
if __name__ == '__main__':
if 1:
unittest.main()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论