提交 6a023bd4 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

added return_disconnected parameter to grad function

上级 4d7dc3e0
...@@ -350,7 +350,7 @@ def Lop(f, wrt, eval_points, consider_constant=None, ...@@ -350,7 +350,7 @@ def Lop(f, wrt, eval_points, consider_constant=None,
def grad(cost, wrt, g_cost=None, consider_constant=None, def grad(cost, wrt, g_cost=None, consider_constant=None,
disconnected_inputs='raise', add_names=True, disconnected_inputs='raise', add_names=True,
known_grads=None): known_grads=None, return_disconnected='zero'):
""" """
:type cost: Scalar (0-dimensional) Variable. :type cost: Scalar (0-dimensional) Variable.
May optionally be None if known_grads is provided. May optionally be None if known_grads is provided.
...@@ -380,6 +380,14 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, ...@@ -380,6 +380,14 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
gradient on some variables but do not know the original gradient on some variables but do not know the original
cost. cost.
:type return_disconnected: string
:param return_disconnected:
'zero' : If wrt[i] is disconnected, return value i will be
wrt[i].zeros_like()
'None' : If wrt[i] is disconnected, return value i will be
None
'Disconnected' : returns variables of type DisconnectedType
:rtype: Variable or list/tuple of Variables (depending upon `wrt`) :rtype: Variable or list/tuple of Variables (depending upon `wrt`)
:return: symbolic expression of gradient of `cost` with respect to `wrt`. :return: symbolic expression of gradient of `cost` with respect to `wrt`.
...@@ -532,7 +540,12 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, ...@@ -532,7 +540,12 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
for i in xrange(len(rval)): for i in xrange(len(rval)):
if isinstance(rval[i].type, DisconnectedType): if isinstance(rval[i].type, DisconnectedType):
rval[i] = _float_zeros_like(wrt[i]) if return_disconnected == 'zero':
rval[i] = _float_zeros_like(wrt[i])
elif return_disconnected == 'None':
rval[i] = None
else:
assert return_disconnected == 'Disconnected'
if using_tuple: if using_tuple:
rval = tuple(rval) rval = tuple(rval)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论