提交 7e5d2077 authored 作者: goodfeli's avatar goodfeli

Merge pull request #105 from delallea/remove_keep_wrt_type

Removed keep_wrt_type parameter of tensor.grad
......@@ -234,7 +234,7 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False,
#########################
def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
disconnected_inputs='raise', keep_wrt_type=True):
disconnected_inputs='raise'):
"""
:type cost: Scalar (0-dimensional) `Variable`
:type wrt: `Variable` or list of `Variable`s.
......@@ -254,14 +254,6 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
- 'warn': consider the gradient zero, and print a warning.
- 'raise': raise an exception.
:param keep_wrt_type: When True, if `wrt` is a list or tuple, then the
returned output is of the same type. When False, if `wrt` is a one-element
list or tuple, then the returned value is a single `Variable` (and if
`wrt` is a list or tuple with at least two elements, then the returned
value is always a list -- never a tuple).
This option has no effect when `wrt` is a single `Variable` (in which case
the returned value is always a single `Variable`).
:rtype: `Variable` or list/tuple of `Variable`s (depending upon `wrt`)
:return: symbolic expression of gradient of `cost` with respect to `wrt`.
......@@ -329,16 +321,16 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
"'ignore', 'warn' and 'raise'.")
ret.append(zeros_like(p))
if keep_wrt_type and using_tuple:
ret = tuple(ret)
if len(ret) == 1:
if keep_wrt_type and (using_list or using_tuple):
return ret
else:
if len(ret) == 1 and not (using_list or using_tuple):
# `wrt` was a single Variable, so we return a single Variable too.
return ret[0]
else:
return ret
# Ensure we preserve the original type of `wrt`.
if using_tuple:
return tuple(ret)
else:
assert using_list
return ret
class numeric_grad:
......@@ -614,8 +606,7 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, abs_tol=None, rel_tol=No
g_cost = cast(g_cost, o_output.dtype)
symbolic_grad = grad(cost, tensor_pt, g_cost,
disconnected_inputs='ignore',
keep_wrt_type=True)
disconnected_inputs='ignore')
#if o_output.dtype in ['float32','float64']:
# assert all([x.dtype == o_output.dtype for x in symbolic_grad]),("Expected grad of type %s, got %s "%( symbolic_grad.dtype, o_output.dtyp))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论