提交 c08a4ec0 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

added a specific exception for disconnected inputs, so it can be

explicitly unit tested
上级 17458063
...@@ -365,7 +365,7 @@ def grad(cost, wrt, consider_constant=None, ...@@ -365,7 +365,7 @@ def grad(cost, wrt, consider_constant=None,
(or if all links are non-differentiable). The possible values are: (or if all links are non-differentiable). The possible values are:
- 'ignore': considers that the gradient on these parameters is zero. - 'ignore': considers that the gradient on these parameters is zero.
- 'warn': consider the gradient zero, and print a warning. - 'warn': consider the gradient zero, and print a warning.
- 'raise': raise an exception. - 'raise': raise DisconnectedInputError.
:type add_names: bool :type add_names: bool
:param add_names: If True, variables generated by grad will be named :param add_names: If True, variables generated by grad will be named
...@@ -499,7 +499,7 @@ def grad(cost, wrt, consider_constant=None, ...@@ -499,7 +499,7 @@ def grad(cost, wrt, consider_constant=None,
elif disconnected_inputs == 'warn': elif disconnected_inputs == 'warn':
warnings.warn(message, stacklevel=2) warnings.warn(message, stacklevel=2)
elif disconnected_inputs == 'raise': elif disconnected_inputs == 'raise':
raise ValueError(message) raise DisconnectedInputError(message)
else: else:
raise ValueError("Invalid value for keyword " raise ValueError("Invalid value for keyword "
"'disconnected_inputs', valid values are " "'disconnected_inputs', valid values are "
...@@ -719,7 +719,13 @@ class NullTypeGradError(TypeError): ...@@ -719,7 +719,13 @@ class NullTypeGradError(TypeError):
""" """
Raised when grad encounters a NullType. Raised when grad encounters a NullType.
""" """
pass
class DisconnectedInputError(ValueError):
"""
Raised when grad is asked to compute the gradient
with respect to a disconnected input and
disconnected_inputs='raise'.
"""
def _populate_grad_dict(var_to_node_to_idx, def _populate_grad_dict(var_to_node_to_idx,
grad_dict, wrt, cost_name=None): grad_dict, wrt, cost_name=None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论