提交 d8d11326 authored 作者: lamblin's avatar lamblin

Merge pull request #1030 from pascanur/fix_borrow_flag_scan

Borrow=True is dangerous if one output destroys another
......@@ -530,7 +530,7 @@ class Scan(PureOp):
self.n_sit_sot +
self.n_nit_sot)
wrapped_inputs = [Param(x, borrow=True) for x in self.inputs]
wrapped_outputs = [Out(x, borrow=True) for x in
wrapped_outputs = [Out(x, borrow=False) for x in
self.outputs[:slices]]
wrapped_outputs += self.outputs[slices:]
profile = None
......
......@@ -3206,6 +3206,27 @@ class T_Scan(unittest.TestCase):
f = theano.function([seq], results[1])
assert numpy.all(exp_out == f(inp))
def test_borrow_bug_jeremiah(self):
# This test fails if scan uses wrongly the borrow flag
inp = numpy.arange(10).reshape(-1,1).astype(theano.config.floatX)
exp_out = numpy.zeros((10,1)).astype(theano.config.floatX)
exp_out[4:] = inp[:-4]
def onestep(x, x_tm4):
return x, x_tm4
seq = tensor.matrix()
initial_value = theano.shared(numpy.zeros((4,1),
dtype=theano.config.floatX))
outputs_info = [{'initial' : initial_value, 'taps' : [-4]}, None]
results, _ = theano.scan(fn=onestep,
sequences=seq,
outputs_info=outputs_info)
sharedvar = theano.shared(numpy.zeros((1,1)))
updates = {sharedvar : results[0][-1:]}
f = theano.function([seq], results[1], updates=updates)
assert numpy.all(exp_out == f(inp))
def test_speed():
#
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论