提交 f2c73f6f authored 作者: nouiz's avatar nouiz

Merge pull request #1084 from goodfeli/test_grad_2

Fix disconnected input bug
...@@ -365,7 +365,7 @@ def grad(cost, wrt, consider_constant=None, ...@@ -365,7 +365,7 @@ def grad(cost, wrt, consider_constant=None,
(or if all links are non-differentiable). The possible values are: (or if all links are non-differentiable). The possible values are:
- 'ignore': considers that the gradient on these parameters is zero. - 'ignore': considers that the gradient on these parameters is zero.
- 'warn': consider the gradient zero, and print a warning. - 'warn': consider the gradient zero, and print a warning.
- 'raise': raise an exception. - 'raise': raise DisconnectedInputError.
:type add_names: bool :type add_names: bool
:param add_names: If True, variables generated by grad will be named :param add_names: If True, variables generated by grad will be named
...@@ -482,28 +482,31 @@ def grad(cost, wrt, consider_constant=None, ...@@ -482,28 +482,31 @@ def grad(cost, wrt, consider_constant=None,
grad_dict[var] = g_var grad_dict[var] = g_var
def handle_disconnected(var):
# variables that do not influence the cost have zero gradient.
# if wrt is such a variable, populate the grad_dict with this info
# so that wrt not being in var_to_node_to_idx won't cause an error below
# according to the flag, possibly raise an error if wrt is disconnected
for elem in wrt:
if elem not in var_to_node_to_idx and elem is not cost \
and elem not in grad_dict:
message = ("grad method was asked to compute the gradient " message = ("grad method was asked to compute the gradient "
"with respect to a variable that is not part of " "with respect to a variable that is not part of "
"the computational graph of the cost, or is used " "the computational graph of the cost, or is used "
"only by a non-differentiable operator: %s" % elem) "only by a non-differentiable operator: %s" % var)
if disconnected_inputs == 'ignore': if disconnected_inputs == 'ignore':
pass pass
elif disconnected_inputs == 'warn': elif disconnected_inputs == 'warn':
warnings.warn(message, stacklevel=2) warnings.warn(message, stacklevel=2)
elif disconnected_inputs == 'raise': elif disconnected_inputs == 'raise':
raise ValueError(message) raise DisconnectedInputError(message)
else: else:
raise ValueError("Invalid value for keyword " raise ValueError("Invalid value for keyword "
"'disconnected_inputs', valid values are " "'disconnected_inputs', valid values are "
"'ignore', 'warn' and 'raise'.") "'ignore', 'warn' and 'raise'.")
# variables that do not influence the cost have zero gradient.
# if wrt is such a variable, populate the grad_dict with this info
# so that wrt not being in var_to_node_to_idx won't cause an error below
# according to the flag, possibly raise an error if wrt is disconnected
for elem in wrt:
if elem not in var_to_node_to_idx and elem is not cost \
and elem not in grad_dict:
handle_disconnected(elem)
grad_dict[elem] = DisconnectedType()() grad_dict[elem] = DisconnectedType()()
cost_name = None cost_name = None
...@@ -523,6 +526,7 @@ def grad(cost, wrt, consider_constant=None, ...@@ -523,6 +526,7 @@ def grad(cost, wrt, consider_constant=None,
for i in xrange(len(rval)): for i in xrange(len(rval)):
if isinstance(rval[i].type, DisconnectedType): if isinstance(rval[i].type, DisconnectedType):
handle_disconnected(rval[i])
if return_disconnected == 'zero': if return_disconnected == 'zero':
rval[i] = _float_zeros_like(wrt[i]) rval[i] = _float_zeros_like(wrt[i])
elif return_disconnected == 'None': elif return_disconnected == 'None':
...@@ -719,7 +723,13 @@ class NullTypeGradError(TypeError): ...@@ -719,7 +723,13 @@ class NullTypeGradError(TypeError):
""" """
Raised when grad encounters a NullType. Raised when grad encounters a NullType.
""" """
pass
class DisconnectedInputError(ValueError):
"""
Raised when grad is asked to compute the gradient
with respect to a disconnected input and
disconnected_inputs='raise'.
"""
def _populate_grad_dict(var_to_node_to_idx, def _populate_grad_dict(var_to_node_to_idx,
grad_dict, wrt, cost_name=None): grad_dict, wrt, cost_name=None):
......
...@@ -522,6 +522,8 @@ def test_undefined_cost_grad(): ...@@ -522,6 +522,8 @@ def test_undefined_cost_grad():
# Tests that if we say the cost is not differentiable via the # Tests that if we say the cost is not differentiable via the
# known_grads mechanism, it is treated as such by the rest of the # known_grads mechanism, it is treated as such by the rest of the
# system. # system.
# This is so that Ops that are built around minigraphs like OpFromGraph
# and scan can implement Op.grad by passing ograds to known_grads
x = theano.tensor.iscalar() x = theano.tensor.iscalar()
y = theano.tensor.iscalar() y = theano.tensor.iscalar()
...@@ -533,6 +535,24 @@ def test_undefined_cost_grad(): ...@@ -533,6 +535,24 @@ def test_undefined_cost_grad():
return return
raise AssertionError("An undefined gradient has been ignored.") raise AssertionError("An undefined gradient has been ignored.")
def test_disconnected_cost_grad():
# Tests that if we say the cost is disconnected via the
# known_grads mechanism, it is treated as such by the rest of the
# system.
# This is so that Ops that are built around minigraphs like OpFromGraph
# and scan can implement Op.grad by passing ograds to known_grads
x = theano.tensor.iscalar()
y = theano.tensor.iscalar()
cost = x + y
assert cost.dtype in theano.tensor.discrete_dtypes
try:
grads = theano.tensor.grad(cost, [x, y], known_grads = {cost: gradient.DisconnectedType()() },
disconnected_inputs='raise')
except theano.gradient.DisconnectedInputError:
return
raise AssertionError("A disconnected gradient has been ignored.")
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论