提交 d77cd978 authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier

Add unit test for connection_pattern

上级 e3e6951e
......@@ -719,6 +719,14 @@ class TestDisconnectedGrad(unittest.TestCase):
assert np.allclose(f(a), f2(a))
def test_connection_pattern(self):
T = theano.tensor
x = T.matrix('x')
y = gradient.disconnected_grad(x)
connection_pattern = y.owner.op.connection_pattern(y.owner)
assert connection_pattern == [[False]]
def test_disconnected_paths(self):
# Test that taking gradient going through a disconnected
# path rasises an exception
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论