提交 bbebd2d3 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

test for bug fix to scan

上级 bf8f180c
...@@ -285,6 +285,35 @@ class T_Scan(unittest.TestCase): ...@@ -285,6 +285,35 @@ class T_Scan(unittest.TestCase):
theano_values = my_f(state, steps) theano_values = my_f(state, steps)
assert numpy.allclose(numpy_values, theano_values) assert numpy.allclose(numpy_values, theano_values)
def test_subtensor_multiple_slices(self):
# This addresses a bug reported by Matthias Zoehrer
# the bug happens when you index on the second dimension,
# case in which the scan save mem optimization fails
def f_pow2(x_tm1):
return 2 * x_tm1
state = theano.tensor.vector('state')
n_steps = theano.tensor.iscalar('nsteps')
output, updates = theano.scan(f_pow2,
[],
state,
[],
n_steps=n_steps,
truncate_gradient=-1,
go_backwards=False)
nw_shape = tensor.ivector('nw_shape')
my_f = theano.function([state, n_steps,nw_shape],
[tensor.reshape(output,nw_shape,ndim=3)[:-2],
output[:-4]],
updates=updates,
allow_input_downcast=True)
nodes = [x for x in my_f.maker.env.toposort()
if isinstance(x.owner.op, theano.scan_module.scan_op.Scan)]
# This assertation fails if savemem optimization failed on scan
assert nodes[0]._scan_savemem_visited
# simple rnn, one input, one state, weights for each; input/state # simple rnn, one input, one state, weights for each; input/state
# are vectors, weights are scalars # are vectors, weights are scalars
def test_one_sequence_one_output_weights(self): def test_one_sequence_one_output_weights(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论