提交 f2c73f6f authored 作者: nouiz's avatar nouiz

Merge pull request #1084 from goodfeli/test_grad_2

Fix disconnected input bug
......@@ -365,7 +365,7 @@ def grad(cost, wrt, consider_constant=None,
(or if all links are non-differentiable). The possible values are:
- 'ignore': considers that the gradient on these parameters is zero.
- 'warn': consider the gradient zero, and print a warning.
- 'raise': raise an exception.
- 'raise': raise DisconnectedInputError.
:type add_names: bool
:param add_names: If True, variables generated by grad will be named
......@@ -482,28 +482,31 @@ def grad(cost, wrt, consider_constant=None,
grad_dict[var] = g_var
# variables that do not influence the cost have zero gradient.
# if wrt is such a variable, populate the grad_dict with this info
# so that wrt not being in var_to_node_to_idx won't cause an error below
# according to the flag, possibly raise an error if wrt is disconnected
for elem in wrt:
if elem not in var_to_node_to_idx and elem is not cost \
and elem not in grad_dict:
def handle_disconnected(var):
message = ("grad method was asked to compute the gradient "
"with respect to a variable that is not part of "
"the computational graph of the cost, or is used "
"only by a non-differentiable operator: %s" % elem)
"only by a non-differentiable operator: %s" % var)
if disconnected_inputs == 'ignore':
pass
elif disconnected_inputs == 'warn':
warnings.warn(message, stacklevel=2)
elif disconnected_inputs == 'raise':
raise ValueError(message)
raise DisconnectedInputError(message)
else:
raise ValueError("Invalid value for keyword "
"'disconnected_inputs', valid values are "
"'ignore', 'warn' and 'raise'.")
# variables that do not influence the cost have zero gradient.
# if wrt is such a variable, populate the grad_dict with this info
# so that wrt not being in var_to_node_to_idx won't cause an error below
# according to the flag, possibly raise an error if wrt is disconnected
for elem in wrt:
if elem not in var_to_node_to_idx and elem is not cost \
and elem not in grad_dict:
handle_disconnected(elem)
grad_dict[elem] = DisconnectedType()()
cost_name = None
......@@ -523,6 +526,7 @@ def grad(cost, wrt, consider_constant=None,
for i in xrange(len(rval)):
if isinstance(rval[i].type, DisconnectedType):
handle_disconnected(rval[i])
if return_disconnected == 'zero':
rval[i] = _float_zeros_like(wrt[i])
elif return_disconnected == 'None':
......@@ -719,7 +723,13 @@ class NullTypeGradError(TypeError):
"""
Raised when grad encounters a NullType.
"""
pass
class DisconnectedInputError(ValueError):
"""
Raised when grad is asked to compute the gradient
with respect to a disconnected input and
disconnected_inputs='raise'.
"""
def _populate_grad_dict(var_to_node_to_idx,
grad_dict, wrt, cost_name=None):
......
......@@ -522,6 +522,8 @@ def test_undefined_cost_grad():
# Tests that if we say the cost is not differentiable via the
# known_grads mechanism, it is treated as such by the rest of the
# system.
# This is so that Ops that are built around minigraphs like OpFromGraph
# and scan can implement Op.grad by passing ograds to known_grads
x = theano.tensor.iscalar()
y = theano.tensor.iscalar()
......@@ -533,6 +535,24 @@ def test_undefined_cost_grad():
return
raise AssertionError("An undefined gradient has been ignored.")
def test_disconnected_cost_grad():
# Tests that if we say the cost is disconnected via the
# known_grads mechanism, it is treated as such by the rest of the
# system.
# This is so that Ops that are built around minigraphs like OpFromGraph
# and scan can implement Op.grad by passing ograds to known_grads
x = theano.tensor.iscalar()
y = theano.tensor.iscalar()
cost = x + y
assert cost.dtype in theano.tensor.discrete_dtypes
try:
grads = theano.tensor.grad(cost, [x, y], known_grads = {cost: gradient.DisconnectedType()() },
disconnected_inputs='raise')
except theano.gradient.DisconnectedInputError:
return
raise AssertionError("A disconnected gradient has been ignored.")
if __name__ == '__main__':
unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论