提交 7b8d56d2 authored 作者: abergeron's avatar abergeron

Merge pull request #2642 from carriepl/scan_grad_crash

[CRASH ] Fix crash in scan.grad when same variable used for multiple outputs
......@@ -1597,7 +1597,16 @@ class Scan(PureOp):
if idx >= self.n_mit_mot_outs:
Xt_placeholder = safe_new(Xt)
Xts.append(Xt_placeholder)
if Xt not in self.inner_nitsot_outs(self_outputs):
# Different processing based on whether Xt is a nitsot output
# or not. NOTE : This cannot be done by using
# "if Xt not in self.inner_nitsot_outs(self_outputs)" because
# the exact same variable can be used as multiple outputs.
idx_nitsot_start = (self.info['n_mit_mot'] +
self.info['n_mit_sot'] +
self.info['n_sit_sot'])
idx_nitsot_end = idx_nitsot_start + self.info['n_nit_sot']
if idx < idx_nitsot_start or idx >= idx_nitsot_end:
# What we do here is loop through dC_douts and collect all
# those that are connected to the specific one and do an
# upcast on all of their dtypes to get the dtype for this
......
......@@ -3279,6 +3279,60 @@ class T_Scan(unittest.TestCase):
if isinstance(x.op, theano.scan_module.scan_op.Scan)]
assert len(lssc) == 0
def test_grad_duplicate_outputs(self):
# This test validates that taking the gradient of a scan, in which
# multiple outputs are the same theano variable, works.
def inner_fct(inp1, inp2, inp3):
total = inp1 + inp2 + inp3
return total, total
# Assemble the scan
seq = tensor.matrix()
out_init = tensor.matrix()
non_seq = tensor.vector()
outputs_info = ([None, dict(initial=out_init, taps=[-3])])
scan_outputs, _ = theano.scan(fn=inner_fct, sequences=seq,
outputs_info=outputs_info,
non_sequences=non_seq)
# Attempt to take various gradients
g_output0 = theano.grad(scan_outputs[0].sum(), [seq, out_init, non_seq])
g_output1 = theano.grad(scan_outputs[1].sum(), [seq, out_init, non_seq])
# Compile the function
fct = theano.function([seq, out_init, non_seq],
g_output0 + g_output1)
# Run the function and validate the outputs
seq_value = numpy.random.random((10, 3))
out_init_value = numpy.random.random((3, 3))
non_seq_value = numpy.random.random((3))
outputs = fct(seq_value, out_init_value, non_seq_value)
expected_g_seq = numpy.array([[4, 4, 4],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[1, 1, 1],
[1, 1, 1],
[1, 1, 1]])
expected_g_out_init = expected_g_seq[:3]
expected_g_non_seq = numpy.array([22, 22, 22])
utt.assert_allclose(outputs[0], expected_g_seq)
utt.assert_allclose(outputs[1], expected_g_out_init)
utt.assert_allclose(outputs[2], expected_g_non_seq)
utt.assert_allclose(outputs[3], expected_g_seq)
utt.assert_allclose(outputs[4], expected_g_out_init)
utt.assert_allclose(outputs[5], expected_g_non_seq)
def test_grad_multiple_seqs_different_nsteps(self):
# Example provided Michael Forbes
# This test assures that we clip the sequences to n_steps before
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论