提交 5fa90044 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Catch exception and forward NullType to input gradients, instead of failing

上级 d8b3d2dc
...@@ -1476,16 +1476,28 @@ class Scan(PureOp): ...@@ -1476,16 +1476,28 @@ class Scan(PureOp):
if (x in diff_inputs) and if (x in diff_inputs) and
(connection_pattern[ (connection_pattern[
get_inp_idx(self_inputs.index(x))][odx])] get_inp_idx(self_inputs.index(x))][odx])]
grads = gradient.grad( gmp = OrderedDict()
cost=None,
known_grads={y: g_y}, for x in wrt:
wrt=wrt, try:
consider_constant=wrt, gmp[x] = gradient.grad(
disconnected_inputs='ignore', cost=None,
return_disconnected='None') known_grads={y: g_y},
gmp = dict(zip(wrt, grads)) wrt=x,
consider_constant=wrt,
disconnected_inputs='ignore',
return_disconnected='None')
except gradient.NullTypeGradError:
# The gradient wrt that particular input is undefined.
# This is not necessarily an issue, because maybe that
# particular input is not in the path between the
# "cost" and "wrt" of the external, initial call to grad().
# We simply forward the Null gradient.
gmp[x] = x
rval = [gmp.get(p, None) for p in diff_inputs] rval = [gmp.get(p, None) for p in diff_inputs]
return rval return rval
dC_dinps_t = [None for inp in diff_inputs] dC_dinps_t = [None for inp in diff_inputs]
disconnected_dC_dinps_t = [True for inp in diff_inputs] disconnected_dC_dinps_t = [True for inp in diff_inputs]
dC_dXts = [] dC_dXts = []
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论