提交 0b8b021c authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Added new option 'keep_wrt_type' to tensor.grad

This allows one to decide which behavior is desired (future new behavior or old one) and get rid of the warning at the same time.
上级 4599b25b
......@@ -233,8 +233,10 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False,
# Gradient
#########################
# TODO For Theano 0.5, change default value of `keep_wrt_type` to True
# and get rid of the `None` option (in docstring and in code).
def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
disconnected_inputs='raise'):
disconnected_inputs='raise', keep_wrt_type=None):
"""
:type cost: Scalar (0-dimensional) `Variable`
:type wrt: `Variable` or list of `Variable`s.
......@@ -254,6 +256,17 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
- 'warn': consider the gradient zero, and print a warning.
- 'raise': raise an exception.
:param keep_wrt_type: When True, if `wrt` is a list or tuple, then the
returned output is of the same type. When False, if `wrt` is a one-element
list or tuple, then the returned value is a single `Variable` (and if
`wrt` is a list or tuple with at least two elements, then the returned
value is always a list -- never a tuple). This option may also be set to
None, in which case it behaves as if it was False, but a warning is also
issued when `wrt` is a one-element list or tuple, since we intend to change
the default behavior in a future Theano version.
This option has no effect when `wrt` is a single `Variable` (in which case
the returned value is always a single `Variable`).
:rtype: `Variable` or list/tuple of `Variable`s (depending upon `wrt`)
:return: symbolic expression of gradient of `cost` with respect to `wrt`.
......@@ -321,28 +334,31 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
"'ignore', 'warn' and 'raise'.")
ret.append(zeros_like(p))
if keep_wrt_type and using_tuple:
ret = tuple(ret)
if len(ret) == 1:
if using_list or using_tuple:
warnings.warn(("The return type of tensor.grad will change in this "
"case. In the future grad(cost, wrt) will return an "
"object of the same type as wrt. So if wrt is a "
"list/tuple, list/tuple will be returned. Idem for "
"TensorVariable."),
stacklevel=2)
# TODO: when we release Theano 0.5, uncomment the following lines
# and remove the warning. Don't forget the line in the currently
# enabled else.
#if using_list:
# return ret
#elif using_tuple:
# return tuple(ret)
#else:
return ret[0]
if (using_list or using_tuple) and keep_wrt_type is None:
warnings.warn(
"The return type of `tensor.grad(cost, wrt)` will change "
"in the case where `wrt` is a one-element list/tuple. "
"In the future `grad(cost, wrt)` will return by default "
"an object of the same type as `wrt` (so if `wrt` is a "
"list/tuple, a list/tuple will be returned, while if it "
"is a single Variable, then a single Variable will be "
"returned). You may get rid of this warning by adding "
"'keep_wrt_type=True' (or False) when calling "
"`tensor.grad`, depending on whether you want the new "
"or old behavior.",
stacklevel=2)
if keep_wrt_type:
return ret
else:
return ret[0]
else:
#if using_tuple:
# return tuple(ret)
return ret
class numeric_grad:
"""WRITEME"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论