提交 2040b54a authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fixed a check for undefined grads in scan_module that could be triggered

by other problems
上级 70ed4221
......@@ -753,6 +753,11 @@ def _populate_var_to_node_to_idx(outputs, wrt, consider_constant):
return var_to_app_to_idx
class NullTypeGradError(TypeError):
"""
Raised when grad encounters a NullType.
"""
pass
def _populate_grad_dict(var_to_node_to_idx,
grad_dict, wrt, cost_name=None):
......@@ -1010,7 +1015,7 @@ def _populate_grad_dict(var_to_node_to_idx,
type(term)))
if isinstance(term.type, NullType):
raise TypeError("tensor.grad "
raise NullTypeGradError("tensor.grad "
"encountered a NaN. " +\
term.type.why_null)
......
......@@ -1261,11 +1261,10 @@ class Scan(PureOp):
if x in diff_inputs]
for x in consider_inps:
try:
_gmp = gradient.grad_sources_inputs(
[(y, g_y)],
[x])
gmp[x] = _gmp[x]
except TypeError:
gmp[x] = gradient.grad(cost=None,
known_grads={y: g_y},
wrt=x)
except gradient.NullTypeGradError:
# It means the gradient is undefined (which implies
# is connected)
gmp[x] = x
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论