提交 9a3d74ad authored 作者: Razvan Pascanu's avatar Razvan Pascanu

new test checking the borrow flag

上级 f24f31bf
...@@ -3206,6 +3206,27 @@ class T_Scan(unittest.TestCase): ...@@ -3206,6 +3206,27 @@ class T_Scan(unittest.TestCase):
f = theano.function([seq], results[1]) f = theano.function([seq], results[1])
assert numpy.all(exp_out == f(inp)) assert numpy.all(exp_out == f(inp))
def test_borrow_bug_jeremiah(self):
# This test fails if scan uses wrongly the borrow flag
inp = numpy.arange(10).reshape(-1,1).astype(theano.config.floatX)
exp_out = numpy.zeros((10,1)).astype(theano.config.floatX)
exp_out[4:] = inp[:-4]
def onestep(x, x_tm4):
return x, x_tm4
seq = tensor.matrix()
initial_value = theano.shared(numpy.zeros((4,1),
dtype=theano.config.floatX))
outputs_info = [{'initial' : initial_value, 'taps' : [-4]}, None]
results, _ = theano.scan(fn=onestep,
sequences=seq,
outputs_info=outputs_info)
sharedvar = theano.shared(numpy.zeros((1,1)))
updates = {sharedvar : results[0][-1:]}
f = theano.function([seq], results[1], updates=updates)
assert numpy.all(exp_out == f(inp))
def test_speed(): def test_speed():
# #
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论