提交 dc9d2a89 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

new optimization for scan to be smarter about memory allocation .. plus a few…

new optimization for scan to be smarter about memory allocation .. plus a few notes in the documentation
上级 e4bb7837
...@@ -53,7 +53,10 @@ Scan will return a tuple, containing our result (``result``) and a ...@@ -53,7 +53,10 @@ Scan will return a tuple, containing our result (``result``) and a
dictionary of updates ( empty in this case). Note that the result dictionary of updates ( empty in this case). Note that the result
is not a matrix, but a 3D tensor containing the value of ``A**k`` for is not a matrix, but a 3D tensor containing the value of ``A**k`` for
each step. We want the last value ( after k steps ) so we compile each step. We want the last value ( after k steps ) so we compile
a function to return just that. a function to return just that. Note that there is an optimization, that
at compile time will detect that you are using just the last value of the
result and ensure that scan does not store all the intermediate values
that are used. So do not worry if A and k are large.
Multiple outputs, several taps values - Recurrent Neural Network with Scan Multiple outputs, several taps values - Recurrent Neural Network with Scan
-------------------------------------------------------------------------- --------------------------------------------------------------------------
...@@ -208,5 +211,9 @@ Reference ...@@ -208,5 +211,9 @@ Reference
.. automodule:: theano.scan .. automodule:: theano.scan
.. autofunction:: theano.map
.. autofunction:: theano.reduce
.. autofunction:: theano.foldl
.. autofunction:: theano.foldr
.. autofunction:: theano.scan .. autofunction:: theano.scan
...@@ -28,6 +28,7 @@ __docformat__ = 'restructedtext en' ...@@ -28,6 +28,7 @@ __docformat__ = 'restructedtext en'
import theano import theano
from theano.tensor import opt, TensorType from theano.tensor import opt, TensorType
from theano import gof, Apply from theano import gof, Apply
from theano.gof import Optimizer, toolbox
from theano.compile import optdb from theano.compile import optdb
import theano.tensor.shared_randomstreams as shared_random import theano.tensor.shared_randomstreams as shared_random
from theano.gof.python25 import all from theano.gof.python25 import all
...@@ -122,14 +123,12 @@ def reduce(fn, sequences, outputs_info, non_sequences = [], go_backwards = False ...@@ -122,14 +123,12 @@ def reduce(fn, sequences, outputs_info, non_sequences = [], go_backwards = False
for i,out_info in enumerate(outs_info): for i,out_info in enumerate(outs_info):
if out_info: if out_info:
if not type(out_info) == dict: if not type(out_info) == dict:
outs_info[i] = dict(initial = out_info, taps = [-1], store_steps = 1) outs_info[i] = dict(initial = out_info, return_steps = 1)
else: else:
# we force to use only the last step # we tell scan to store only the last step
# and store only the alst step
outs_info[i]['taps'] = [-1]
outs_info[i]['store_steps'] = 1 outs_info[i]['store_steps'] = 1
# NOTE : Maybe some errors can be detected here were we can give # NOTE : Maybe some errors can be detected here and
# more meaningfull error messages than in scan RP # we could give more meaningfull error messages then in scan ?
return scan(fn, sequences = sequences, outputs_info = outs_info, return scan(fn, sequences = sequences, outputs_info = outs_info,
non_sequences = non_sequences, go_backwards = go_backwards, non_sequences = non_sequences, go_backwards = go_backwards,
truncate_gradient = 1, mode = mode) truncate_gradient = 1, mode = mode)
...@@ -276,6 +275,10 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[], ...@@ -276,6 +275,10 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
flag tells scan that the output should be computed in the memory spaced occupied flag tells scan that the output should be computed in the memory spaced occupied
by that input sequence. Note that scan will only do this if allowed by the by that input sequence. Note that scan will only do this if allowed by the
rest of your computational graph. rest of your computational graph.
* ``return_steps`` how many steps to return from your output. If not given, or
0 scan will return all steps, otherwise it will return the last ``return_steps``.
Note that if you set this to something else then 0, scan will always be smart
about the amount of memory it allocates for a given input.
If the function applied recursively uses only the If the function applied recursively uses only the
previous value of the output, the initial state should have previous value of the output, the initial state should have
...@@ -525,9 +528,8 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[], ...@@ -525,9 +528,8 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
store_steps = [ 0 for i in xrange(n_outs)] store_steps = [ 0 for i in xrange(n_outs)]
for i in xrange(n_outs): for i in xrange(n_outs):
if outs_info[i].get('store_steps', None): if outs_info[i].get('return_steps', None):
print 'here' store_steps[i] = outs_info[i]['return_steps']
store_steps[i] = outs_info[i]['store_steps']
# add shared variable that act as outputs # add shared variable that act as outputs
# #
...@@ -632,7 +634,6 @@ class Scan(theano.Op): ...@@ -632,7 +634,6 @@ class Scan(theano.Op):
if k > n_seqs: if k > n_seqs:
raise ValueError(('Sequences past taps dictionary reffers to ' raise ValueError(('Sequences past taps dictionary reffers to '
'an unexisting sequence %d')%k) 'an unexisting sequence %d')%k)
#check outputs past taps #check outputs past taps
for k,v in outs_taps.iteritems(): for k,v in outs_taps.iteritems():
if k > n_outs: if k > n_outs:
...@@ -679,6 +680,7 @@ class Scan(theano.Op): ...@@ -679,6 +680,7 @@ class Scan(theano.Op):
self.inputs = inputs self.inputs = inputs
self.givens = givens self.givens = givens
self.outputs = outputs self.outputs = outputs
self.mode = mode
self.truncate_gradient = truncate_gradient self.truncate_gradient = truncate_gradient
self.go_backwards = go_backwards self.go_backwards = go_backwards
self.slice_to_seqs = slice_to_seqs self.slice_to_seqs = slice_to_seqs
...@@ -706,6 +708,7 @@ class Scan(theano.Op): ...@@ -706,6 +708,7 @@ class Scan(theano.Op):
(self.seqs_taps == other.seqs_taps) and \ (self.seqs_taps == other.seqs_taps) and \
(self.outs_taps == other.outs_taps) and \ (self.outs_taps == other.outs_taps) and \
(self.inplace_map == other.inplace_map) and \ (self.inplace_map == other.inplace_map) and \
(self.mode == other.mode) and \
(self.n_seqs == other.n_seqs) and\ (self.n_seqs == other.n_seqs) and\
(self.inplace == other.inplace) and\ (self.inplace == other.inplace) and\
(self.go_backwards == other.go_backwards) and\ (self.go_backwards == other.go_backwards) and\
...@@ -725,6 +728,7 @@ class Scan(theano.Op): ...@@ -725,6 +728,7 @@ class Scan(theano.Op):
hash(self.go_backwards) ^\ hash(self.go_backwards) ^\
hash(self.truncate_gradient) ^\ hash(self.truncate_gradient) ^\
hash(self.n_args) ^ \ hash(self.n_args) ^ \
hash(self.mode) ^\
hash_listsDictsTuples(self.outputs) ^ \ hash_listsDictsTuples(self.outputs) ^ \
hash_listsDictsTuples(self.inputs) ^ \ hash_listsDictsTuples(self.inputs) ^ \
hash_listsDictsTuples(self.givens) ^ \ hash_listsDictsTuples(self.givens) ^ \
...@@ -1048,13 +1052,96 @@ class Scan(theano.Op): ...@@ -1048,13 +1052,96 @@ class Scan(theano.Op):
''' '''
class ScanSpaceOptimizer(Optimizer):
""" Graph Optimizer that reduces scan memory consumption """
def __init__(self):
Optimizer.__init__(self)
def add_requirements(self,env):
env.extend(toolbox.ReplaceValidate())
def apply(self, env):
nodelist = list(env.toposort())
for node in nodelist:
op = node.op
# If it is a scan Op
if isinstance(op, Scan):
outputs = node.outputs
store_steps = [0 for x in outputs]
# check the otuputs
for i,out in enumerate(node.outputs):
if op.store_steps[i] == 0 :
# if we do not have a range for this output
req_steps = 0
# look at all its clients
for cl,_dx in out.clients:
if type(cl) == str:
# if the node is actually an output, then
# we need to store the entire thing
req_steps = 0
break
else:
if not isinstance(cl.op,
theano.tensor.basic.Subtensor):
# if any of the clients is not a subtensor
# we also need to store the enitre thing
req_steps = 0
break
else:
# if it is a tensor, and the first
# dimension is just -1
if cl.op.idx_list[0] == -1 :
req_steps = 1
else:
# or a constant that evaluates to
# -1
try:
idx = opt.get_constant_value(cl.op.idx_list[0])
if idx== -1:
req_steps = 1
else:
req_steps = 0
break
except:
req_steps = 0
break
store_steps[i] = req_steps
else:
store_steps[i] = op.store_steps[i]
if numpy.any(store_steps!= op.store_steps):
new_scan = Scan((op.inputs, op.outputs, op.givens,
op.slice_to_seqs),op.n_seqs, op.n_outs,
op.inplace_map, op.seqs_taps, op.outs_taps,
op.truncate_gradient, op.go_backwards,
store_steps, op.mode,op.inplace).make_node(*node.inputs)
# we not need to replace the outputs of scan
for i,out in enumerate(node.outputs):
# if we are dealing with an output for which
# we changed the number of stored steps we
# also need to get rid off the subtensor
if op.store_steps[i] == 0 and store_steps[i] == 1:
# get the output of the subtensor variables
outSubTens = [ x[0].outputs[0] for x in out.clients ]
new_old = [(x,new_scan.outputs[i]) for x in outSubTens]
env.replace_all_validate(new_old,reason =
'scan_space_optimizer')
else:
env.replace_all_validate([(out,
new_scan.outputs[i])], reason =
'scan_space_optimizer')
optdb.register('scanOp_space_optimization', ScanSpaceOptimizer(), 74, 'fast_run')
@gof.local_optimizer([None]) @gof.local_optimizer([None])
def scan_make_inplace(node): def scan_make_inplace(node):
op = node.op op = node.op
if isinstance(op, Scan) and (not op.inplace) and (op.inplace_map.keys() != []): if isinstance(op, Scan) and (not op.inplace) and (op.inplace_map.keys() != []):
return Scan((op.inputs, op.outputs, op.givens, op.slice_to_seqs ) , op.n_seqs, return Scan((op.inputs, op.outputs, op.givens, op.slice_to_seqs ) , op.n_seqs,
op.n_outs, op.inplace_map, op.seqs_taps, op.outs_taps, op.n_outs, op.inplace_map, op.seqs_taps, op.outs_taps,
op.truncate_gradient, op.go_backwards, op.store_steps, op.truncate_gradient, op.go_backwards, op.store_steps, op.mode,
inplace=True ).make_node(*node.inputs).outputs inplace=True ).make_node(*node.inputs).outputs
return False return False
......
...@@ -606,6 +606,7 @@ class T_Scan(unittest.TestCase): ...@@ -606,6 +606,7 @@ class T_Scan(unittest.TestCase):
f = theano.function([v,s], result, updates = updates) f = theano.function([v,s], result, updates = updates)
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
v_v = rng.uniform( size = (5,), low = -5., high = 5.) v_v = rng.uniform( size = (5,), low = -5., high = 5.)
print f(v_v,0.)
assert ( numpy.sum(v_v) == f(v_v, 0.) ) assert ( numpy.sum(v_v) == f(v_v, 0.) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论