提交 df19095d authored 作者: Frederic Bastien's avatar Frederic Bastien

made the scan op more friendly to the profile mode.

It was putting into the same profile the time taked inside the scan op to the sacn op AND into the individual op. Now make multiple profiler instance to split them.
上级 684f31f5
...@@ -74,8 +74,8 @@ def hash_listsDictsTuples(x): ...@@ -74,8 +74,8 @@ def hash_listsDictsTuples(x):
def map(fn, sequences, non_sequences = [], def map(fn, sequences, non_sequences = [],
truncate_gradient = -1, go_backwards = False, truncate_gradient = -1, go_backwards = False,
mode = 'FAST_RUN'): mode = None, name = None):
''' Similar behaviour as python map """ Similar behaviour as python map
:param fn: the function to be applied over the elements in :param fn: the function to be applied over the elements in
sequences ( see scan `fn` for more info) sequences ( see scan `fn` for more info)
...@@ -93,14 +93,16 @@ def map(fn, sequences, non_sequences = [], ...@@ -93,14 +93,16 @@ def map(fn, sequences, non_sequences = [],
:param mode: see scan :param mode: see scan
''' :param name: see scan
"""
return scan(fn, sequences= sequences, outputs_info = [],non_sequences= non_sequences, return scan(fn, sequences= sequences, outputs_info = [],non_sequences= non_sequences,
truncate_gradient= truncate_gradient, truncate_gradient= truncate_gradient,
go_backwards= go_backwards, mode = mode) go_backwards= go_backwards, mode = mode, name = name)
def reduce(fn, sequences, outputs_info, non_sequences = [], go_backwards = False, mode = 'FAST_RUN'): def reduce(fn, sequences, outputs_info, non_sequences = [], go_backwards = False,
''' Similar behaviour as python reduce mode = None, name = None):
""" Similar behaviour as python reduce
:param fn: the function to be applied over the elements in :param fn: the function to be applied over the elements in
sequences ( see scan `fn` for more info) sequences ( see scan `fn` for more info)
...@@ -117,7 +119,8 @@ def reduce(fn, sequences, outputs_info, non_sequences = [], go_backwards = False ...@@ -117,7 +119,8 @@ def reduce(fn, sequences, outputs_info, non_sequences = [], go_backwards = False
see scan for more info see scan for more info
:param mode: see scan :param mode: see scan
''' :param name: see scan
"""
# Specify that you only want the last value of the output # Specify that you only want the last value of the output
if type(outputs_info) not in (list,tuple): if type(outputs_info) not in (list,tuple):
outs_info = [outputs_info] outs_info = [outputs_info]
...@@ -135,10 +138,10 @@ def reduce(fn, sequences, outputs_info, non_sequences = [], go_backwards = False ...@@ -135,10 +138,10 @@ def reduce(fn, sequences, outputs_info, non_sequences = [], go_backwards = False
# we could give more meaningfull error messages then in scan ? # we could give more meaningfull error messages then in scan ?
return scan(fn, sequences = sequences, outputs_info = outs_info, return scan(fn, sequences = sequences, outputs_info = outs_info,
non_sequences = non_sequences, go_backwards = go_backwards, non_sequences = non_sequences, go_backwards = go_backwards,
truncate_gradient = 1, mode = mode) truncate_gradient = 1, mode = mode, name = name)
def foldl(fn, sequences, outputs_info, non_sequences = [], mode = 'FAST_RUN'): def foldl(fn, sequences, outputs_info, non_sequences = [], mode = None, name = None):
''' Similar behaviour as haskell foldl """ Similar behaviour as haskell foldl
:param fn: the function to be applied over the elements in :param fn: the function to be applied over the elements in
sequences ( see scan `fn` for more info) sequences ( see scan `fn` for more info)
...@@ -153,12 +156,13 @@ def foldl(fn, sequences, outputs_info, non_sequences = [], mode = 'FAST_RUN'): ...@@ -153,12 +156,13 @@ def foldl(fn, sequences, outputs_info, non_sequences = [], mode = 'FAST_RUN'):
foldl shouldn't iterate (see scan for more info) foldl shouldn't iterate (see scan for more info)
:param mode: see scan :param mode: see scan
''' :param name: see scan
"""
return reduce(fn = fn, sequences = sequences, outputs_info = outputs_info, return reduce(fn = fn, sequences = sequences, outputs_info = outputs_info,
non_sequences= non_sequences, go_backwards = False, mode = mode) non_sequences= non_sequences, go_backwards = False, mode = mode, name = name)
def foldr(fn, sequences, outputs_info, non_sequences = [], mode = 'FAST_RUN'): def foldr(fn, sequences, outputs_info, non_sequences = [], mode = None):
''' Similar behaviour as haskell foldr """ Similar behaviour as haskell foldr
:param fn: the function to be applied over the elements in :param fn: the function to be applied over the elements in
sequences ( see scan `fn` for more info) sequences ( see scan `fn` for more info)
...@@ -175,9 +179,10 @@ def foldr(fn, sequences, outputs_info, non_sequences = [], mode = 'FAST_RUN'): ...@@ -175,9 +179,10 @@ def foldr(fn, sequences, outputs_info, non_sequences = [], mode = 'FAST_RUN'):
:param truncate_gradient: see scan for more info :param truncate_gradient: see scan for more info
:param mode: see scan :param mode: see scan
''' :param name: see scan
"""
return reduce(fn = fn,sequences = sequences, outputs_info = outputs_info, return reduce(fn = fn,sequences = sequences, outputs_info = outputs_info,
non_sequences = non_sequences, go_backwards = True, mode = mode) non_sequences = non_sequences, go_backwards = True, mode = mode, name = name)
# CONSIDER ALTERNATE CALLING CONVENTIONS: # CONSIDER ALTERNATE CALLING CONVENTIONS:
# simple: # simple:
...@@ -210,8 +215,8 @@ def foldr(fn, sequences, outputs_info, non_sequences = [], mode = 'FAST_RUN'): ...@@ -210,8 +215,8 @@ def foldr(fn, sequences, outputs_info, non_sequences = [], mode = 'FAST_RUN'):
def scan(fn, sequences=[], outputs_info=[], non_sequences=[], def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
n_steps = None, truncate_gradient = -1, go_backwards = False, n_steps = None, truncate_gradient = -1, go_backwards = False,
mode = None): mode = None, name = None):
'''Function that constructs and applies a Scan op """Function that constructs and applies a Scan op
:param fn: :param fn:
Function that describes the operations involved in one step of scan Function that describes the operations involved in one step of scan
...@@ -340,13 +345,26 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[], ...@@ -340,13 +345,26 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
:param go_backwards: :param go_backwards:
Flag indicating if you should go backwards through the sequences Flag indicating if you should go backwards through the sequences
:param name:
The name of the theano fct compiled by the Scan op. It will show in the
profiler output.
:param mode:
The mode used when compiling the theano fct in the Scan op.
If None will use the config mode.
If None and the config mode is a a profile mode, we will create a new instance
to compute correctly the timming.
Otherwise we the time of the Scan op will show into the Scan op and the
time spent inside the Scan op fct will also show op. The new profiler instance
will be printed when python exit.
:rtype: tuple :rtype: tuple
:return: tuple of the form (outputs, updates); ``outputs`` is either a :return: tuple of the form (outputs, updates); ``outputs`` is either a
Theano variable or a list of Theano variables representing the Theano variable or a list of Theano variables representing the
outputs of scan. ``updates`` is a dictionary specifying the outputs of scan. ``updates`` is a dictionary specifying the
updates rules for all shared variables used in the scan updates rules for all shared variables used in the scan
operation; this dictionary should be pass to ``theano.function`` operation; this dictionary should be pass to ``theano.function``
''' """
# check if inputs are just single variables instead of lists # check if inputs are just single variables instead of lists
if not (type(sequences) in (list, tuple)): if not (type(sequences) in (list, tuple)):
...@@ -608,7 +626,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[], ...@@ -608,7 +626,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# Create the Scan op object # Create the Scan op object
local_op = Scan( (inner_fn_inputs,inner_fn_out_states, givens, slice_to_seqs ), n_seqs, local_op = Scan( (inner_fn_inputs,inner_fn_out_states, givens, slice_to_seqs ), n_seqs,
n_extended_outs, inplace_map, sequences_taps, outputs_taps, truncate_gradient, n_extended_outs, inplace_map, sequences_taps, outputs_taps, truncate_gradient,
go_backwards, store_steps, mode, n_fixed_steps = n_fixed_steps) go_backwards, store_steps, mode, n_fixed_steps = n_fixed_steps, name = name)
# Call the object on the input sequences, initial values for outs, # Call the object on the input sequences, initial values for outs,
# and non sequences # and non sequences
...@@ -653,7 +671,8 @@ class Scan(Op): ...@@ -653,7 +671,8 @@ class Scan(Op):
inplace_map={}, seqs_taps={}, outs_taps={}, inplace_map={}, seqs_taps={}, outs_taps={},
truncate_gradient = -1, truncate_gradient = -1,
go_backwards = False, store_steps = {}, go_backwards = False, store_steps = {},
mode = 'FAST_RUN', n_fixed_steps = None, inplace=False): mode = None, n_fixed_steps = None, inplace=False,
name = None):
''' '''
:param (inputs,outputs, givens,slice_to_seqs): :param (inputs,outputs, givens,slice_to_seqs):
inputs and outputs Theano variables that describe the function that is inputs and outputs Theano variables that describe the function that is
...@@ -679,6 +698,8 @@ class Scan(Op): ...@@ -679,6 +698,8 @@ class Scan(Op):
received a number or None otherwise. The value is used to optimize received a number or None otherwise. The value is used to optimize
the graph, since a scan that has n_steps fixed to 1 or 0 is not the graph, since a scan that has n_steps fixed to 1 or 0 is not
really needed in the graph. (? could we use tag hints ?) really needed in the graph. (? could we use tag hints ?)
:param name: see scan fct
:param mode: see scan fct
''' '''
#check sequences past taps #check sequences past taps
for k,v in seqs_taps.iteritems(): for k,v in seqs_taps.iteritems():
...@@ -741,7 +762,21 @@ class Scan(Op): ...@@ -741,7 +762,21 @@ class Scan(Op):
self.go_backwards = go_backwards self.go_backwards = go_backwards
self.slice_to_seqs = slice_to_seqs self.slice_to_seqs = slice_to_seqs
self.fn = function(inputs,outputs, mode = mode, givens = givens) mode_instance = compile.mode.get_mode(mode)
#if we use the default mode and it is a ProfileMode
#we must make a copy otherwise in the profile their will time counted many times
#1) The scan op and its time will include all time spend into the inner node.
#2) The inner scan op with their real time.
if mode is None and isinstance(mode_instance, compile.profilemode.ProfileMode):
mode_instance = compile.profilemode.ProfileMode(
optimizer=mode_instance.provided_optimizer,
linker=mode_instance.provided_linker)
compile.profilemode.prof_mode_instance_to_print.append(mode_instance)
self.mode_instance = mode_instance
if name is None: name = 'scan_fn'
self.fn = function(inputs,outputs, mode = mode_instance, givens = givens,
name = name)
assert not numpy.any([isinstance(x.variable,SharedVariable) for x in assert not numpy.any([isinstance(x.variable,SharedVariable) for x in
self.fn.maker.inputs]) self.fn.maker.inputs])
...@@ -1337,7 +1372,7 @@ class ScanSpaceOptimizer(Optimizer): ...@@ -1337,7 +1372,7 @@ class ScanSpaceOptimizer(Optimizer):
op.inplace_map, op.seqs_taps, op.outs_taps, op.inplace_map, op.seqs_taps, op.outs_taps,
op.truncate_gradient, op.go_backwards, op.truncate_gradient, op.go_backwards,
store_steps, op.mode,op.n_fixed_steps, store_steps, op.mode,op.n_fixed_steps,
op.inplace).make_node(*node.inputs) op.inplace, name = op.fn.name).make_node(*node.inputs)
# we not need to replace the outputs of scan # we not need to replace the outputs of scan
for i,out in enumerate(node.outputs): for i,out in enumerate(node.outputs):
# if we are dealing with an output for which # if we are dealing with an output for which
...@@ -1364,7 +1399,7 @@ def scan_make_inplace(node): ...@@ -1364,7 +1399,7 @@ def scan_make_inplace(node):
return Scan((op.inputs, op.outputs, op.givens, op.slice_to_seqs ) , op.n_seqs, return Scan((op.inputs, op.outputs, op.givens, op.slice_to_seqs ) , op.n_seqs,
op.n_outs, op.inplace_map, op.seqs_taps, op.outs_taps, op.n_outs, op.inplace_map, op.seqs_taps, op.outs_taps,
op.truncate_gradient, op.go_backwards, op.store_steps, op.mode, op.truncate_gradient, op.go_backwards, op.store_steps, op.mode,
op.n_fixed_steps, inplace=True ).make_node(*node.inputs).outputs op.n_fixed_steps, inplace=True, name = op.fn.name).make_node(*node.inputs).outputs
return False return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论