提交 f221a38e authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier

Fix crash in scan.infer_shape() and add test case.

上级 2d733797
......@@ -1265,8 +1265,13 @@ class Scan(PureOp):
# if we are dealing with a repeat-until, then we do not know the
# leading dimension so we replace it for every entry with Shape_i
if self.as_while:
scan_outs = [(Shape_i(0)(o),) + x[1:]
for o, x in izip(node.outputs, scan_outs)]
scan_outs_init = scan_outs
scan_outs = []
for o, x in izip(node.outputs, scan_outs_init):
if x is None:
scan_outs.append(None)
else:
scan_outs.append((Shape_i(0)(o),) + x[1:])
return scan_outs
def get_input_pos(self, output_index):
......
......@@ -2519,6 +2519,25 @@ class T_Scan(unittest.TestCase):
utt.assert_allclose(tx4, v_u[-1] + 4.)
utt.assert_allclose(tx5, v_u[-1] + 5.)
def test_infer_shape(self):
# Test for a crash in scan.infer_shape when using both
# an until condition and random sampling in the inner function.
x = tensor.scalar()
srng = theano.tensor.shared_randomstreams.RandomStreams(0)
def inner_fct(previous_val):
new_val = previous_val + srng.uniform()
condition = theano.scan_module.until(previous_val > 50)
return new_val, condition
out, updates = theano.scan(inner_fct,
outputs_info=x,
n_steps=100)
g_out = tensor.grad(out.sum(), x)
fct = theano.function([x], out)
# The following test will fail in DebugMode if there are
# some problems in Scan.infer_shape
def test_remove_stuff(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论