提交 ccdff864 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Fixed the ticket about the grad method silently using 0 as the gradients wrt

to a variable when it did not know how to compute those gradients.
上级 7eafc8ae
...@@ -4717,13 +4717,29 @@ def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False): ...@@ -4717,13 +4717,29 @@ def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False):
list(inputs) + list(consider_constant), list(inputs) + list(consider_constant),
warn_type=warn_type) warn_type=warn_type)
# Note that it is important to use `zeros_like` when there is no gradient,
# instead of returning a scalar constant equal to zero. Otherwise we lose # Note : If p is not in gmap there can be several reasons, among which
# the guarantee that the gradient has same shape as `wrt`. # is the fact that p might not be part of the computational graph. A
if isinstance(wrt, (list, tuple)): # simple example is that for a+b for e.g. a[0] is not part of the graph,
return [gmap.get(p, zeros_like(p)) for p in wrt] # so Theano does not know how to compute TT.grad(TT.sum(a+b), a[0])
# such subtle cases can be fixed by a more careful implementation of the
# gradient, but for now Theano needs to throw an exception, and make the
# user aware that it does not know how to compute that gradient
if not isinstance(wrt, (list, tuple)):
wrt = [wrt]
ret = []
for p in wrt:
if p not in gmap:
raise ValueError(("grad method was asked to compute the graident "
"with respect to a variable that is not part of "
"the computational graph of the cost"),p)
else:
ret.append(gmap[p])
if len(ret) == 1:
return ret[0]
else: else:
return gmap.get(wrt, zeros_like(wrt)) return ret
class numeric_grad: class numeric_grad:
"""WRITEME""" """WRITEME"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论