提交 1282743e authored 作者: Olivier Delalleau's avatar Olivier Delalleau

When a gradient is zero because the variable has no influence on the result, now…

When a gradient is zero because the variable has no influence on the result, now return a variable with same shape instead of a 0 scalar, so that we can rely on the shape being unchanged
上级 7c7c6918
......@@ -3452,8 +3452,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False):
:return: symbolic expression of gradient of `cost` with respect to `wrt`.
If `wrt` is a list, then return a list containing the gradient of `cost` wrt
each element of the list. If an element of `wrt` is not differentiable
with respect to the output, then a `TensorConstant` with an appropriate
kind of zero is returned.
with respect to the output, then a zero variable is returned.
This function is a wrapper around a the more general function
`theano.gradient.grad_sources_inputs``.
......@@ -3473,21 +3472,13 @@ def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False):
gmap = gradient.grad_sources_inputs([(cost, g_cost)], inputs + consider_constant,
warn_type=warn_type)
def zero(p):
return TensorConstant(
TensorType(dtype = p.type.dtype, broadcastable = []),
theano._asarray(0, dtype=p.type.dtype))
#try:
#it = iter(wrt)
#except:
#it = None
#if it: #hasattr(wrt, '__iter__'): # isinstance(wrt, (list, tuple)):
# Note that it is important to use `zeros_like` when there is no gradient,
# instead of returning a scalar constant equal to zero. Otherwise we lose
# the guarantee that the gradient has same shape as `wrt`.
if isinstance(wrt, (list, tuple)):
return [gmap.get(p, zero(p)) for p in wrt]
return [gmap.get(p, zeros_like(p)) for p in wrt]
else:
return gmap.get(wrt, zero(wrt))
return gmap.get(wrt, zeros_like(wrt))
class numeric_grad:
"""WRITEME"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论