提交 27bce11d authored 作者: Razvan Pascanu's avatar Razvan Pascanu

fixed bug with inplace computation for gradient

上级 6d6ce74d
...@@ -18,8 +18,9 @@ Special cases: ...@@ -18,8 +18,9 @@ Special cases:
- A ``map()`` operation can be performed by applying a function that - A ``map()`` operation can be performed by applying a function that
ignores each previous output. ignores each previous output.
Often a for loop can be expressed as a ``scan()`` operation, and ``scan`` is the closest that theano comes to looping. The advantage of using ``scan`` over Often a for loop can be expressed as a ``scan()`` operation, and ``scan`` is
for loops is that it allows you to express the loop symbolically. The the closest that theano comes to looping. The advantage of using ``scan``
over for loops is that it allows you to express the loop symbolically. The
Scan Op should always be used by applying the ``scan`` function. Scan Op should always be used by applying the ``scan`` function.
""" """
__docformat__ = 'restructedtext en' __docformat__ = 'restructedtext en'
...@@ -299,7 +300,7 @@ class Scan(theano.Op): ...@@ -299,7 +300,7 @@ class Scan(theano.Op):
self.destroy_map = {} self.destroy_map = {}
if inplace: if inplace:
for i in inplace_map.keys(): for i in inplace_map.keys():
self.destroy_map.update({i: [inplace_map[i]] } ) self.destroy_map.update({i: [inplace_map[i]+1] } )
self.seqs_taps = seqs_taps self.seqs_taps = seqs_taps
self.outs_taps = outs_taps self.outs_taps = outs_taps
...@@ -373,8 +374,6 @@ class Scan(theano.Op): ...@@ -373,8 +374,6 @@ class Scan(theano.Op):
rval = (self.inputs == other.inputs) and \ rval = (self.inputs == other.inputs) and \
(self.outputs == other.outputs) and \ (self.outputs == other.outputs) and \
(self.keep_outputs == other.keep_outputs) and \ (self.keep_outputs == other.keep_outputs) and \
(self.g_ins == other.g_ins) and \
(self.g_outs == other.g_outs) and \
(self.seqs_taps == other.seqs_taps) and \ (self.seqs_taps == other.seqs_taps) and \
(self.outs_taps == other.outs_taps) and \ (self.outs_taps == other.outs_taps) and \
(self.inplace_map == other.inplace_map) and \ (self.inplace_map == other.inplace_map) and \
...@@ -553,6 +552,9 @@ class Scan(theano.Op): ...@@ -553,6 +552,9 @@ class Scan(theano.Op):
def grad(self, args, g_outs): def grad(self, args, g_outs):
raise NotImplemented;
'''
if True: if True:
#((self.updates.keys() != []) or (self.inplace_map.keys() != [])\ #((self.updates.keys() != []) or (self.inplace_map.keys() != [])\
# or numpy.any(self.keep_outputs)): # or numpy.any(self.keep_outputs)):
...@@ -590,7 +592,7 @@ class Scan(theano.Op): ...@@ -590,7 +592,7 @@ class Scan(theano.Op):
self.truncate_gradient) self.truncate_gradient)
return g_scan(g_args) return g_scan(g_args)
'''
@gof.local_optimizer([None]) @gof.local_optimizer([None])
...@@ -598,20 +600,20 @@ def scan_make_inplace(node): ...@@ -598,20 +600,20 @@ def scan_make_inplace(node):
op = node.op op = node.op
if isinstance(op, Scan) and (not op.inplace) \ if isinstance(op, Scan) and (not op.inplace) \
and (op.inplace_map.keys() != []): and (op.inplace_map.keys() != []):
return Scan((op.inputs, op.outputs) , op.n_seqs, \ return Scan((op.inputs, op.outputs) , op.n_seqs,
op.n_outs, op.inplace_map, op.seqs_taps, op.outs_taps, \ op.n_outs, op.inplace_map, op.seqs_taps, op.outs_taps,
op.force_gradient, op.truncate_gradient, \ op.truncate_gradient, op.go_backwards, op.keep_outputs,
op.go_backwards, inplace=True \ inplace=True
).make_node(*node.inputs).outputs ).make_node(*node.inputs).outputs
return False return False
optdb.register('scan_make_inplace', opt.in2out(scan_make_inplace,\ optdb.register('scan_make_inplace', opt.in2out(scan_make_inplace,
ignore_newtrees=True), 75, 'fast_run', 'inplace') ignore_newtrees=True), 75, 'fast_run', 'inplace')
'''
class ScanGrad(theano.Op): class ScanGrad(theano.Op):
"""Gradient Op for Scan""" """Gradient Op for Scan"""
def __init__(self,(g_ins, g_outs) , n_seqs, n_outs, def __init__(self,(g_ins, g_outs) , n_seqs, n_outs,
...@@ -767,7 +769,7 @@ class ScanGrad(theano.Op): ...@@ -767,7 +769,7 @@ class ScanGrad(theano.Op):
for i,v in enumerate(g_ins + g_seeds+ g_non_seqs): for i,v in enumerate(g_ins + g_seeds+ g_non_seqs):
storage[i][0] = v storage[i][0] = v
'''
...@@ -221,13 +221,12 @@ class T_Scan(unittest.TestCase): ...@@ -221,13 +221,12 @@ class T_Scan(unittest.TestCase):
assert (compareArrays( out, f8(v_u, v_x0) ) ) assert (compareArrays( out, f8(v_u, v_x0) ) )
'''
# simple rnn ; compute inplace # simple rnn ; compute inplace
def test_7(self): def test_7(self):
u = theano.tensor.dvector() u = theano.tensor.dvector()
mu = theano.Param( u, mutable = True) mu = theano.Param( u, mutable = True)
x0 = theano.tensor.dvector() x0 = theano.tensor.dscalar()
W_in = theano.shared(.1) W_in = theano.shared(.1)
W = theano.shared(1.) W = theano.shared(1.)
...@@ -238,13 +237,12 @@ class T_Scan(unittest.TestCase): ...@@ -238,13 +237,12 @@ class T_Scan(unittest.TestCase):
f9 = theano.function([mu,x0], Y , #mode = 'FAST_RUN') f9 = theano.function([mu,x0], Y , #mode = 'FAST_RUN')
mode = 'DEBUG_MODE') mode = 'DEBUG_MODE')
v_u = numpy.array([1.,2.,3.]) v_u = numpy.array([1.,2.,3.])
v_x0 = numpy.array([1.]) v_x0 = numpy.array(1.)
out = f9(v_u, v_x0) out = f9(v_u, v_x0)
v_out = numpy.array([1.1,1.3,1.6]) v_out = numpy.array([1.1,1.3,1.6])
assert (compareArrays(out, v_out)) assert (compareArrays(out, v_out))
print v_u
assert (compareArrays(v_u, out)) assert (compareArrays(v_u, out))
''' '''
...@@ -252,14 +250,12 @@ class T_Scan(unittest.TestCase): ...@@ -252,14 +250,12 @@ class T_Scan(unittest.TestCase):
def test_10(self): def test_10(self):
pass pass
'''
TO TEST: TO TEST:
- test gradient (one output) - test gradient (one output)
- test gradient (multiple outputs) - test gradient (multiple outputs)
- test gradient (go_bacwards) - test gradient (go_bacwards)
- test gradient (multiple outputs / some uncomputable ) - test gradient (multiple outputs / some uncomputable )
- test gradient (truncate_gradient) - test gradient (truncate_gradient)
- test gradient (force_gradient)
- test_gradient (taps past/future) - test_gradient (taps past/future)
''' '''
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论