提交 01e433a3 authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier

Add a test case for previous crash

上级 4566621b
...@@ -2519,6 +2519,35 @@ class T_Scan(unittest.TestCase): ...@@ -2519,6 +2519,35 @@ class T_Scan(unittest.TestCase):
utt.assert_allclose(tx4, v_u[-1] + 4.) utt.assert_allclose(tx4, v_u[-1] + 4.)
utt.assert_allclose(tx5, v_u[-1] + 5.) utt.assert_allclose(tx5, v_u[-1] + 5.)
def test_use_scan_direct_output(self):
# This test looks for a crash that happened when directly using the
# recurrent output of a scan node instead of taking the result
# returned by the scan() function
x = tensor.scalar()
seq = tensor.vector()
out, updates = theano.scan(lambda a, b : a + b, sequences=seq,
outputs_info=x)
# Obtain a reference to the scan output before the subtensor and
# compile a function with it as output
assert isinstance(out.owner.op, tensor.subtensor.Subtensor)
out = out.owner.inputs[0]
fct = theano.function([x, seq], out[:-1])
# Test the function to ensure valid outputs
floatX = theano.config.floatX
init_value = 5.0
seq_value = numpy.arange(4, dtype=floatX)
output = fct(init_value, seq_value)
expected_output = [init_value]
for i in seq_value[:-1]:
expected_output.append(expected_output[-1] + i)
utt.assert_allclose(output, expected_output)
def test_infer_shape(self): def test_infer_shape(self):
# Test for a crash in scan.infer_shape when using both # Test for a crash in scan.infer_shape when using both
# an until condition and random sampling in the inner function. # an until condition and random sampling in the inner function.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论