提交 83c7e294 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

I changed the perform of scan to use directly the linker to execute the

inner function, without creating a special type of function. This reduced understanding complexity, but also reduced a potential source of bugs ( when unpickling, or having alias outputs). I also added a new test for the bug Arnaud discovered, which this fix also adresses in a more broad way.
上级 b5fe3cd1
...@@ -846,15 +846,7 @@ def scan( fn ...@@ -846,15 +846,7 @@ def scan( fn
info['inplace'] = False info['inplace'] = False
info['gpu'] = False info['gpu'] = False
revised_outs = [] local_op = scan_op.Scan( inner_inputs, new_outs, info )
for o in new_outs:
if (o in inner_inputs or
isinstance(o, tensor.Constant)):
revised_outs.append( scan_utils.cloneOp(o))
else:
revised_outs.append(o)
local_op = scan_op.Scan( inner_inputs, revised_outs, info )
## ##
### Step 8. Compute the outputs using the scan op ### Step 8. Compute the outputs using the scan op
......
...@@ -2007,7 +2007,34 @@ class T_Scan(unittest.TestCase): ...@@ -2007,7 +2007,34 @@ class T_Scan(unittest.TestCase):
assert scan1.owner.op == scan2.owner.op assert scan1.owner.op == scan2.owner.op
assert hash(scan1.owner.op) == hash(scan2.owner.op) assert hash(scan1.owner.op) == hash(scan2.owner.op)
def test_same(self):
# This test is checking a bug discovered by Arnaud and it is based
# on his code
x = theano.tensor.fmatrix('x')
mem_val = numpy.zeros((2,), dtype='float32')
memory = theano.shared(mem_val.copy())
W = theano.shared(numpy.random.random((5, 2)).astype('float32'))
def f(inp, mem):
i = theano.tensor.join(0, inp, mem)
d = theano.tensor.dot(i, W)
return d, d
outs, updts = theano.scan(f, sequences=[x],
non_sequences=[],
outputs_info=[None, memory])
f = theano.function([x], outs[0])
f2 = theano.function([x], outs[1])
x_val = numpy.random.random((4, 3)).astype('float32')
f_vals = f(x_val)
memory.set_value(mem_val.copy())
f2_vals = f2(x_val)
assert numpy.allclose(f_vals, f2_vals)
if __name__ == '__main__': if __name__ == '__main__':
#''' #'''
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论