提交 c08a4ec0 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

added a specific exception for disconnected inputs, so it can be

explicitly unit tested
上级 17458063
......@@ -365,7 +365,7 @@ def grad(cost, wrt, consider_constant=None,
(or if all links are non-differentiable). The possible values are:
- 'ignore': considers that the gradient on these parameters is zero.
- 'warn': consider the gradient zero, and print a warning.
- 'raise': raise an exception.
- 'raise': raise DisconnectedInputError.
:type add_names: bool
:param add_names: If True, variables generated by grad will be named
......@@ -499,7 +499,7 @@ def grad(cost, wrt, consider_constant=None,
elif disconnected_inputs == 'warn':
warnings.warn(message, stacklevel=2)
elif disconnected_inputs == 'raise':
raise ValueError(message)
raise DisconnectedInputError(message)
else:
raise ValueError("Invalid value for keyword "
"'disconnected_inputs', valid values are "
......@@ -719,7 +719,13 @@ class NullTypeGradError(TypeError):
"""
Raised when grad encounters a NullType.
"""
pass
class DisconnectedInputError(ValueError):
"""
Raised when grad is asked to compute the gradient
with respect to a disconnected input and
disconnected_inputs='raise'.
"""
def _populate_grad_dict(var_to_node_to_idx,
grad_dict, wrt, cost_name=None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论