提交 46d46d76 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

numpydoc for theano/gof/opt.py

上级 ae99f41d
"""
Defines the base class for optimizations as well as a certain
amount of useful generic optimization tools.
"""
from __future__ import print_function
......@@ -35,10 +36,13 @@ def _list_of_nodes(fgraph):
class Optimizer(object):
"""WRITEME
"""
WRITEME
An L{Optimizer} can be applied to an L{FunctionGraph} to transform it.
It can represent an optimization or in general any kind
of transformation you could apply to an L{FunctionGraph}.
"""
def __hash__(self):
......@@ -58,19 +62,25 @@ class Optimizer(object):
return id(self) != id(other)
def apply(self, fgraph):
"""WRITEME
"""
WRITEME
Applies the optimization to the provided L{FunctionGraph}. It may
use all the methods defined by the L{FunctionGraph}. If the
L{Optimizer} needs to use a certain tool, such as an
L{InstanceFinder}, it can do so in its L{add_requirements} method.
"""
pass
def optimize(self, fgraph, *args, **kwargs):
"""WRITEME
This is meant as a shortcut to::
"""
WRITEME
This is meant as a shortcut to:
opt.add_requirements(fgraph)
opt.apply(fgraph)
"""
self.add_requirements(fgraph)
try:
......@@ -82,18 +92,24 @@ class Optimizer(object):
return ret
def __call__(self, fgraph):
"""WRITEME
Same as self.optimize(fgraph)
"""
WRITEME
Same as self.optimize(fgraph).
"""
return self.optimize(fgraph)
def add_requirements(self, fgraph):
"""WRITEME
"""
WRITEME
Add features to the fgraph that are required to apply the optimization.
For example:
fgraph.attach_feature(History())
fgraph.attach_feature(MyFeature())
etc.
"""
pass
......@@ -111,7 +127,10 @@ class Optimizer(object):
class FromFunctionOptimizer(Optimizer):
"""WRITEME"""
"""
WRITEME
"""
def __init__(self, fn, requirements=()):
self.apply = fn
self.requirements = requirements
......@@ -134,14 +153,20 @@ class FromFunctionOptimizer(Optimizer):
def optimizer(f):
"""decorator for FromFunctionOptimizer"""
"""
Decorator for FromFunctionOptimizer.
"""
rval = FromFunctionOptimizer(f)
rval.__name__ = f.__name__
return rval
def inplace_optimizer(f):
"""decorator for FromFunctionOptimizer"""
"""
Decorator for FromFunctionOptimizer.
"""
dh_handler = dh.DestroyHandler
requirements = (lambda fgraph:
fgraph.attach_feature(dh_handler()),)
......@@ -152,13 +177,18 @@ def inplace_optimizer(f):
class SeqOptimizer(Optimizer, list):
# inherit from Optimizer first to get Optimizer.__hash__
"""WRITEME
"""
WRITEME
Takes a list of L{Optimizer} instances and applies them
sequentially.
"""
@staticmethod
def warn(exc, self, optimizer):
"""Default failure_callback for SeqOptimizer
"""
Default failure_callback for SeqOptimizer.
"""
_logger.error("SeqOptimizer apply %s" % str(optimizer))
_logger.error("Traceback:")
......@@ -169,15 +199,21 @@ class SeqOptimizer(Optimizer, list):
pdb.post_mortem(sys.exc_info()[2])
def __init__(self, *opts, **kw):
"""WRITEME"""
"""
WRITEME
"""
if len(opts) == 1 and isinstance(opts[0], (list, tuple)):
opts = opts[0]
self[:] = opts
self.failure_callback = kw.pop('failure_callback', None)
def apply(self, fgraph):
"""WRITEME
"""
WRITEME
Applies each L{Optimizer} in self in turn.
"""
l = []
if fgraph.profile:
......@@ -286,6 +322,7 @@ class SeqOptimizer(Optimizer, list):
def merge_profile(prof1, prof2):
"""
Merge 2 profiles returned by this cass apply() fct.
"""
new_t = []
new_l = []
......@@ -354,7 +391,11 @@ class SeqOptimizer(Optimizer, list):
class _metadict:
"""WRITEME"""
"""
WRITEME
"""
# dict that accepts unhashable keys
# uses an associative list
# for internal use only
......@@ -430,6 +471,7 @@ class MergeFeature(object):
That way, the MergeOptimizer can remember the result of the last merge
pass on the fgraph.
"""
def on_attach(self, fgraph):
assert not hasattr(fgraph, 'merge_feature')
......@@ -493,7 +535,10 @@ class MergeFeature(object):
self.seen_constants.discard(id(c))
def process_constant(self, fgraph, c):
"""Check if a constant can be merged, and queue that replacement"""
"""
Check if a constant can be merged, and queue that replacement.
"""
if id(c) in self.seen_constants:
return
sig = c.merge_signature()
......@@ -511,7 +556,10 @@ class MergeFeature(object):
self.seen_constants.add(id(c))
def process_node(self, fgraph, node):
"""Check if a node can be merged, and queue that replacement."""
"""
Check if a node can be merged, and queue that replacement.
"""
if node in self.nodes_seen:
return
......@@ -570,6 +618,7 @@ class MergeOptimizer(Optimizer):
The first step of merging is constant-merging, so that all clients of an
int(1) for example, are transferred to a particular instance of int(1).
"""
def add_requirements(self, fgraph):
......@@ -678,6 +727,7 @@ def is_same_graph_with_merge(var1, var2, givens=None):
Merge-based implementation of `theano.gof.graph.is_same_graph`.
See help on `theano.gof.graph.is_same_graph` for additional documentation.
"""
if givens is None:
givens = {}
......@@ -718,13 +768,15 @@ def pre_constant_merge(vars):
`vars` is a list of nodes, and we want to merge together nodes
that are constant inputs used to compute nodes in that list.
:note: This function will ignore nodes that are in an fgraph.
It is used to pre-merge nodes generated inside an optimization,
before it is inserted in the fgraph.
It is useful if there are many such replacements to make,
so that DebugMode will not check each of them.
"""
Notes
-----
This function will ignore nodes that are in an fgraph.
It is used to pre-merge nodes generated inside an optimization,
before it is inserted in the fgraph.
It is useful if there are many such replacements to make,
so that DebugMode will not check each of them.
"""
seen_var = set()
# signature -> variable (for constants)
const_sig_inv = {}
......@@ -767,10 +819,12 @@ def pre_constant_merge(vars):
########################
class LocalOptimizer(object):
"""A class for node-based optimizations.
"""
A class for node-based optimizations.
Instances should implement the transform function,
and be passed to configure a fgraph-based Optimizer instance.
"""
def __hash__(self):
......@@ -784,11 +838,13 @@ class LocalOptimizer(object):
Return the list of op classes that this opt applies to.
Return None to apply to all nodes.
"""
return None
def transform(self, node):
"""Transform a subgraph whose output is `node`.
"""
Transform a subgraph whose output is `node`.
Subclasses should implement this function so that it returns one of two
kinds of things:
......@@ -800,7 +856,9 @@ class LocalOptimizer(object):
- dict(old variables -> new variables). A dictionary that map
from old variables to new variables to replace.
:type node: an Apply instance
Parameters
----------
node : an Apply instance
"""
......@@ -810,8 +868,8 @@ class LocalOptimizer(object):
def add_requirements(self, fgraph):
"""
If this local optimization wants to add some requirements to the
fgraph,
This is the place to do it.
fgraph, this is the place to do it.
"""
# Added by default
# fgraph.attach_feature(toolbox.ReplaceValidate())
......@@ -830,8 +888,11 @@ theano.configparser.AddConfigVar(
class LocalMetaOptimizer(LocalOptimizer):
"""Base class for meta-optimizers that try a set of LocalOptimizers
to replace a node and choose the one that executes the fastest"""
"""
Base class for meta-optimizers that try a set of LocalOptimizers
to replace a node and choose the one that executes the fastest.
"""
def __init__(self, tracks=None, optimizers=()):
self._tracks = tracks
......@@ -907,9 +968,12 @@ class LocalMetaOptimizer(LocalOptimizer):
return
def provide_inputs(self, node, inputs):
"""If implemented, returns a dictionary mapping all symbolic variables
in ``inputs`` to SharedVariable instances of suitable dummy values. The
``node`` can be inspected to infer required input shapes."""
"""
If implemented, returns a dictionary mapping all symbolic variables
in ``inputs`` to SharedVariable instances of suitable dummy values.
The ``node`` can be inspected to infer required input shapes.
"""
raise NotImplementedError()
def time_call(self, fn):
......@@ -919,7 +983,10 @@ class LocalMetaOptimizer(LocalOptimizer):
class FromFunctionLocalOptimizer(LocalOptimizer):
"""WRITEME"""
"""
WRITEME
"""
def __init__(self, fn, tracks=None, requirements=()):
self.transform = fn
self._tracks = tracks
......@@ -945,7 +1012,10 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
def local_optimizer(tracks, inplace=False):
def decorator(f):
"""WRITEME"""
"""
WRITEME
"""
if tracks is not None:
if len(tracks) is 0:
raise ValueError("Use None instead of an empty list to apply to all nodes.", f.__module__, f.__name__)
......@@ -964,7 +1034,10 @@ def local_optimizer(tracks, inplace=False):
class LocalOptGroup(LocalOptimizer):
"""WRITEME"""
"""
WRITEME
"""
def __init__(self, *optimizers):
if len(optimizers) == 1 and isinstance(optimizers[0], list):
......@@ -1009,12 +1082,23 @@ class LocalOptGroup(LocalOptimizer):
class OpSub(LocalOptimizer):
"""WRITEME
"""
WRITEME
Replaces the application of a certain op by the application of
another op that take the same inputs as what they are replacing.
another op that takes the same inputs as what they are replacing.
e.g. OpSub(add, sub) ==>
Parameters
----------
op1, op2
op1.make_node and op2.make_node must take the same number of
inputs and have the same number of outputs.
Examples
--------
OpSub(add, sub) ==>
add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
"""
# an OpSub does not apply to the nodes it produces
......@@ -1023,10 +1107,6 @@ class OpSub(LocalOptimizer):
retains_inputs = True
def __init__(self, op1, op2, transfer_tags=True):
"""
op1.make_node and op2.make_node must take the same number of
inputs and have the same number of outputs.
"""
self.op1 = op1
self.op2 = op2
self.transfer_tags = transfer_tags
......@@ -1052,9 +1132,12 @@ class OpSub(LocalOptimizer):
class OpRemove(LocalOptimizer):
"""WRITEME
"""
WRITEME
Removes all applications of an op by transferring each of its
outputs to the corresponding input.
"""
reentrant = False # no nodes are added at all
......@@ -1085,25 +1168,27 @@ class OpRemove(LocalOptimizer):
class PatternSub(LocalOptimizer):
"""WRITEME
"""
WRITEME
@todo update
Replaces all occurrences of the input pattern by the output pattern:
input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...)
input_pattern ::= dict(pattern = <input_pattern>,
input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...)
input_pattern ::= dict(pattern = <input_pattern>,
constraint = <constraint>)
sub_pattern ::= input_pattern
sub_pattern ::= string
sub_pattern ::= a Constant instance
sub_pattern ::= int
sub_pattern ::= float
constraint ::= lambda fgraph, expr: additional matching condition
output_pattern ::= (op, <output_pattern1>, <output_pattern2>, ...)
output_pattern ::= string
output_pattern ::= int
output_pattern ::= float
sub_pattern ::= input_pattern
sub_pattern ::= string
sub_pattern ::= a Constant instance
sub_pattern ::= int
sub_pattern ::= float
constraint ::= lambda fgraph, expr: additional matching condition
output_pattern ::= (op, <output_pattern1>, <output_pattern2>, ...)
output_pattern ::= string
output_pattern ::= int
output_pattern ::= float
Each string in the input pattern is a variable that will be set to
whatever expression is found in its place. If the same string is
......@@ -1123,45 +1208,51 @@ class PatternSub(LocalOptimizer):
trying to match and returns True or False according to an
arbitrary criterion.
Examples:
PatternSub((add, 'x', 'y'), (add, 'y', 'x'))
PatternSub((multiply, 'x', 'x'), (square, 'x'))
PatternSub((subtract, (add, 'x', 'y'), 'y'), 'x')
PatternSub((power, 'x', Constant(double, 2.0)), (square, 'x'))
PatternSub((boggle, {'pattern': 'x',
'constraint': lambda expr: expr.type == scrabble}),
(scrabble, 'x'))
The constructor creates a PatternSub that replaces occurrences of
in_pattern by occurrences of out_pattern.
Parameters
----------
in_pattern
The input pattern that we want to replace.
out_pattern
The replacement pattern.
allow_multiple_clients : bool
If False, the pattern matching will fail if one of the subpatterns has
more than one client.
skip_identities_fn : TODO
name
Allows to override this optimizer name.
pdb : bool
If True, we invoke pdb when the first node in the pattern matches.
tracks : optional
The values that self.tracks() will return. Useful to speed up
optimization sometimes.
get_nodes : optional
If you provide `tracks`, you must provide this parameter. It must be a
function that takes the tracked node and returns a list of nodes on
which we will try this optimizer.
Notes
-----
`tracks` and `get_nodes` can be used to make this optimizer track a less
frequent Op, so this will make this optimizer tried less frequently.
Examples
--------
PatternSub((add, 'x', 'y'), (add, 'y', 'x'))
PatternSub((multiply, 'x', 'x'), (square, 'x'))
PatternSub((subtract, (add, 'x', 'y'), 'y'), 'x')
PatternSub((power, 'x', Constant(double, 2.0)), (square, 'x'))
PatternSub((boggle, {'pattern': 'x',
'constraint': lambda expr: expr.type == scrabble}),
(scrabble, 'x'))
"""
def __init__(self, in_pattern, out_pattern,
allow_multiple_clients=False,
skip_identities_fn=None, name=None, pdb=False,
tracks=(), get_nodes=None):
"""
Creates a PatternSub that replaces occurrences of
in_pattern by occurrences of out_pattern.
:param in_pattern: the input pattern that we want to replace
:param out_pattern: the replacement pattern
:param allow_multiple_clients: if False, the pattern matching will fail
if one of the subpatterns has more than
one client.
:param skip_identities_fn: TODO
:param name: Allow to override this optimizer name
:param pdb: if True, we invoke pdb when the first node in the
pattern match.
:param tracks: Optional. The values that self.tracks() will
return. Useful to speed up optimization some times.
:param get_nodes: Optional. If you provide `tracks`, you must
provide this parameter. It must be a function that take the
tracked node and return a list of node on which we will try
this optimizer.
`tracks` and `get_nodes` can be used to make this optimizer
track a less frequent Op, so this will make this optimizer
tried less frequently,
"""
self.in_pattern = in_pattern
self.out_pattern = out_pattern
if isinstance(in_pattern, (list, tuple)):
......@@ -1196,6 +1287,7 @@ class PatternSub(LocalOptimizer):
"""
Checks if the graph from node corresponds to in_pattern. If it does,
constructs out_pattern and performs the replacement.
"""
if get_nodes and self.get_nodes is not None:
for real_node in self.get_nodes(node):
......@@ -1357,12 +1449,40 @@ class Updater:
class NavigatorOptimizer(Optimizer):
"""Abstract class
"""
Abstract class.
Parameters
----------
local_opt
A LocalOptimizer to apply over a FunctionGraph (or None is Ok too).
ignore_newtrees
- True: new subgraphs returned by an optimization is not a
candidate for optimization.
- False: new subgraphs returned by an optimization is a candidate
for optimization.
- 'auto': let the local_opt set this parameter via its 'reentrant'
attribute.
failure_callback
A function that takes (exception, navigator, [(old, new),
(old,new),...]) and we call it if there's an exception.
If the trouble is from local_opt.transform(), the new variables
will be 'None'.
If the trouble is from validation (the new types don't match for
example) then the new variables will be the ones created by
transform().
If this parameter is None, then exceptions are not caught here
(raised normally).
"""
@staticmethod
def warn(exc, nav, repl_pairs, local_opt):
"""failure_callback for NavigatorOptimizer: print traceback
"""
Failure_callback for NavigatorOptimizer: print traceback.
"""
if config.on_opt_error != 'ignore':
_logger.error("Optimization failure due to: %s" % str(local_opt))
......@@ -1377,9 +1497,11 @@ class NavigatorOptimizer(Optimizer):
@staticmethod
def warn_inplace(exc, nav, repl_pairs, local_opt):
"""failure_callback for NavigatorOptimizer
"""
Failure_callback for NavigatorOptimizer.
Ignore InconsistencyErrors, print traceback.
ignore InconsistencyErrors, print traceback
"""
if isinstance(exc, InconsistencyError):
return
......@@ -1387,36 +1509,14 @@ class NavigatorOptimizer(Optimizer):
@staticmethod
def warn_ignore(exc, nav, repl_pairs, local_opt):
"""failure_callback for NavigatorOptimizer: ignore all errors
"""
Failure_callback for NavigatorOptimizer: ignore all errors.
"""
pass
def __init__(self, local_opt, ignore_newtrees='auto',
failure_callback=None):
"""
:param local_opt: a LocalOptimizer to apply over a FunctionGraph
(or None is Ok too).
:param ignore_newtrees:
- True: new subgraphs returned by an optimization is not a
candidate for optimization
- False: new subgraphs returned by an optimization is a candidate
for optimization
- 'auto': let the local_opt set this parameter via its 'reentrant'
attribute.
:param failure_callback:
a function that takes (exception, navigator, [(old, new),
(old,new),...]) and we call it if there's an exception.
If the trouble is from local_opt.transform(), the new variables
will be 'None'.
If the trouble is from validation (the new types don't match for
example) then the new variables will be the ones created by
transform().
If this parameter is None, then exceptions are not caught here
(raised normally).
"""
self.local_opt = local_opt
if ignore_newtrees == 'auto':
self.ignore_newtrees = not getattr(local_opt, 'reentrant', True)
......@@ -1429,14 +1529,23 @@ class NavigatorOptimizer(Optimizer):
Install some FunctionGraph listeners to help the navigator deal with
the ignore_trees-related functionality.
:param importer: function that will be called whenever when
optimizations add stuff to the graph.
:param pruner: function to be called when optimizations remove stuff
from graph.
:param chin: "on change input" called whenever an node's inputs change.
:returns: The FunctionGraph plugin that handles the three tasks.
Parameters
----------
importer
Function that will be called whenever optimizations add stuff
to the graph.
pruner
Function to be called when optimizations remove stuff
from the graph.
chin
"on change input" called whenever a node's inputs change.
Returns
-------
object
The FunctionGraph plugin that handles the three tasks.
Keep this around so that you can detach later!
"""
if self.ignore_newtrees:
importer = None
......@@ -1449,18 +1558,25 @@ class NavigatorOptimizer(Optimizer):
return u
def detach_updater(self, fgraph, u):
"""Undo the work of attach_updater.
"""
Undo the work of attach_updater.
Parameters
----------
u
A return-value of attach_updater.
:param u: a return-value of attach_updater
Returns
-------
None
:returns: None.
"""
if u is not None:
fgraph.remove_feature(u)
def process_node(self, fgraph, node, lopt=None):
"""
This function will use `lopt` to `transform` the `node`. The
This function will use `lopt` to `transform` the `node`. The
`transform` method will return either False or a list of Variables
that are intended to replace `node.outputs`.
......@@ -1470,12 +1586,20 @@ class NavigatorOptimizer(Optimizer):
If there are no replacement candidates or the fgraph rejects the
replacements, this function returns False.
:param fgraph: a FunctionGraph
:param node: an Apply instance in `fgraph`
:param lopt: a LocalOptimizer instance that may have a better idea for
Parameters
----------
fgraph
A FunctionGraph.
node
An Apply instance in `fgraph`
lopt
A LocalOptimizer instance that may have a better idea for
how to compute node's outputs.
:rtype: Bool
:returns: True iff the `node`'s outputs were replaced in the `fgraph`.
Returns
-------
bool
True iff the `node`'s outputs were replaced in the `fgraph`.
"""
lopt = lopt or self.local_opt
......@@ -1544,7 +1668,10 @@ class NavigatorOptimizer(Optimizer):
class TopoOptimizer(NavigatorOptimizer):
"""WRITEME"""
"""
WRITEME
"""
def __init__(self, local_opt, order='in_to_out', ignore_newtrees=False,
failure_callback=None):
......@@ -1617,7 +1744,10 @@ class TopoOptimizer(NavigatorOptimizer):
class OpKeyOptimizer(NavigatorOptimizer):
"""WRITEME"""
"""
WRITEME
"""
def __init__(self, local_opt, ignore_newtrees=False,
failure_callback=None):
......@@ -1661,6 +1791,7 @@ class OpKeyOptimizer(NavigatorOptimizer):
Requires the following features:
- NodeFinder
- ReplaceValidate(Added by default)
"""
super(OpKeyOptimizer, self).add_requirements(fgraph)
fgraph.attach_feature(toolbox.NodeFinder())
......@@ -1686,24 +1817,27 @@ class ChangeTracker:
class EquilibriumOptimizer(NavigatorOptimizer):
"""
Apply optimizations until equilibrium point.
Parameters
----------
optimizers
List or set of local or global optimizations to apply until equilibrium.
max_use_ratio
Each optimizer can be applied at most (size of graph * this number)
times.
ignore_newtrees
See EquilibriumDB ignore_newtrees parameter definition.
"""
def __init__(self,
optimizers,
failure_callback=None,
ignore_newtrees=True,
max_use_ratio=None,
final_optimizers=None):
""" Apply optimizations until equilibrium point.
:param optimizers: list or set of local or global optimizations to
apply until equilibrium.
:param max_use_ratio: each optimizer can be applied at most
(size of graph * this number) times
:param ignore_newtrees: See EquilibriumDB ignore_newtrees
parameter definition
"""
super(EquilibriumOptimizer, self).__init__(
None,
ignore_newtrees=ignore_newtrees,
......@@ -2083,8 +2217,10 @@ class EquilibriumOptimizer(NavigatorOptimizer):
def _check_chain(r, chain):
"""WRITEME"""
"""
WRITEME
"""
chain = list(reversed(chain))
while chain:
elem = chain.pop()
......@@ -2115,17 +2251,20 @@ def _check_chain(r, chain):
def check_chain(r, *chain):
"""WRITEME"""
"""
WRITEME
"""
if isinstance(r, graph.Apply):
r = r.outputs[0]
return _check_chain(r, reduce(list.__iadd__, ([x, 0] for x in chain)))
def pre_greedy_local_optimizer(list_optimizations, out):
'''
"""
This function traverses the computation graph described by all
``node`` in the graph before the variable out but that are not in the
fgraph. it applies each of the local_optimizations on the traversed graph.
fgraph. It applies each of the local_optimizations on the traversed graph.
Its main use is to apply locally constant folding when generating
the graph of the indices of a subtensor.
......@@ -2133,11 +2272,14 @@ def pre_greedy_local_optimizer(list_optimizations, out):
We should not apply optimizations on node that are in fgraph.
So we don't optimize node that have an attribute fgraph.
:note: This don't do an equilibrium... So if there is optimization
like local_upcast_elemwise_constant_inputs in the list, that
add additional node to the inputs of the node, it can
be needed to call this function multiple time.
'''
Notes
-----
This doesn't do an equilibrium... So if there is optimization
like local_upcast_elemwise_constant_inputs in the list, that
adds additional node to the inputs of the node, it can
be needed to call this function multiple times.
"""
def local_recursive_function(list_opt, out, optimized_vars, depth):
if not getattr(out, 'owner', None):
return [out], optimized_vars
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论