提交 0e7a532e authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #2723 from carriepl/scan_connection_pattern

[CRASH] Fix crash in connection_pattern when there are mitmots
...@@ -1414,11 +1414,8 @@ class Scan(PureOp): ...@@ -1414,11 +1414,8 @@ class Scan(PureOp):
corresponding inner output(s) in a sequence. corresponding inner output(s) in a sequence.
""" """
s = 0 s = 0
if self.n_mit_mot > 0: e = 0
e = len(self.mitmot_out_taps()[0]) for p in xrange(oidx + 1):
else:
e = 1
for p in xrange(oidx):
s = e s = e
if p < self.n_mit_mot: if p < self.n_mit_mot:
e += len(self.mitmot_out_taps()[p]) e += len(self.mitmot_out_taps()[p])
......
...@@ -838,6 +838,26 @@ class T_Scan(unittest.TestCase): ...@@ -838,6 +838,26 @@ class T_Scan(unittest.TestCase):
n_steps=2) n_steps=2)
tensor.grad(a[-1], a0) tensor.grad(a[-1], a0)
def test_connection_pattern2(self):
# This tests for a crash in connection_pattern() when a scan node
# has more than one mitmot (multiple input taps as well as
# multiple output taps) output
x = tensor.matrix()
seq = tensor.vector()
def inner_fct(seq, state_old, state_current):
state_next = state_old * 2 + state_current + seq
return state_next
out, _ = theano.scan(inner_fct, sequences=seq,
outputs_info={'initial':x, 'taps':[-2,-1]})
g_out = theano.grad(out.sum(), [seq, x])
scan_node = g_out[0].owner.inputs[1].owner.inputs[1].owner.inputs[0].owner
connection_pattern = scan_node.op.connection_pattern(scan_node)
def test_grad_two_scans(self): def test_grad_two_scans(self):
# data input & output # data input & output
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论