提交 b2d3d192 authored 作者: abergeron's avatar abergeron

Merge pull request #2710 from carriepl/scan_crash_infer_shape

Fix crash in scan.infer_shape() and add test case.
...@@ -1265,8 +1265,13 @@ class Scan(PureOp): ...@@ -1265,8 +1265,13 @@ class Scan(PureOp):
# if we are dealing with a repeat-until, then we do not know the # if we are dealing with a repeat-until, then we do not know the
# leading dimension so we replace it for every entry with Shape_i # leading dimension so we replace it for every entry with Shape_i
if self.as_while: if self.as_while:
scan_outs = [(Shape_i(0)(o),) + x[1:] scan_outs_init = scan_outs
for o, x in izip(node.outputs, scan_outs)] scan_outs = []
for o, x in izip(node.outputs, scan_outs_init):
if x is None:
scan_outs.append(None)
else:
scan_outs.append((Shape_i(0)(o),) + x[1:])
return scan_outs return scan_outs
def get_input_pos(self, output_index): def get_input_pos(self, output_index):
......
...@@ -2519,6 +2519,29 @@ class T_Scan(unittest.TestCase): ...@@ -2519,6 +2519,29 @@ class T_Scan(unittest.TestCase):
utt.assert_allclose(tx4, v_u[-1] + 4.) utt.assert_allclose(tx4, v_u[-1] + 4.)
utt.assert_allclose(tx5, v_u[-1] + 5.) utt.assert_allclose(tx5, v_u[-1] + 5.)
def test_infer_shape(self):
# Test for a crash in scan.infer_shape when using both
# an until condition and random sampling in the inner function.
x = tensor.scalar()
srng = theano.tensor.shared_randomstreams.RandomStreams(0)
def inner_fct(previous_val):
new_val = previous_val + srng.uniform()
condition = theano.scan_module.until(previous_val > 5)
return new_val, condition
out, updates = theano.scan(inner_fct,
outputs_info=x,
n_steps=10)
g_out = tensor.grad(out.sum(), x)
fct = theano.function([x], [out, g_out])
for i in xrange(-5, 5):
output, g_output = fct(i)
assert len(output) == g_output
# The following test will fail in DebugMode if there are # The following test will fail in DebugMode if there are
# some problems in Scan.infer_shape # some problems in Scan.infer_shape
def test_remove_stuff(self): def test_remove_stuff(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论