提交 0b8b021c authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Added new option 'keep_wrt_type' to tensor.grad

This allows one to decide which behavior is desired (future new behavior or old one) and get rid of the warning at the same time.
上级 4599b25b
...@@ -233,8 +233,10 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False, ...@@ -233,8 +233,10 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False,
# Gradient # Gradient
######################### #########################
# TODO For Theano 0.5, change default value of `keep_wrt_type` to True
# and get rid of the `None` option (in docstring and in code).
def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False, def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
disconnected_inputs='raise'): disconnected_inputs='raise', keep_wrt_type=None):
""" """
:type cost: Scalar (0-dimensional) `Variable` :type cost: Scalar (0-dimensional) `Variable`
:type wrt: `Variable` or list of `Variable`s. :type wrt: `Variable` or list of `Variable`s.
...@@ -254,6 +256,17 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False, ...@@ -254,6 +256,17 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
- 'warn': consider the gradient zero, and print a warning. - 'warn': consider the gradient zero, and print a warning.
- 'raise': raise an exception. - 'raise': raise an exception.
:param keep_wrt_type: When True, if `wrt` is a list or tuple, then the
returned output is of the same type. When False, if `wrt` is a one-element
list or tuple, then the returned value is a single `Variable` (and if
`wrt` is a list or tuple with at least two elements, then the returned
value is always a list -- never a tuple). This option may also be set to
None, in which case it behaves as if it was False, but a warning is also
issued when `wrt` is a one-element list or tuple, since we intend to change
the default behavior in a future Theano version.
This option has no effect when `wrt` is a single `Variable` (in which case
the returned value is always a single `Variable`).
:rtype: `Variable` or list/tuple of `Variable`s (depending upon `wrt`) :rtype: `Variable` or list/tuple of `Variable`s (depending upon `wrt`)
:return: symbolic expression of gradient of `cost` with respect to `wrt`. :return: symbolic expression of gradient of `cost` with respect to `wrt`.
...@@ -321,28 +334,31 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False, ...@@ -321,28 +334,31 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
"'ignore', 'warn' and 'raise'.") "'ignore', 'warn' and 'raise'.")
ret.append(zeros_like(p)) ret.append(zeros_like(p))
if keep_wrt_type and using_tuple:
ret = tuple(ret)
if len(ret) == 1: if len(ret) == 1:
if using_list or using_tuple: if (using_list or using_tuple) and keep_wrt_type is None:
warnings.warn(("The return type of tensor.grad will change in this " warnings.warn(
"case. In the future grad(cost, wrt) will return an " "The return type of `tensor.grad(cost, wrt)` will change "
"object of the same type as wrt. So if wrt is a " "in the case where `wrt` is a one-element list/tuple. "
"list/tuple, list/tuple will be returned. Idem for " "In the future `grad(cost, wrt)` will return by default "
"TensorVariable."), "an object of the same type as `wrt` (so if `wrt` is a "
stacklevel=2) "list/tuple, a list/tuple will be returned, while if it "
# TODO: when we release Theano 0.5, uncomment the following lines "is a single Variable, then a single Variable will be "
# and remove the warning. Don't forget the line in the currently "returned). You may get rid of this warning by adding "
# enabled else. "'keep_wrt_type=True' (or False) when calling "
#if using_list: "`tensor.grad`, depending on whether you want the new "
# return ret "or old behavior.",
#elif using_tuple: stacklevel=2)
# return tuple(ret) if keep_wrt_type:
#else: return ret
return ret[0] else:
return ret[0]
else: else:
#if using_tuple:
# return tuple(ret)
return ret return ret
class numeric_grad: class numeric_grad:
"""WRITEME""" """WRITEME"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论