提交 2b0880dd authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #3102 from carriepl/scan_util_expand

Match broadcastable pattern of original variable
...@@ -401,6 +401,12 @@ def expand(tensor_var, size): ...@@ -401,6 +401,12 @@ def expand(tensor_var, size):
zeros_shape = [size + shapes[0]] + shapes[1:] zeros_shape = [size + shapes[0]] + shapes[1:]
empty = tensor.zeros(zeros_shape, empty = tensor.zeros(zeros_shape,
dtype=tensor_var.dtype) dtype=tensor_var.dtype)
# Make sure to reuse the broadcast pattern of the original tensor for
# every dimension but the first one.
broadcastable = (False,) + tensor_var.broadcastable[1:]
empty = tensor.patternbroadcast(empty, broadcastable)
return tensor.set_subtensor(empty[:shapes[0]], tensor_var) return tensor.set_subtensor(empty[:shapes[0]], tensor_var)
......
...@@ -2431,6 +2431,43 @@ class T_Scan(unittest.TestCase): ...@@ -2431,6 +2431,43 @@ class T_Scan(unittest.TestCase):
utt.assert_allclose(output1, expected_output1) utt.assert_allclose(output1, expected_output1)
utt.assert_allclose(output2, expected_output2) utt.assert_allclose(output2, expected_output2)
def test_use_scan_direct_output2(self):
# This test looks for a crash that happened when directly using the
# recurrent output of a scan node associated with a state with a
# state with broadcastable dimensions
x = tensor.dcol()
seq = tensor.dcol()
outputs_info=[x, tensor.zeros_like(x)]
(out1, out2), updates = theano.scan(lambda a, b, c : (a + b, a + c),
sequences=seq,
outputs_info=outputs_info)
# Obtain a reference to the scan outputs before the subtensor and
# compile a function with them as outputs
assert isinstance(out1.owner.op, tensor.subtensor.Subtensor)
assert isinstance(out2.owner.op, tensor.subtensor.Subtensor)
out1_direct = out1.owner.inputs[0]
out2_direct = out2.owner.inputs[0]
fct = theano.function([x, seq],
[out1_direct, out2_direct])
# Test that the function returns valid outputs
x_val = numpy.arange(0, 4)[:, None]
seq_val = numpy.arange(4, 8)[:, None]
out1, out2 = fct(x_val, seq_val)
expected_out1 = numpy.zeros((5, 4, 1))
expected_out2 = numpy.zeros((5, 4, 1))
for i in range(4):
expected_out2[i + 1] = expected_out2[i] + seq_val[i]
for i in range(5):
expected_out1[i] = expected_out2[i] + x_val
utt.assert_allclose(out1, expected_out1)
utt.assert_allclose(out2, expected_out2)
def test_infer_shape(self): def test_infer_shape(self):
# Test for a crash in scan.infer_shape when using both # Test for a crash in scan.infer_shape when using both
# an until condition and random sampling in the inner function. # an until condition and random sampling in the inner function.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论