提交 2377786e authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier

Add test case for connection_pattern

上级 75f3fd4e
......@@ -838,6 +838,26 @@ class T_Scan(unittest.TestCase):
n_steps=2)
tensor.grad(a[-1], a0)
def test_connection_pattern2():
# This tests for a crash in connection_pattern() when a scan node
# has more than one mitmot (multiple input taps as well as
# multiple output taps) output
x = tensor.matrix()
seq = tensor.vector()
def inner_fct(seq, state_old, state_current):
state_next = state_old * 2 + state_current + seq
return state_next
out, _ = theano.scan(inner_fct, sequences=seq,
outputs_info={'initial':x, 'taps':[-2,-1]})
g_out = theano.grad(out.sum(), [seq, x])
scan_node = g_out[0].owner.inputs[1].owner.inputs[1].owner.inputs[0].owner
connection_pattern = scan_node.op.connection_pattern(scan_node)
def test_grad_two_scans(self):
# data input & output
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论