提交 965afb68 authored 作者: Ian Goodfellow's avatar Ian Goodfellow 提交者: Olivier Delalleau

bug fix where grad didn't obey its keep_type flag

changed default value of keep_type in accordance with decision from theano-dev mailing list added unit test to ensure that grad behaves as it is required to for pylearn2 support
上级 b4973089
...@@ -233,10 +233,8 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False, ...@@ -233,10 +233,8 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False,
# Gradient # Gradient
######################### #########################
# TODO For Theano 0.5, change default value of `keep_wrt_type` to True
# and get rid of the `None` option (in docstring and in code).
def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False, def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
disconnected_inputs='raise', keep_wrt_type=None): disconnected_inputs='raise', keep_wrt_type=True):
""" """
:type cost: Scalar (0-dimensional) `Variable` :type cost: Scalar (0-dimensional) `Variable`
:type wrt: `Variable` or list of `Variable`s. :type wrt: `Variable` or list of `Variable`s.
...@@ -260,10 +258,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False, ...@@ -260,10 +258,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
returned output is of the same type. When False, if `wrt` is a one-element returned output is of the same type. When False, if `wrt` is a one-element
list or tuple, then the returned value is a single `Variable` (and if list or tuple, then the returned value is a single `Variable` (and if
`wrt` is a list or tuple with at least two elements, then the returned `wrt` is a list or tuple with at least two elements, then the returned
value is always a list -- never a tuple). This option may also be set to value is always a list -- never a tuple).
None, in which case it behaves as if it was False, but a warning is also
issued when `wrt` is a one-element list or tuple, since we intend to change
the default behavior in a future Theano version.
This option has no effect when `wrt` is a single `Variable` (in which case This option has no effect when `wrt` is a single `Variable` (in which case
the returned value is always a single `Variable`). the returned value is always a single `Variable`).
...@@ -338,20 +333,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False, ...@@ -338,20 +333,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
ret = tuple(ret) ret = tuple(ret)
if len(ret) == 1: if len(ret) == 1:
if (using_list or using_tuple) and keep_wrt_type is None: if keep_wrt_type and (using_list or using_tuple):
warnings.warn(
"The return type of `tensor.grad(cost, wrt)` will change "
"in the case where `wrt` is a one-element list/tuple. "
"In the future `grad(cost, wrt)` will return by default "
"an object of the same type as `wrt` (so if `wrt` is a "
"list/tuple, a list/tuple will be returned, while if it "
"is a single Variable, then a single Variable will be "
"returned). You may get rid of this warning by adding "
"'keep_wrt_type=True' (or False) when calling "
"`tensor.grad`, depending on whether you want the new "
"or old behavior.",
stacklevel=2)
if keep_wrt_type:
return ret return ret
else: else:
return ret[0] return ret[0]
......
...@@ -22,6 +22,7 @@ from theano.gof.python25 import any, all, combinations ...@@ -22,6 +22,7 @@ from theano.gof.python25 import any, all, combinations
from theano.compile.mode import get_default_mode from theano.compile.mode import get_default_mode
from theano import function from theano import function
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
import theano.tensor as T
imported_scipy_special = False imported_scipy_special = False
...@@ -3662,6 +3663,25 @@ class test_grad(unittest.TestCase): ...@@ -3662,6 +3663,25 @@ class test_grad(unittest.TestCase):
self.assertTrue(o.gval0 is g0) self.assertTrue(o.gval0 is g0)
self.assertTrue(o.gval1 is g1) self.assertTrue(o.gval1 is g1)
def test_grad_keep_type(self):
"""Tests that the theano grad method returns a list if it is passed a list
and a single variable if it is passed a single variable.
pylearn2 depends on theano behaving this way but theano developers have
repeatedly changed it """
X = T.matrix()
y = X.sum()
G = T.grad(y, [X])
assert isinstance(G,list)
G = T.grad(y, X)
assert not isinstance(G,list)
def test_1None_rval(self): def test_1None_rval(self):
"""grad: Test returning a single zero value from grad""" """grad: Test returning a single zero value from grad"""
o = test_grad.O() o = test_grad.O()
...@@ -5199,6 +5219,7 @@ class test_size(unittest.TestCase): ...@@ -5199,6 +5219,7 @@ class test_size(unittest.TestCase):
assert y.size == function([], x.size)() assert y.size == function([], x.size)()
if __name__ == '__main__': if __name__ == '__main__':
if 1: if 1:
unittest.main() unittest.main()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论