提交 46d46d76 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

numpydoc for theano/gof/opt.py

上级 ae99f41d
""" """
Defines the base class for optimizations as well as a certain Defines the base class for optimizations as well as a certain
amount of useful generic optimization tools. amount of useful generic optimization tools.
""" """
from __future__ import print_function from __future__ import print_function
...@@ -35,10 +36,13 @@ def _list_of_nodes(fgraph): ...@@ -35,10 +36,13 @@ def _list_of_nodes(fgraph):
class Optimizer(object): class Optimizer(object):
"""WRITEME """
WRITEME
An L{Optimizer} can be applied to an L{FunctionGraph} to transform it. An L{Optimizer} can be applied to an L{FunctionGraph} to transform it.
It can represent an optimization or in general any kind It can represent an optimization or in general any kind
of transformation you could apply to an L{FunctionGraph}. of transformation you could apply to an L{FunctionGraph}.
""" """
def __hash__(self): def __hash__(self):
...@@ -58,19 +62,25 @@ class Optimizer(object): ...@@ -58,19 +62,25 @@ class Optimizer(object):
return id(self) != id(other) return id(self) != id(other)
def apply(self, fgraph): def apply(self, fgraph):
"""WRITEME """
WRITEME
Applies the optimization to the provided L{FunctionGraph}. It may Applies the optimization to the provided L{FunctionGraph}. It may
use all the methods defined by the L{FunctionGraph}. If the use all the methods defined by the L{FunctionGraph}. If the
L{Optimizer} needs to use a certain tool, such as an L{Optimizer} needs to use a certain tool, such as an
L{InstanceFinder}, it can do so in its L{add_requirements} method. L{InstanceFinder}, it can do so in its L{add_requirements} method.
""" """
pass pass
def optimize(self, fgraph, *args, **kwargs): def optimize(self, fgraph, *args, **kwargs):
"""WRITEME """
This is meant as a shortcut to:: WRITEME
This is meant as a shortcut to:
opt.add_requirements(fgraph) opt.add_requirements(fgraph)
opt.apply(fgraph) opt.apply(fgraph)
""" """
self.add_requirements(fgraph) self.add_requirements(fgraph)
try: try:
...@@ -82,18 +92,24 @@ class Optimizer(object): ...@@ -82,18 +92,24 @@ class Optimizer(object):
return ret return ret
def __call__(self, fgraph): def __call__(self, fgraph):
"""WRITEME """
Same as self.optimize(fgraph) WRITEME
Same as self.optimize(fgraph).
""" """
return self.optimize(fgraph) return self.optimize(fgraph)
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
"""WRITEME """
WRITEME
Add features to the fgraph that are required to apply the optimization. Add features to the fgraph that are required to apply the optimization.
For example: For example:
fgraph.attach_feature(History()) fgraph.attach_feature(History())
fgraph.attach_feature(MyFeature()) fgraph.attach_feature(MyFeature())
etc. etc.
""" """
pass pass
...@@ -111,7 +127,10 @@ class Optimizer(object): ...@@ -111,7 +127,10 @@ class Optimizer(object):
class FromFunctionOptimizer(Optimizer): class FromFunctionOptimizer(Optimizer):
"""WRITEME""" """
WRITEME
"""
def __init__(self, fn, requirements=()): def __init__(self, fn, requirements=()):
self.apply = fn self.apply = fn
self.requirements = requirements self.requirements = requirements
...@@ -134,14 +153,20 @@ class FromFunctionOptimizer(Optimizer): ...@@ -134,14 +153,20 @@ class FromFunctionOptimizer(Optimizer):
def optimizer(f): def optimizer(f):
"""decorator for FromFunctionOptimizer""" """
Decorator for FromFunctionOptimizer.
"""
rval = FromFunctionOptimizer(f) rval = FromFunctionOptimizer(f)
rval.__name__ = f.__name__ rval.__name__ = f.__name__
return rval return rval
def inplace_optimizer(f): def inplace_optimizer(f):
"""decorator for FromFunctionOptimizer""" """
Decorator for FromFunctionOptimizer.
"""
dh_handler = dh.DestroyHandler dh_handler = dh.DestroyHandler
requirements = (lambda fgraph: requirements = (lambda fgraph:
fgraph.attach_feature(dh_handler()),) fgraph.attach_feature(dh_handler()),)
...@@ -152,13 +177,18 @@ def inplace_optimizer(f): ...@@ -152,13 +177,18 @@ def inplace_optimizer(f):
class SeqOptimizer(Optimizer, list): class SeqOptimizer(Optimizer, list):
# inherit from Optimizer first to get Optimizer.__hash__ # inherit from Optimizer first to get Optimizer.__hash__
"""WRITEME """
WRITEME
Takes a list of L{Optimizer} instances and applies them Takes a list of L{Optimizer} instances and applies them
sequentially. sequentially.
""" """
@staticmethod @staticmethod
def warn(exc, self, optimizer): def warn(exc, self, optimizer):
"""Default failure_callback for SeqOptimizer """
Default failure_callback for SeqOptimizer.
""" """
_logger.error("SeqOptimizer apply %s" % str(optimizer)) _logger.error("SeqOptimizer apply %s" % str(optimizer))
_logger.error("Traceback:") _logger.error("Traceback:")
...@@ -169,15 +199,21 @@ class SeqOptimizer(Optimizer, list): ...@@ -169,15 +199,21 @@ class SeqOptimizer(Optimizer, list):
pdb.post_mortem(sys.exc_info()[2]) pdb.post_mortem(sys.exc_info()[2])
def __init__(self, *opts, **kw): def __init__(self, *opts, **kw):
"""WRITEME""" """
WRITEME
"""
if len(opts) == 1 and isinstance(opts[0], (list, tuple)): if len(opts) == 1 and isinstance(opts[0], (list, tuple)):
opts = opts[0] opts = opts[0]
self[:] = opts self[:] = opts
self.failure_callback = kw.pop('failure_callback', None) self.failure_callback = kw.pop('failure_callback', None)
def apply(self, fgraph): def apply(self, fgraph):
"""WRITEME """
WRITEME
Applies each L{Optimizer} in self in turn. Applies each L{Optimizer} in self in turn.
""" """
l = [] l = []
if fgraph.profile: if fgraph.profile:
...@@ -286,6 +322,7 @@ class SeqOptimizer(Optimizer, list): ...@@ -286,6 +322,7 @@ class SeqOptimizer(Optimizer, list):
def merge_profile(prof1, prof2): def merge_profile(prof1, prof2):
""" """
Merge 2 profiles returned by this cass apply() fct. Merge 2 profiles returned by this cass apply() fct.
""" """
new_t = [] new_t = []
new_l = [] new_l = []
...@@ -354,7 +391,11 @@ class SeqOptimizer(Optimizer, list): ...@@ -354,7 +391,11 @@ class SeqOptimizer(Optimizer, list):
class _metadict: class _metadict:
"""WRITEME""" """
WRITEME
"""
# dict that accepts unhashable keys # dict that accepts unhashable keys
# uses an associative list # uses an associative list
# for internal use only # for internal use only
...@@ -430,6 +471,7 @@ class MergeFeature(object): ...@@ -430,6 +471,7 @@ class MergeFeature(object):
That way, the MergeOptimizer can remember the result of the last merge That way, the MergeOptimizer can remember the result of the last merge
pass on the fgraph. pass on the fgraph.
""" """
def on_attach(self, fgraph): def on_attach(self, fgraph):
assert not hasattr(fgraph, 'merge_feature') assert not hasattr(fgraph, 'merge_feature')
...@@ -493,7 +535,10 @@ class MergeFeature(object): ...@@ -493,7 +535,10 @@ class MergeFeature(object):
self.seen_constants.discard(id(c)) self.seen_constants.discard(id(c))
def process_constant(self, fgraph, c): def process_constant(self, fgraph, c):
"""Check if a constant can be merged, and queue that replacement""" """
Check if a constant can be merged, and queue that replacement.
"""
if id(c) in self.seen_constants: if id(c) in self.seen_constants:
return return
sig = c.merge_signature() sig = c.merge_signature()
...@@ -511,7 +556,10 @@ class MergeFeature(object): ...@@ -511,7 +556,10 @@ class MergeFeature(object):
self.seen_constants.add(id(c)) self.seen_constants.add(id(c))
def process_node(self, fgraph, node): def process_node(self, fgraph, node):
"""Check if a node can be merged, and queue that replacement.""" """
Check if a node can be merged, and queue that replacement.
"""
if node in self.nodes_seen: if node in self.nodes_seen:
return return
...@@ -570,6 +618,7 @@ class MergeOptimizer(Optimizer): ...@@ -570,6 +618,7 @@ class MergeOptimizer(Optimizer):
The first step of merging is constant-merging, so that all clients of an The first step of merging is constant-merging, so that all clients of an
int(1) for example, are transferred to a particular instance of int(1). int(1) for example, are transferred to a particular instance of int(1).
""" """
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
...@@ -678,6 +727,7 @@ def is_same_graph_with_merge(var1, var2, givens=None): ...@@ -678,6 +727,7 @@ def is_same_graph_with_merge(var1, var2, givens=None):
Merge-based implementation of `theano.gof.graph.is_same_graph`. Merge-based implementation of `theano.gof.graph.is_same_graph`.
See help on `theano.gof.graph.is_same_graph` for additional documentation. See help on `theano.gof.graph.is_same_graph` for additional documentation.
""" """
if givens is None: if givens is None:
givens = {} givens = {}
...@@ -718,13 +768,15 @@ def pre_constant_merge(vars): ...@@ -718,13 +768,15 @@ def pre_constant_merge(vars):
`vars` is a list of nodes, and we want to merge together nodes `vars` is a list of nodes, and we want to merge together nodes
that are constant inputs used to compute nodes in that list. that are constant inputs used to compute nodes in that list.
:note: This function will ignore nodes that are in an fgraph. Notes
It is used to pre-merge nodes generated inside an optimization, -----
before it is inserted in the fgraph. This function will ignore nodes that are in an fgraph.
It is useful if there are many such replacements to make, It is used to pre-merge nodes generated inside an optimization,
so that DebugMode will not check each of them. before it is inserted in the fgraph.
""" It is useful if there are many such replacements to make,
so that DebugMode will not check each of them.
"""
seen_var = set() seen_var = set()
# signature -> variable (for constants) # signature -> variable (for constants)
const_sig_inv = {} const_sig_inv = {}
...@@ -767,10 +819,12 @@ def pre_constant_merge(vars): ...@@ -767,10 +819,12 @@ def pre_constant_merge(vars):
######################## ########################
class LocalOptimizer(object): class LocalOptimizer(object):
"""A class for node-based optimizations. """
A class for node-based optimizations.
Instances should implement the transform function, Instances should implement the transform function,
and be passed to configure a fgraph-based Optimizer instance. and be passed to configure a fgraph-based Optimizer instance.
""" """
def __hash__(self): def __hash__(self):
...@@ -784,11 +838,13 @@ class LocalOptimizer(object): ...@@ -784,11 +838,13 @@ class LocalOptimizer(object):
Return the list of op classes that this opt applies to. Return the list of op classes that this opt applies to.
Return None to apply to all nodes. Return None to apply to all nodes.
""" """
return None return None
def transform(self, node): def transform(self, node):
"""Transform a subgraph whose output is `node`. """
Transform a subgraph whose output is `node`.
Subclasses should implement this function so that it returns one of two Subclasses should implement this function so that it returns one of two
kinds of things: kinds of things:
...@@ -800,7 +856,9 @@ class LocalOptimizer(object): ...@@ -800,7 +856,9 @@ class LocalOptimizer(object):
- dict(old variables -> new variables). A dictionary that map - dict(old variables -> new variables). A dictionary that map
from old variables to new variables to replace. from old variables to new variables to replace.
:type node: an Apply instance Parameters
----------
node : an Apply instance
""" """
...@@ -810,8 +868,8 @@ class LocalOptimizer(object): ...@@ -810,8 +868,8 @@ class LocalOptimizer(object):
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
""" """
If this local optimization wants to add some requirements to the If this local optimization wants to add some requirements to the
fgraph, fgraph, this is the place to do it.
This is the place to do it.
""" """
# Added by default # Added by default
# fgraph.attach_feature(toolbox.ReplaceValidate()) # fgraph.attach_feature(toolbox.ReplaceValidate())
...@@ -830,8 +888,11 @@ theano.configparser.AddConfigVar( ...@@ -830,8 +888,11 @@ theano.configparser.AddConfigVar(
class LocalMetaOptimizer(LocalOptimizer): class LocalMetaOptimizer(LocalOptimizer):
"""Base class for meta-optimizers that try a set of LocalOptimizers """
to replace a node and choose the one that executes the fastest""" Base class for meta-optimizers that try a set of LocalOptimizers
to replace a node and choose the one that executes the fastest.
"""
def __init__(self, tracks=None, optimizers=()): def __init__(self, tracks=None, optimizers=()):
self._tracks = tracks self._tracks = tracks
...@@ -907,9 +968,12 @@ class LocalMetaOptimizer(LocalOptimizer): ...@@ -907,9 +968,12 @@ class LocalMetaOptimizer(LocalOptimizer):
return return
def provide_inputs(self, node, inputs): def provide_inputs(self, node, inputs):
"""If implemented, returns a dictionary mapping all symbolic variables """
in ``inputs`` to SharedVariable instances of suitable dummy values. The If implemented, returns a dictionary mapping all symbolic variables
``node`` can be inspected to infer required input shapes.""" in ``inputs`` to SharedVariable instances of suitable dummy values.
The ``node`` can be inspected to infer required input shapes.
"""
raise NotImplementedError() raise NotImplementedError()
def time_call(self, fn): def time_call(self, fn):
...@@ -919,7 +983,10 @@ class LocalMetaOptimizer(LocalOptimizer): ...@@ -919,7 +983,10 @@ class LocalMetaOptimizer(LocalOptimizer):
class FromFunctionLocalOptimizer(LocalOptimizer): class FromFunctionLocalOptimizer(LocalOptimizer):
"""WRITEME""" """
WRITEME
"""
def __init__(self, fn, tracks=None, requirements=()): def __init__(self, fn, tracks=None, requirements=()):
self.transform = fn self.transform = fn
self._tracks = tracks self._tracks = tracks
...@@ -945,7 +1012,10 @@ class FromFunctionLocalOptimizer(LocalOptimizer): ...@@ -945,7 +1012,10 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
def local_optimizer(tracks, inplace=False): def local_optimizer(tracks, inplace=False):
def decorator(f): def decorator(f):
"""WRITEME""" """
WRITEME
"""
if tracks is not None: if tracks is not None:
if len(tracks) is 0: if len(tracks) is 0:
raise ValueError("Use None instead of an empty list to apply to all nodes.", f.__module__, f.__name__) raise ValueError("Use None instead of an empty list to apply to all nodes.", f.__module__, f.__name__)
...@@ -964,7 +1034,10 @@ def local_optimizer(tracks, inplace=False): ...@@ -964,7 +1034,10 @@ def local_optimizer(tracks, inplace=False):
class LocalOptGroup(LocalOptimizer): class LocalOptGroup(LocalOptimizer):
"""WRITEME""" """
WRITEME
"""
def __init__(self, *optimizers): def __init__(self, *optimizers):
if len(optimizers) == 1 and isinstance(optimizers[0], list): if len(optimizers) == 1 and isinstance(optimizers[0], list):
...@@ -1009,12 +1082,23 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1009,12 +1082,23 @@ class LocalOptGroup(LocalOptimizer):
class OpSub(LocalOptimizer): class OpSub(LocalOptimizer):
"""WRITEME """
WRITEME
Replaces the application of a certain op by the application of Replaces the application of a certain op by the application of
another op that take the same inputs as what they are replacing. another op that takes the same inputs as what they are replacing.
e.g. OpSub(add, sub) ==> Parameters
----------
op1, op2
op1.make_node and op2.make_node must take the same number of
inputs and have the same number of outputs.
Examples
--------
OpSub(add, sub) ==>
add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x)) add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
""" """
# an OpSub does not apply to the nodes it produces # an OpSub does not apply to the nodes it produces
...@@ -1023,10 +1107,6 @@ class OpSub(LocalOptimizer): ...@@ -1023,10 +1107,6 @@ class OpSub(LocalOptimizer):
retains_inputs = True retains_inputs = True
def __init__(self, op1, op2, transfer_tags=True): def __init__(self, op1, op2, transfer_tags=True):
"""
op1.make_node and op2.make_node must take the same number of
inputs and have the same number of outputs.
"""
self.op1 = op1 self.op1 = op1
self.op2 = op2 self.op2 = op2
self.transfer_tags = transfer_tags self.transfer_tags = transfer_tags
...@@ -1052,9 +1132,12 @@ class OpSub(LocalOptimizer): ...@@ -1052,9 +1132,12 @@ class OpSub(LocalOptimizer):
class OpRemove(LocalOptimizer): class OpRemove(LocalOptimizer):
"""WRITEME """
WRITEME
Removes all applications of an op by transferring each of its Removes all applications of an op by transferring each of its
outputs to the corresponding input. outputs to the corresponding input.
""" """
reentrant = False # no nodes are added at all reentrant = False # no nodes are added at all
...@@ -1085,25 +1168,27 @@ class OpRemove(LocalOptimizer): ...@@ -1085,25 +1168,27 @@ class OpRemove(LocalOptimizer):
class PatternSub(LocalOptimizer): class PatternSub(LocalOptimizer):
"""WRITEME """
WRITEME
@todo update @todo update
Replaces all occurrences of the input pattern by the output pattern: Replaces all occurrences of the input pattern by the output pattern:
input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...) input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...)
input_pattern ::= dict(pattern = <input_pattern>, input_pattern ::= dict(pattern = <input_pattern>,
constraint = <constraint>) constraint = <constraint>)
sub_pattern ::= input_pattern sub_pattern ::= input_pattern
sub_pattern ::= string sub_pattern ::= string
sub_pattern ::= a Constant instance sub_pattern ::= a Constant instance
sub_pattern ::= int sub_pattern ::= int
sub_pattern ::= float sub_pattern ::= float
constraint ::= lambda fgraph, expr: additional matching condition constraint ::= lambda fgraph, expr: additional matching condition
output_pattern ::= (op, <output_pattern1>, <output_pattern2>, ...) output_pattern ::= (op, <output_pattern1>, <output_pattern2>, ...)
output_pattern ::= string output_pattern ::= string
output_pattern ::= int output_pattern ::= int
output_pattern ::= float output_pattern ::= float
Each string in the input pattern is a variable that will be set to Each string in the input pattern is a variable that will be set to
whatever expression is found in its place. If the same string is whatever expression is found in its place. If the same string is
...@@ -1123,45 +1208,51 @@ class PatternSub(LocalOptimizer): ...@@ -1123,45 +1208,51 @@ class PatternSub(LocalOptimizer):
trying to match and returns True or False according to an trying to match and returns True or False according to an
arbitrary criterion. arbitrary criterion.
Examples: The constructor creates a PatternSub that replaces occurrences of
PatternSub((add, 'x', 'y'), (add, 'y', 'x')) in_pattern by occurrences of out_pattern.
PatternSub((multiply, 'x', 'x'), (square, 'x'))
PatternSub((subtract, (add, 'x', 'y'), 'y'), 'x') Parameters
PatternSub((power, 'x', Constant(double, 2.0)), (square, 'x')) ----------
PatternSub((boggle, {'pattern': 'x', in_pattern
'constraint': lambda expr: expr.type == scrabble}), The input pattern that we want to replace.
(scrabble, 'x')) out_pattern
The replacement pattern.
allow_multiple_clients : bool
If False, the pattern matching will fail if one of the subpatterns has
more than one client.
skip_identities_fn : TODO
name
Allows to override this optimizer name.
pdb : bool
If True, we invoke pdb when the first node in the pattern matches.
tracks : optional
The values that self.tracks() will return. Useful to speed up
optimization sometimes.
get_nodes : optional
If you provide `tracks`, you must provide this parameter. It must be a
function that takes the tracked node and returns a list of nodes on
which we will try this optimizer.
Notes
-----
`tracks` and `get_nodes` can be used to make this optimizer track a less
frequent Op, so this will make this optimizer tried less frequently.
Examples
--------
PatternSub((add, 'x', 'y'), (add, 'y', 'x'))
PatternSub((multiply, 'x', 'x'), (square, 'x'))
PatternSub((subtract, (add, 'x', 'y'), 'y'), 'x')
PatternSub((power, 'x', Constant(double, 2.0)), (square, 'x'))
PatternSub((boggle, {'pattern': 'x',
'constraint': lambda expr: expr.type == scrabble}),
(scrabble, 'x'))
""" """
def __init__(self, in_pattern, out_pattern, def __init__(self, in_pattern, out_pattern,
allow_multiple_clients=False, allow_multiple_clients=False,
skip_identities_fn=None, name=None, pdb=False, skip_identities_fn=None, name=None, pdb=False,
tracks=(), get_nodes=None): tracks=(), get_nodes=None):
"""
Creates a PatternSub that replaces occurrences of
in_pattern by occurrences of out_pattern.
:param in_pattern: the input pattern that we want to replace
:param out_pattern: the replacement pattern
:param allow_multiple_clients: if False, the pattern matching will fail
if one of the subpatterns has more than
one client.
:param skip_identities_fn: TODO
:param name: Allow to override this optimizer name
:param pdb: if True, we invoke pdb when the first node in the
pattern match.
:param tracks: Optional. The values that self.tracks() will
return. Useful to speed up optimization some times.
:param get_nodes: Optional. If you provide `tracks`, you must
provide this parameter. It must be a function that take the
tracked node and return a list of node on which we will try
this optimizer.
`tracks` and `get_nodes` can be used to make this optimizer
track a less frequent Op, so this will make this optimizer
tried less frequently,
"""
self.in_pattern = in_pattern self.in_pattern = in_pattern
self.out_pattern = out_pattern self.out_pattern = out_pattern
if isinstance(in_pattern, (list, tuple)): if isinstance(in_pattern, (list, tuple)):
...@@ -1196,6 +1287,7 @@ class PatternSub(LocalOptimizer): ...@@ -1196,6 +1287,7 @@ class PatternSub(LocalOptimizer):
""" """
Checks if the graph from node corresponds to in_pattern. If it does, Checks if the graph from node corresponds to in_pattern. If it does,
constructs out_pattern and performs the replacement. constructs out_pattern and performs the replacement.
""" """
if get_nodes and self.get_nodes is not None: if get_nodes and self.get_nodes is not None:
for real_node in self.get_nodes(node): for real_node in self.get_nodes(node):
...@@ -1357,12 +1449,40 @@ class Updater: ...@@ -1357,12 +1449,40 @@ class Updater:
class NavigatorOptimizer(Optimizer): class NavigatorOptimizer(Optimizer):
"""Abstract class """
Abstract class.
Parameters
----------
local_opt
A LocalOptimizer to apply over a FunctionGraph (or None is Ok too).
ignore_newtrees
- True: new subgraphs returned by an optimization is not a
candidate for optimization.
- False: new subgraphs returned by an optimization is a candidate
for optimization.
- 'auto': let the local_opt set this parameter via its 'reentrant'
attribute.
failure_callback
A function that takes (exception, navigator, [(old, new),
(old,new),...]) and we call it if there's an exception.
If the trouble is from local_opt.transform(), the new variables
will be 'None'.
If the trouble is from validation (the new types don't match for
example) then the new variables will be the ones created by
transform().
If this parameter is None, then exceptions are not caught here
(raised normally).
""" """
@staticmethod @staticmethod
def warn(exc, nav, repl_pairs, local_opt): def warn(exc, nav, repl_pairs, local_opt):
"""failure_callback for NavigatorOptimizer: print traceback """
Failure_callback for NavigatorOptimizer: print traceback.
""" """
if config.on_opt_error != 'ignore': if config.on_opt_error != 'ignore':
_logger.error("Optimization failure due to: %s" % str(local_opt)) _logger.error("Optimization failure due to: %s" % str(local_opt))
...@@ -1377,9 +1497,11 @@ class NavigatorOptimizer(Optimizer): ...@@ -1377,9 +1497,11 @@ class NavigatorOptimizer(Optimizer):
@staticmethod @staticmethod
def warn_inplace(exc, nav, repl_pairs, local_opt): def warn_inplace(exc, nav, repl_pairs, local_opt):
"""failure_callback for NavigatorOptimizer """
Failure_callback for NavigatorOptimizer.
Ignore InconsistencyErrors, print traceback.
ignore InconsistencyErrors, print traceback
""" """
if isinstance(exc, InconsistencyError): if isinstance(exc, InconsistencyError):
return return
...@@ -1387,36 +1509,14 @@ class NavigatorOptimizer(Optimizer): ...@@ -1387,36 +1509,14 @@ class NavigatorOptimizer(Optimizer):
@staticmethod @staticmethod
def warn_ignore(exc, nav, repl_pairs, local_opt): def warn_ignore(exc, nav, repl_pairs, local_opt):
"""failure_callback for NavigatorOptimizer: ignore all errors """
Failure_callback for NavigatorOptimizer: ignore all errors.
""" """
pass pass
def __init__(self, local_opt, ignore_newtrees='auto', def __init__(self, local_opt, ignore_newtrees='auto',
failure_callback=None): failure_callback=None):
"""
:param local_opt: a LocalOptimizer to apply over a FunctionGraph
(or None is Ok too).
:param ignore_newtrees:
- True: new subgraphs returned by an optimization is not a
candidate for optimization
- False: new subgraphs returned by an optimization is a candidate
for optimization
- 'auto': let the local_opt set this parameter via its 'reentrant'
attribute.
:param failure_callback:
a function that takes (exception, navigator, [(old, new),
(old,new),...]) and we call it if there's an exception.
If the trouble is from local_opt.transform(), the new variables
will be 'None'.
If the trouble is from validation (the new types don't match for
example) then the new variables will be the ones created by
transform().
If this parameter is None, then exceptions are not caught here
(raised normally).
"""
self.local_opt = local_opt self.local_opt = local_opt
if ignore_newtrees == 'auto': if ignore_newtrees == 'auto':
self.ignore_newtrees = not getattr(local_opt, 'reentrant', True) self.ignore_newtrees = not getattr(local_opt, 'reentrant', True)
...@@ -1429,14 +1529,23 @@ class NavigatorOptimizer(Optimizer): ...@@ -1429,14 +1529,23 @@ class NavigatorOptimizer(Optimizer):
Install some FunctionGraph listeners to help the navigator deal with Install some FunctionGraph listeners to help the navigator deal with
the ignore_trees-related functionality. the ignore_trees-related functionality.
:param importer: function that will be called whenever when Parameters
optimizations add stuff to the graph. ----------
:param pruner: function to be called when optimizations remove stuff importer
from graph. Function that will be called whenever optimizations add stuff
:param chin: "on change input" called whenever an node's inputs change. to the graph.
pruner
:returns: The FunctionGraph plugin that handles the three tasks. Function to be called when optimizations remove stuff
from the graph.
chin
"on change input" called whenever a node's inputs change.
Returns
-------
object
The FunctionGraph plugin that handles the three tasks.
Keep this around so that you can detach later! Keep this around so that you can detach later!
""" """
if self.ignore_newtrees: if self.ignore_newtrees:
importer = None importer = None
...@@ -1449,18 +1558,25 @@ class NavigatorOptimizer(Optimizer): ...@@ -1449,18 +1558,25 @@ class NavigatorOptimizer(Optimizer):
return u return u
def detach_updater(self, fgraph, u): def detach_updater(self, fgraph, u):
"""Undo the work of attach_updater. """
Undo the work of attach_updater.
Parameters
----------
u
A return-value of attach_updater.
:param u: a return-value of attach_updater Returns
-------
None
:returns: None.
""" """
if u is not None: if u is not None:
fgraph.remove_feature(u) fgraph.remove_feature(u)
def process_node(self, fgraph, node, lopt=None): def process_node(self, fgraph, node, lopt=None):
""" """
This function will use `lopt` to `transform` the `node`. The This function will use `lopt` to `transform` the `node`. The
`transform` method will return either False or a list of Variables `transform` method will return either False or a list of Variables
that are intended to replace `node.outputs`. that are intended to replace `node.outputs`.
...@@ -1470,12 +1586,20 @@ class NavigatorOptimizer(Optimizer): ...@@ -1470,12 +1586,20 @@ class NavigatorOptimizer(Optimizer):
If there are no replacement candidates or the fgraph rejects the If there are no replacement candidates or the fgraph rejects the
replacements, this function returns False. replacements, this function returns False.
:param fgraph: a FunctionGraph Parameters
:param node: an Apply instance in `fgraph` ----------
:param lopt: a LocalOptimizer instance that may have a better idea for fgraph
A FunctionGraph.
node
An Apply instance in `fgraph`
lopt
A LocalOptimizer instance that may have a better idea for
how to compute node's outputs. how to compute node's outputs.
:rtype: Bool
:returns: True iff the `node`'s outputs were replaced in the `fgraph`. Returns
-------
bool
True iff the `node`'s outputs were replaced in the `fgraph`.
""" """
lopt = lopt or self.local_opt lopt = lopt or self.local_opt
...@@ -1544,7 +1668,10 @@ class NavigatorOptimizer(Optimizer): ...@@ -1544,7 +1668,10 @@ class NavigatorOptimizer(Optimizer):
class TopoOptimizer(NavigatorOptimizer): class TopoOptimizer(NavigatorOptimizer):
"""WRITEME""" """
WRITEME
"""
def __init__(self, local_opt, order='in_to_out', ignore_newtrees=False, def __init__(self, local_opt, order='in_to_out', ignore_newtrees=False,
failure_callback=None): failure_callback=None):
...@@ -1617,7 +1744,10 @@ class TopoOptimizer(NavigatorOptimizer): ...@@ -1617,7 +1744,10 @@ class TopoOptimizer(NavigatorOptimizer):
class OpKeyOptimizer(NavigatorOptimizer): class OpKeyOptimizer(NavigatorOptimizer):
"""WRITEME""" """
WRITEME
"""
def __init__(self, local_opt, ignore_newtrees=False, def __init__(self, local_opt, ignore_newtrees=False,
failure_callback=None): failure_callback=None):
...@@ -1661,6 +1791,7 @@ class OpKeyOptimizer(NavigatorOptimizer): ...@@ -1661,6 +1791,7 @@ class OpKeyOptimizer(NavigatorOptimizer):
Requires the following features: Requires the following features:
- NodeFinder - NodeFinder
- ReplaceValidate(Added by default) - ReplaceValidate(Added by default)
""" """
super(OpKeyOptimizer, self).add_requirements(fgraph) super(OpKeyOptimizer, self).add_requirements(fgraph)
fgraph.attach_feature(toolbox.NodeFinder()) fgraph.attach_feature(toolbox.NodeFinder())
...@@ -1686,24 +1817,27 @@ class ChangeTracker: ...@@ -1686,24 +1817,27 @@ class ChangeTracker:
class EquilibriumOptimizer(NavigatorOptimizer): class EquilibriumOptimizer(NavigatorOptimizer):
"""
Apply optimizations until equilibrium point.
Parameters
----------
optimizers
List or set of local or global optimizations to apply until equilibrium.
max_use_ratio
Each optimizer can be applied at most (size of graph * this number)
times.
ignore_newtrees
See EquilibriumDB ignore_newtrees parameter definition.
"""
def __init__(self, def __init__(self,
optimizers, optimizers,
failure_callback=None, failure_callback=None,
ignore_newtrees=True, ignore_newtrees=True,
max_use_ratio=None, max_use_ratio=None,
final_optimizers=None): final_optimizers=None):
""" Apply optimizations until equilibrium point.
:param optimizers: list or set of local or global optimizations to
apply until equilibrium.
:param max_use_ratio: each optimizer can be applied at most
(size of graph * this number) times
:param ignore_newtrees: See EquilibriumDB ignore_newtrees
parameter definition
"""
super(EquilibriumOptimizer, self).__init__( super(EquilibriumOptimizer, self).__init__(
None, None,
ignore_newtrees=ignore_newtrees, ignore_newtrees=ignore_newtrees,
...@@ -2083,8 +2217,10 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2083,8 +2217,10 @@ class EquilibriumOptimizer(NavigatorOptimizer):
def _check_chain(r, chain): def _check_chain(r, chain):
"""WRITEME""" """
WRITEME
"""
chain = list(reversed(chain)) chain = list(reversed(chain))
while chain: while chain:
elem = chain.pop() elem = chain.pop()
...@@ -2115,17 +2251,20 @@ def _check_chain(r, chain): ...@@ -2115,17 +2251,20 @@ def _check_chain(r, chain):
def check_chain(r, *chain): def check_chain(r, *chain):
"""WRITEME""" """
WRITEME
"""
if isinstance(r, graph.Apply): if isinstance(r, graph.Apply):
r = r.outputs[0] r = r.outputs[0]
return _check_chain(r, reduce(list.__iadd__, ([x, 0] for x in chain))) return _check_chain(r, reduce(list.__iadd__, ([x, 0] for x in chain)))
def pre_greedy_local_optimizer(list_optimizations, out): def pre_greedy_local_optimizer(list_optimizations, out):
''' """
This function traverses the computation graph described by all This function traverses the computation graph described by all
``node`` in the graph before the variable out but that are not in the ``node`` in the graph before the variable out but that are not in the
fgraph. it applies each of the local_optimizations on the traversed graph. fgraph. It applies each of the local_optimizations on the traversed graph.
Its main use is to apply locally constant folding when generating Its main use is to apply locally constant folding when generating
the graph of the indices of a subtensor. the graph of the indices of a subtensor.
...@@ -2133,11 +2272,14 @@ def pre_greedy_local_optimizer(list_optimizations, out): ...@@ -2133,11 +2272,14 @@ def pre_greedy_local_optimizer(list_optimizations, out):
We should not apply optimizations on node that are in fgraph. We should not apply optimizations on node that are in fgraph.
So we don't optimize node that have an attribute fgraph. So we don't optimize node that have an attribute fgraph.
:note: This don't do an equilibrium... So if there is optimization Notes
like local_upcast_elemwise_constant_inputs in the list, that -----
add additional node to the inputs of the node, it can This doesn't do an equilibrium... So if there is optimization
be needed to call this function multiple time. like local_upcast_elemwise_constant_inputs in the list, that
''' adds additional node to the inputs of the node, it can
be needed to call this function multiple times.
"""
def local_recursive_function(list_opt, out, optimized_vars, depth): def local_recursive_function(list_opt, out, optimized_vars, depth):
if not getattr(out, 'owner', None): if not getattr(out, 'owner', None):
return [out], optimized_vars return [out], optimized_vars
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论