提交 39b4ec6e authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #2646 from carriepl/scan_grad_crash

[CRASH] Fix scan_op.connection_pattern and add test
...@@ -1349,33 +1349,33 @@ class Scan(PureOp): ...@@ -1349,33 +1349,33 @@ class Scan(PureOp):
e += 1 e += 1
return self.outputs[s:e] return self.outputs[s:e]
def _get_inner_inps(iidx): def _get_inner_inps(outer_iidx):
if node.inputs[iidx + 1] in self.outer_nitsot(node): """Given the index of an outer input, return the corresponding
return None inner input(s) as a sequence.
if node.inputs[iidx + 1] in self.outer_non_seqs(node): """
loc_idx = self.outer_non_seqs(node).index(
node.inputs[iidx + 1])
return [self.inner_non_seqs(self.inputs)[loc_idx]]
s = 0 outer_iidx_from_inner_iidx = self.get_outer_iidx_from_inner_iidx_seq()
if self.n_seqs > 0:
e = 1 # For every inner input, if the corresponding outer input is the
# desired one, store the index
inner_iidxs = []
for i in xrange(len(outer_iidx_from_inner_iidx)):
if outer_iidx_from_inner_iidx[i] == outer_iidx:
inner_iidxs.append(i)
# The inner inputs can be selected this way because the indices in
# inner_iidxs are consecutive and in ascending order
if len(inner_iidxs) > 0:
inner_inputs = self.inputs[inner_iidxs[0]:inner_iidxs[-1]+1]
else: else:
e = len(self.tap_array[0]) inner_inputs = []
for p in xrange(iidx):
s = e return inner_inputs
if p < self.n_seqs:
e += 1
elif p - self.n_seqs < len(self.tap_array):
e += len(self.tap_array[p - self.n_seqs])
else:
e += 1
return self.inputs[s:e]
for oidx, out in enumerate(node.outputs): for oidx, out in enumerate(node.outputs):
for iidx, inp in enumerate(node.inputs[1:]): for iidx, inp in enumerate(node.inputs[1:]):
ols = _get_inner_outs(oidx) ols = _get_inner_outs(oidx)
ils = _get_inner_inps(iidx) ils = _get_inner_inps(iidx + 1)
if ils is None: if ils is None:
# The gradient should be disconnected # The gradient should be disconnected
...@@ -1439,7 +1439,7 @@ class Scan(PureOp): ...@@ -1439,7 +1439,7 @@ class Scan(PureOp):
return connection_pattern return connection_pattern
def get_outer_iidx_from_inner_iidx_seq(self): def get_outer_iidx_from_inner_iidx_seq(self):
""" Return a sequence where the value of at the i-th position is the """ Return a sequence where the value at the i-th position is the
index of the outer input corresponding to the i-th inner input index of the outer input corresponding to the i-th inner input
""" """
......
...@@ -3333,6 +3333,35 @@ class T_Scan(unittest.TestCase): ...@@ -3333,6 +3333,35 @@ class T_Scan(unittest.TestCase):
utt.assert_allclose(outputs[4], expected_g_out_init) utt.assert_allclose(outputs[4], expected_g_out_init)
utt.assert_allclose(outputs[5], expected_g_non_seq) utt.assert_allclose(outputs[5], expected_g_non_seq)
def test_grad_duplicate_outputs_connection_pattern(self):
# This test checks for a crash in scan.connection_pattern when taking
# the grad of a scan with certain combinations of outputs.
def inner_fct(inp1, inp2, inp3, inp4, inp5, inp6):
total = inp1 + inp2 + inp3 + inp4 + inp5 + inp6
return total, total, total, total, total, total
# Assemble the scan
out_init = [tensor.vector(), tensor.vector(),
tensor.matrix(), tensor.matrix()]
outputs_info = ([None, None, out_init[0], out_init[1],
dict(initial=out_init[2], taps=[-2, -1]),
dict(initial=out_init[3], taps=[-2, -1])])
scan_outputs, _ = theano.scan(fn=inner_fct, outputs_info=outputs_info,
n_steps=10)
g_output0 = theano.grad(scan_outputs[0].sum(), out_init[1])
# Validate the connnection pattern is as it should be
node = scan_outputs[0].owner
connection_pattern = node.op.connection_pattern(node)
expected_connection_pattern = [[(j in [1,2,3,4]) for i in range(6)]
for j in range(7)]
assert connection_pattern == expected_connection_pattern
def test_grad_multiple_seqs_different_nsteps(self): def test_grad_multiple_seqs_different_nsteps(self):
# Example provided Michael Forbes # Example provided Michael Forbes
# This test assures that we clip the sequences to n_steps before # This test assures that we clip the sequences to n_steps before
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论