提交 776dabd1 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2633 from carriepl/disconnected_grad_connection_pattern

Add connection_pattern
...@@ -1945,6 +1945,8 @@ class DisconnectedGrad(ViewOp): ...@@ -1945,6 +1945,8 @@ class DisconnectedGrad(ViewOp):
def grad(self, args, g_outs): def grad(self, args, g_outs):
return [disconnected_type() for g_out in g_outs] return [disconnected_type() for g_out in g_outs]
def connection_pattern(self, node):
return [[False]]
disconnected_grad_ = DisconnectedGrad() disconnected_grad_ = DisconnectedGrad()
......
...@@ -719,6 +719,14 @@ class TestDisconnectedGrad(unittest.TestCase): ...@@ -719,6 +719,14 @@ class TestDisconnectedGrad(unittest.TestCase):
assert np.allclose(f(a), f2(a)) assert np.allclose(f(a), f2(a))
def test_connection_pattern(self):
T = theano.tensor
x = T.matrix('x')
y = gradient.disconnected_grad(x)
connection_pattern = y.owner.op.connection_pattern(y.owner)
assert connection_pattern == [[False]]
def test_disconnected_paths(self): def test_disconnected_paths(self):
# Test that taking gradient going through a disconnected # Test that taking gradient going through a disconnected
# path rasises an exception # path rasises an exception
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论