提交 a5b4de5a authored 作者: abergeron's avatar abergeron

Merge pull request #3292 from carriepl/scan_0_steps

Prevent ScanSaveMem from creating 0-steps scans.
...@@ -1258,6 +1258,14 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1258,6 +1258,14 @@ class ScanSaveMem(gof.Optimizer):
real_steps = None real_steps = None
nw_steps = select_min(select_max(sym_steps, real_steps), nw_steps = select_min(select_max(sym_steps, real_steps),
node.inputs[0]) node.inputs[0])
# Make sure the ScanSaveMem optimization never makes the new
# number of steps to be 0 (this could happen, for instance, if
# the optimization detects that the outputs of the Scan go through
# subtensor nodes that end up taking no elements) because Scan with
# 0 iterations are not supported. Make sure the new number of steps
# is at least 1.
nw_steps = select_max(nw_steps, 1)
else: else:
nw_steps = node.inputs[0] nw_steps = node.inputs[0]
global_nsteps = None global_nsteps = None
......
...@@ -3706,6 +3706,49 @@ class T_Scan(unittest.TestCase): ...@@ -3706,6 +3706,49 @@ class T_Scan(unittest.TestCase):
n_steps=5) n_steps=5)
rval = theano.function([], y2.sum())() rval = theano.function([], y2.sum())()
def test_savemem_opt_0_step(self):
# Test a case where the savemem optimization has the opportunity to
# lower the number of steps of a Scan to 0. It tests that the
# optimization doesn't do so since Scan nodes with 0
# steps are not currently supported and doing so would result in a
# crash during the function execution.
def inner_scan_step(x_t_t, h_tm1, w):
return tensor.dot(h_tm1, w) + x_t_t
def outer_scan_step(x_t, w):
h, _ = theano.scan(inner_scan_step,
sequences=[x_t[1:]],
outputs_info=[x_t[0]],
non_sequences=[w],
strict=True,
name="the_inner_scan")
return h
def get_outputs(x, w):
features, _ = theano.scan(outer_scan_step,
sequences=[x],
non_sequences=[w],
strict=True,
name="the_outer_scan")
return_val = tensor.grad(features.sum(), w)
return return_val
# Compile the theano function
x = tensor.tensor3('x')
w = tensor.matrix('w')
f = theano.function(inputs=[x, w], outputs=get_outputs(x, w))
# Test the function to ensure it returns valid results
x_value = numpy.random.random((2, 2, 3)).astype(theano.config.floatX)
w_value = numpy.random.random((3, 3)).astype(theano.config.floatX)
expected_output = numpy.tile(x_value[:, 0].sum(0), (3, 1)).transpose()
output = f(x_value, w_value)
utt.assert_allclose(output, expected_output)
def test_grad_multiple_taps_state(self): def test_grad_multiple_taps_state(self):
# The test is based on the code provided by Timothy Lillicrap # The test is based on the code provided by Timothy Lillicrap
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论