提交 7e2f4871 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

added a unit test of the disconnected gradient tracking system

上级 c08a4ec0
...@@ -522,6 +522,8 @@ def test_undefined_cost_grad(): ...@@ -522,6 +522,8 @@ def test_undefined_cost_grad():
# Tests that if we say the cost is not differentiable via the # Tests that if we say the cost is not differentiable via the
# known_grads mechanism, it is treated as such by the rest of the # known_grads mechanism, it is treated as such by the rest of the
# system. # system.
# This is so that Ops that are built around minigraphs like OpFromGraph
# and scan can implement Op.grad by passing ograds to known_grads
x = theano.tensor.iscalar() x = theano.tensor.iscalar()
y = theano.tensor.iscalar() y = theano.tensor.iscalar()
...@@ -533,6 +535,24 @@ def test_undefined_cost_grad(): ...@@ -533,6 +535,24 @@ def test_undefined_cost_grad():
return return
raise AssertionError("An undefined gradient has been ignored.") raise AssertionError("An undefined gradient has been ignored.")
def test_disconnected_cost_grad():
# Tests that if we say the cost is disconnected via the
# known_grads mechanism, it is treated as such by the rest of the
# system.
# This is so that Ops that are built around minigraphs like OpFromGraph
# and scan can implement Op.grad by passing ograds to known_grads
x = theano.tensor.iscalar()
y = theano.tensor.iscalar()
cost = x + y
assert cost.dtype in theano.tensor.discrete_dtypes
try:
grads = theano.tensor.grad(cost, [x, y], known_grads = {cost: gradient.DisconnectedType()() },
disconnected_inputs='raise')
except theano.gradient.DisconnectedInputError:
return
raise AssertionError("A disconnected gradient has been ignored.")
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论