提交 4c5afa2d authored 作者: --global's avatar --global

Adapt unit test to test both paths of fix for direct outputs

上级 c4708d8e
...@@ -2526,27 +2526,35 @@ class T_Scan(unittest.TestCase): ...@@ -2526,27 +2526,35 @@ class T_Scan(unittest.TestCase):
x = tensor.scalar() x = tensor.scalar()
seq = tensor.vector() seq = tensor.vector()
out, updates = theano.scan(lambda a, b : a + b, sequences=seq, outputs_info=[x, tensor.zeros_like(x)]
outputs_info=x) (out1, out2), updates = theano.scan(lambda a, b, c : (a + b, b + c),
sequences=seq,
outputs_info=outputs_info)
# Obtain a reference to the scan output before the subtensor and # Obtain a reference to the scan outputs before the subtensor and
# compile a function with it as output # compile a function with them as outputs
assert isinstance(out.owner.op, tensor.subtensor.Subtensor) assert isinstance(out1.owner.op, tensor.subtensor.Subtensor)
out = out.owner.inputs[0] assert isinstance(out2.owner.op, tensor.subtensor.Subtensor)
fct = theano.function([x, seq], out[:-1]) out1_direct = out1.owner.inputs[0]
out2_direct = out2.owner.inputs[0]
fct = theano.function([x, seq], [out1_direct[:-1], out2_direct[:-1]])
# Test the function to ensure valid outputs # Test the function to ensure valid outputs
floatX = theano.config.floatX floatX = theano.config.floatX
init_value = 5.0 init_value = 5.0
seq_value = numpy.arange(4, dtype=floatX) seq_value = numpy.arange(4, dtype=floatX)
output = fct(init_value, seq_value) output1, output2 = fct(init_value, seq_value)
expected_output = [init_value] expected_output1 = [init_value]
expected_output2 = [0]
for i in seq_value[:-1]: for i in seq_value[:-1]:
expected_output.append(expected_output[-1] + i) expected_output2.append(expected_output1[-1] +
expected_output2[-1])
expected_output1.append(expected_output1[-1] + i)
utt.assert_allclose(output, expected_output) utt.assert_allclose(output1, expected_output1)
utt.assert_allclose(output2, expected_output2)
def test_infer_shape(self): def test_infer_shape(self):
# Test for a crash in scan.infer_shape when using both # Test for a crash in scan.infer_shape when using both
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论