提交 5fa90044 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Catch exception and forward NullType to input gradients, instead of failing

上级 d8b3d2dc
......@@ -1476,16 +1476,28 @@ class Scan(PureOp):
if (x in diff_inputs) and
(connection_pattern[
get_inp_idx(self_inputs.index(x))][odx])]
grads = gradient.grad(
cost=None,
known_grads={y: g_y},
wrt=wrt,
consider_constant=wrt,
disconnected_inputs='ignore',
return_disconnected='None')
gmp = dict(zip(wrt, grads))
gmp = OrderedDict()
for x in wrt:
try:
gmp[x] = gradient.grad(
cost=None,
known_grads={y: g_y},
wrt=x,
consider_constant=wrt,
disconnected_inputs='ignore',
return_disconnected='None')
except gradient.NullTypeGradError:
# The gradient wrt that particular input is undefined.
# This is not necessarily an issue, because maybe that
# particular input is not in the path between the
# "cost" and "wrt" of the external, initial call to grad().
# We simply forward the Null gradient.
gmp[x] = x
rval = [gmp.get(p, None) for p in diff_inputs]
return rval
dC_dinps_t = [None for inp in diff_inputs]
disconnected_dC_dinps_t = [True for inp in diff_inputs]
dC_dXts = []
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论