提交 8d1910e3 authored 作者: abergeron's avatar abergeron

Merge pull request #2717 from carriepl/scan_crash_infer_shape

[CRASH] Fix crash in save_mem when directly using outputs of the scan node
......@@ -1336,12 +1336,33 @@ class ScanSaveMem(gof.Optimizer):
if global_nsteps is not None:
for idx, val in enumerate(store_steps[op.n_mit_mot:]):
if val == 0:
# val == 0 means that we want to keep all intermediate
# results for that state, including the initial values.
if idx < op.n_mit_sot + op.n_sit_sot:
_nw_input = nw_inputs[offset + idx].owner.inputs[1]
nw_input = scan_utils.expand(_nw_input, nw_steps)
nw_inputs[offset + idx] = nw_input
elif idx < (op.n_mit_sot + op.n_sit_sot +
op.n_nit_sot):
in_idx = offset + idx
# Number of steps in the initial state
initl = init_l[op.n_mit_mot + idx]
# If the initial buffer has the form
# inc_subtensor(zeros(...)[...], _nw_input)
# we want to make the zeros tensor as small as
# possible (nw_steps + initl), and call
# inc_subtensor on that instead.
# Otherwise, simply take 0:(nw_steps+initl).
if ((nw_inputs[in_idx].owner and
isinstance(nw_inputs[in_idx].owner.op,
tensor.IncSubtensor) and
isinstance(
nw_inputs[in_idx].owner.op.idx_list[0],
slice))):
_nw_input = nw_inputs[in_idx].owner.inputs[1]
nw_input = scan_utils.expand(_nw_input,
nw_steps)
nw_inputs[in_idx] = nw_input
else:
nw_input = nw_inputs[in_idx][:(initl+nw_steps)]
elif idx < op.n_mit_sot + op.n_sit_sot + op.n_nit_sot:
in_idx = offset + idx + op.n_shared_outs
if nw_inputs[in_idx] == node.inputs[0]:
nw_inputs[in_idx] = nw_steps
......
......@@ -2581,6 +2581,53 @@ class T_Scan(unittest.TestCase):
utt.assert_allclose(tx4, v_u[-1] + 4.)
utt.assert_allclose(tx5, v_u[-1] + 5.)
def test_use_scan_direct_output(self):
# This test looks for a crash that happened when directly using the
# recurrent output of a scan node instead of taking the result
# returned by the scan() function
# Obtain a compilation mode that will cause the test to fail if an
# exception occurs in the optimization process
on_opt_error = theano.config.on_opt_error
theano.config.on_opt_error = "raise"
mode = theano.compile.get_default_mode()
theano.config.on_opt_error = on_opt_error
x = tensor.scalar()
seq = tensor.vector()
outputs_info=[x, tensor.zeros_like(x)]
(out1, out2), updates = theano.scan(lambda a, b, c : (a + b, b + c),
sequences=seq,
outputs_info=outputs_info,
mode=mode)
# Obtain a reference to the scan outputs before the subtensor and
# compile a function with them as outputs
assert isinstance(out1.owner.op, tensor.subtensor.Subtensor)
assert isinstance(out2.owner.op, tensor.subtensor.Subtensor)
out1_direct = out1.owner.inputs[0]
out2_direct = out2.owner.inputs[0]
fct = theano.function([x, seq],
[out1_direct[:-1], out2_direct[:-1]],
mode=mode)
# Test the function to ensure valid outputs
floatX = theano.config.floatX
init_value = 5.0
seq_value = numpy.arange(4, dtype=floatX)
output1, output2 = fct(init_value, seq_value)
expected_output1 = [init_value]
expected_output2 = [0]
for i in seq_value[:-1]:
expected_output2.append(expected_output1[-1] +
expected_output2[-1])
expected_output1.append(expected_output1[-1] + i)
utt.assert_allclose(output1, expected_output1)
utt.assert_allclose(output2, expected_output2)
def test_infer_shape(self):
# Test for a crash in scan.infer_shape when using both
# an until condition and random sampling in the inner function.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论