提交 2040b54a authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fixed a check for undefined grads in scan_module that could be triggered

by other problems
上级 70ed4221
...@@ -753,6 +753,11 @@ def _populate_var_to_node_to_idx(outputs, wrt, consider_constant): ...@@ -753,6 +753,11 @@ def _populate_var_to_node_to_idx(outputs, wrt, consider_constant):
return var_to_app_to_idx return var_to_app_to_idx
class NullTypeGradError(TypeError):
"""
Raised when grad encounters a NullType.
"""
pass
def _populate_grad_dict(var_to_node_to_idx, def _populate_grad_dict(var_to_node_to_idx,
grad_dict, wrt, cost_name=None): grad_dict, wrt, cost_name=None):
...@@ -1010,7 +1015,7 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -1010,7 +1015,7 @@ def _populate_grad_dict(var_to_node_to_idx,
type(term))) type(term)))
if isinstance(term.type, NullType): if isinstance(term.type, NullType):
raise TypeError("tensor.grad " raise NullTypeGradError("tensor.grad "
"encountered a NaN. " +\ "encountered a NaN. " +\
term.type.why_null) term.type.why_null)
......
...@@ -1261,11 +1261,10 @@ class Scan(PureOp): ...@@ -1261,11 +1261,10 @@ class Scan(PureOp):
if x in diff_inputs] if x in diff_inputs]
for x in consider_inps: for x in consider_inps:
try: try:
_gmp = gradient.grad_sources_inputs( gmp[x] = gradient.grad(cost=None,
[(y, g_y)], known_grads={y: g_y},
[x]) wrt=x)
gmp[x] = _gmp[x] except gradient.NullTypeGradError:
except TypeError:
# It means the gradient is undefined (which implies # It means the gradient is undefined (which implies
# is connected) # is connected)
gmp[x] = x gmp[x] = x
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论