提交 7e415c47 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

PEP 8

上级 81e1a1e9
...@@ -3,21 +3,23 @@ Defines the base class for optimizations as well as a certain ...@@ -3,21 +3,23 @@ Defines the base class for optimizations as well as a certain
amount of useful generic optimization tools. amount of useful generic optimization tools.
""" """
import copy
import copy, logging, sys, time import logging
import sys
import time
import numpy import numpy
import graph import graph
from env import InconsistencyError from env import InconsistencyError
import op
import utils import utils
import unify import unify
import toolbox import toolbox
import op
import theano import theano
from theano import config from theano import config
from theano.gof.python25 import any, all, deque from theano.gof.python25 import any, all, deque
from theano.configparser import AddConfigVar, BoolParam, config from theano.configparser import AddConfigVar, BoolParam
#if sys.version_info[:2] >= (2,5): #if sys.version_info[:2] >= (2,5):
# from collections import defaultdict # from collections import defaultdict
...@@ -39,9 +41,11 @@ import traceback ...@@ -39,9 +41,11 @@ import traceback
_optimizer_idx = [0] _optimizer_idx = [0]
def _list_of_nodes(env): def _list_of_nodes(env):
return list(graph.io_toposort(env.inputs, env.outputs)) return list(graph.io_toposort(env.inputs, env.outputs))
class Optimizer(object): class Optimizer(object):
"""WRITEME """WRITEME
An L{Optimizer} can be applied to an L{Env} to transform it. An L{Optimizer} can be applied to an L{Env} to transform it.
...@@ -91,26 +95,30 @@ class Optimizer(object): ...@@ -91,26 +95,30 @@ class Optimizer(object):
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, 'name', None) name = getattr(self, 'name', None)
print >> stream, "%s%s %s id=%i" %(' '*level, self.__class__.__name__, print >> stream, "%s%s %s id=%i" % (
name, id(self)) (' ' * level), self.__class__.__name__, name, id(self))
class FromFunctionOptimizer(Optimizer): class FromFunctionOptimizer(Optimizer):
"""WRITEME""" """WRITEME"""
def __init__(self, fn): def __init__(self, fn):
self.apply = fn self.apply = fn
def add_requirements(self, env): def add_requirements(self, env):
# Added by default # Added by default
#env.extend(toolbox.ReplaceValidate()) #env.extend(toolbox.ReplaceValidate())
pass pass
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print >> stream, "%s%s id=%i" %(' '*level, print >> stream, "%s%s id=%i" % (
' ' * level,
str(self.apply), str(self.apply),
id(self)) id(self))
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.fn(*args, **kwargs) return self.fn(*args, **kwargs)
def optimizer(f): def optimizer(f):
"""decorator for FromFunctionOptimizer""" """decorator for FromFunctionOptimizer"""
rval = FromFunctionOptimizer(f) rval = FromFunctionOptimizer(f)
...@@ -118,7 +126,6 @@ def optimizer(f): ...@@ -118,7 +126,6 @@ def optimizer(f):
return rval return rval
class SeqOptimizer(Optimizer, list): class SeqOptimizer(Optimizer, list):
#inherit from Optimizer first to get Optimizer.__hash__ #inherit from Optimizer first to get Optimizer.__hash__
"""WRITEME """WRITEME
...@@ -129,7 +136,7 @@ class SeqOptimizer(Optimizer, list): ...@@ -129,7 +136,7 @@ class SeqOptimizer(Optimizer, list):
def warn(exc, self, optimizer): def warn(exc, self, optimizer):
"""Default failure_callback for SeqOptimizer """Default failure_callback for SeqOptimizer
""" """
_logger.error("SeqOptimizer apply %s"% str(optimizer)) _logger.error("SeqOptimizer apply %s" % str(optimizer))
_logger.error("Traceback:") _logger.error("Traceback:")
_logger.error(traceback.format_exc()) _logger.error(traceback.format_exc())
if config.on_opt_error == 'raise': if config.on_opt_error == 'raise':
...@@ -146,14 +153,15 @@ class SeqOptimizer(Optimizer, list): ...@@ -146,14 +153,15 @@ class SeqOptimizer(Optimizer, list):
"""WRITEME """WRITEME
Applies each L{Optimizer} in self in turn. Applies each L{Optimizer} in self in turn.
""" """
l=[] l = []
nb_node_before = len(env.nodes) nb_node_before = len(env.nodes)
for optimizer in self: for optimizer in self:
try: try:
t0=time.time() t0 = time.time()
optimizer.optimize(env) optimizer.optimize(env)
l.append(float(time.time()-t0)) l.append(float(time.time() - t0))
except AssertionError: # do not catch Assertion failures except AssertionError:
# do not catch Assertion failures
raise raise
except Exception, e: except Exception, e:
if self.failure_callback: if self.failure_callback:
...@@ -192,7 +200,6 @@ class SeqOptimizer(Optimizer, list): ...@@ -192,7 +200,6 @@ class SeqOptimizer(Optimizer, list):
#added to override the list's __neq__ implementation #added to override the list's __neq__ implementation
return id(self) != id(other) return id(self) != id(other)
def __str__(self): def __str__(self):
return "SeqOpt(%s)" % list.__str__(self) return "SeqOpt(%s)" % list.__str__(self)
...@@ -201,14 +208,13 @@ class SeqOptimizer(Optimizer, list): ...@@ -201,14 +208,13 @@ class SeqOptimizer(Optimizer, list):
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, 'name', None) name = getattr(self, 'name', None)
print >> stream, "%s%s %s id=%i" %(' '*level, self.__class__.__name__, name, id(self)) print >> stream, "%s%s %s id=%i" % (
(' ' * level), self.__class__.__name__, name, id(self))
# This way, -1 will do all depth # This way, -1 will do all depth
if depth != 0: if depth != 0:
depth -= 1 depth -= 1
for opt in self: for opt in self:
opt.print_summary(stream, level=level+2, depth=depth) opt.print_summary(stream, level=(level + 2), depth=depth)
class _metadict: class _metadict:
...@@ -219,13 +225,15 @@ class _metadict: ...@@ -219,13 +225,15 @@ class _metadict:
def __init__(self): def __init__(self):
self.d = {} self.d = {}
self.l = [] self.l = []
def __getitem__(self, item): def __getitem__(self, item):
return self.get(item, None) return self.get(item, None)
def __setitem__(self, item, value): def __setitem__(self, item, value):
try: try:
self.d[item] = value self.d[item] = value
except Exception: except Exception:
for i, (key,val) in enumerate(self.l): for i, (key, val) in enumerate(self.l):
if key == item: if key == item:
self.l[i] = (item, value) self.l[i] = (item, value)
return return
...@@ -265,9 +273,11 @@ class _metadict: ...@@ -265,9 +273,11 @@ class _metadict:
return value return value
else: else:
return default return default
def clear(self): def clear(self):
self.d = {} self.d = {}
self.l = [] self.l = []
def __str__(self): def __str__(self):
return "(%s, %s)" % (self.d, self.l) return "(%s, %s)" % (self.d, self.l)
...@@ -528,12 +538,13 @@ def pre_constant_merge(vars): ...@@ -528,12 +538,13 @@ def pre_constant_merge(vars):
const_sig_inv[sig] = var const_sig_inv[sig] = var
return var return var
if var.owner: if var.owner:
for idx,inp in enumerate(var.owner.inputs): for idx, inp in enumerate(var.owner.inputs):
var.owner.inputs[idx] = recursive_merge(inp) var.owner.inputs[idx] = recursive_merge(inp)
return var return var
return map(recursive_merge, vars) return map(recursive_merge, vars)
######################## ########################
### Local Optimizers ### ### Local Optimizers ###
######################## ########################
...@@ -557,25 +568,31 @@ class LocalOptimizer(object): ...@@ -557,25 +568,31 @@ class LocalOptimizer(object):
Subclasses should implement this function so that it returns one of two Subclasses should implement this function so that it returns one of two
kinds of things: kinds of things:
- False to indicate that no optimization can be applied to this `node`; or - False to indicate that no optimization can be applied to this `node`;
or
- <list of variables> to use in place of `node`'s outputs in the greater graph. - <list of variables> to use in place of `node`'s outputs in the
greater graph.
:type node: an Apply instance :type node: an Apply instance
""" """
raise utils.MethodNotDefined("transform", type(self), self.__class__.__name__) raise utils.MethodNotDefined("transform",
type(self), self.__class__.__name__)
def add_requirements(self, env): def add_requirements(self, env):
"""If this local optimization wants to add some requirements to the env, """
This is the place to do it.""" If this local optimization wants to add some requirements to the env,
This is the place to do it.
"""
# Added by default # Added by default
#env.extend(toolbox.ReplaceValidate()) #env.extend(toolbox.ReplaceValidate())
pass pass
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print >> stream, "%s%s id=%i" %(' '*level, self.__class__.__name__, id(self)) print >> stream, "%s%s id=%i" % (
(' ' * level), self.__class__.__name__, id(self))
class FromFunctionLocalOptimizer(LocalOptimizer): class FromFunctionLocalOptimizer(LocalOptimizer):
"""WRITEME""" """WRITEME"""
...@@ -584,15 +601,21 @@ class FromFunctionLocalOptimizer(LocalOptimizer): ...@@ -584,15 +601,21 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
tracks = [] tracks = []
self.transform = fn self.transform = fn
self._tracks = tracks self._tracks = tracks
def tracks(self): def tracks(self):
return self._tracks return self._tracks
def __str__(self): def __str__(self):
return getattr(self, '__name__', '<FromFunctionLocalOptimizer instance>') return getattr(self, '__name__',
'<FromFunctionLocalOptimizer instance>')
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print >> stream, "%s%s id=%i" %(' '*level, print >> stream, "%s%s id=%i" % (
' ' * level,
str(self.transform), str(self.transform),
id(self)) id(self))
def local_optimizer(*tracks): def local_optimizer(*tracks):
def decorator(f): def decorator(f):
"""WRITEME""" """WRITEME"""
...@@ -607,11 +630,15 @@ class LocalOptGroup(LocalOptimizer): ...@@ -607,11 +630,15 @@ class LocalOptGroup(LocalOptimizer):
def __init__(self, *optimizers): def __init__(self, *optimizers):
self.opts = optimizers self.opts = optimizers
self.reentrant = any(getattr(opt, 'reentrant', True) for opt in optimizers) self.reentrant = any(getattr(opt, 'reentrant', True)
self.retains_inputs = all(getattr(opt, 'retains_inputs', False) for opt in optimizers) for opt in optimizers)
self.retains_inputs = all(getattr(opt, 'retains_inputs', False)
for opt in optimizers)
def __str__(self): def __str__(self):
return getattr(self, '__name__', '<theano.gof.opt.LocalOptGroup instance>'+str([str(o) for o in self.opts])) return getattr(self, '__name__',
('<theano.gof.opt.LocalOptGroup instance>'
+ str([str(o) for o in self.opts])))
def transform(self, node): def transform(self, node):
for opt in self.opts: for opt in self.opts:
...@@ -620,11 +647,12 @@ class LocalOptGroup(LocalOptimizer): ...@@ -620,11 +647,12 @@ class LocalOptGroup(LocalOptimizer):
return repl return repl
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print >> stream, "%s%s id=%i" %(' '*level, self.__class__.__name__, id(self)) print >> stream, "%s%s id=%i" % (
(' ' * level), self.__class__.__name__, id(self))
if depth != 0: if depth != 0:
depth -= 1 depth -= 1
for lopt in self.opts: for lopt in self.opts:
lopt.print_summary(stream, level=level+2, depth=depth) lopt.print_summary(stream, level=(level + 2), depth=depth)
class _LocalOpKeyOptGroup(LocalOptGroup): class _LocalOpKeyOptGroup(LocalOptGroup):
...@@ -644,13 +672,16 @@ class OpSub(LocalOptimizer): ...@@ -644,13 +672,16 @@ class OpSub(LocalOptimizer):
Replaces the application of a certain op by the application of Replaces the application of a certain op by the application of
another op that take the same inputs as what they are replacing. another op that take the same inputs as what they are replacing.
e.g. OpSub(add, sub) ==> add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x)) e.g. OpSub(add, sub) ==>
add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
""" """
reentrant = False # an OpSub does not apply to the nodes it produces # an OpSub does not apply to the nodes it produces
retains_inputs = True # all the inputs of the original node are transferred to the outputs reentrant = False
# all the inputs of the original node are transferred to the outputs
retains_inputs = True
def __init__(self, op1, op2, transfer_tags = True): def __init__(self, op1, op2, transfer_tags=True):
""" """
op1.make_node and op2.make_node must take the same number of op1.make_node and op2.make_node must take the same number of
inputs and have the same number of outputs. inputs and have the same number of outputs.
...@@ -705,7 +736,8 @@ class OpRemove(LocalOptimizer): ...@@ -705,7 +736,8 @@ class OpRemove(LocalOptimizer):
return "%s(x) -> x" % (self.op) return "%s(x) -> x" % (self.op)
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print >> stream, "%s%s(%s) id=%i" %(' '*level, print >> stream, "%s%s(%s) id=%i" % (
' ' * level,
self.__class__.__name__, self.__class__.__name__,
str(self.op), str(self.op),
id(self)) id(self))
...@@ -756,12 +788,12 @@ class PatternSub(LocalOptimizer): ...@@ -756,12 +788,12 @@ class PatternSub(LocalOptimizer):
PatternSub((subtract, (add, 'x', 'y'), 'y'), 'x') PatternSub((subtract, (add, 'x', 'y'), 'y'), 'x')
PatternSub((power, 'x', Constant(double, 2.0)), (square, 'x')) PatternSub((power, 'x', Constant(double, 2.0)), (square, 'x'))
PatternSub((boggle, {'pattern': 'x', PatternSub((boggle, {'pattern': 'x',
'constraint': lambda expr: expr.type == scrabble}), 'constraint': lambda expr: expr.type == scrabble}),
(scrabble, 'x')) (scrabble, 'x'))
""" """
def __init__(self, in_pattern, out_pattern, allow_multiple_clients = False, def __init__(self, in_pattern, out_pattern, allow_multiple_clients=False,
skip_identities_fn = None, name = None, pdb = False): skip_identities_fn=None, name=None, pdb=False):
""" """
Creates a PatternSub that replaces occurrences of Creates a PatternSub that replaces occurrences of
in_pattern by occurrences of out_pattern. in_pattern by occurrences of out_pattern.
...@@ -771,7 +803,8 @@ class PatternSub(LocalOptimizer): ...@@ -771,7 +803,8 @@ class PatternSub(LocalOptimizer):
:param allow_multiple_clients: if False, the pattern matching will fail :param allow_multiple_clients: if False, the pattern matching will fail
if one of the subpatterns has more than if one of the subpatterns has more than
one client. one client.
:param pdb: if True, we invoke pdb when the first node in the pattern match. :param pdb: if True, we invoke pdb when the first node in the
pattern match.
""" """
self.in_pattern = in_pattern self.in_pattern = in_pattern
self.out_pattern = out_pattern self.out_pattern = out_pattern
...@@ -780,8 +813,11 @@ class PatternSub(LocalOptimizer): ...@@ -780,8 +813,11 @@ class PatternSub(LocalOptimizer):
elif isinstance(in_pattern, dict): elif isinstance(in_pattern, dict):
self.op = self.in_pattern['pattern'][0] self.op = self.in_pattern['pattern'][0]
else: else:
raise TypeError("The pattern to search for must start with a specific Op instance.") raise TypeError("The pattern to search for must start with "
self.__doc__ = self.__class__.__doc__ + "\n\nThis instance does: " + str(self) + "\n" "a specific Op instance.")
self.__doc__ = (self.__class__.__doc__
+ "\n\nThis instance does: "
+ str(self) + "\n")
self.allow_multiple_clients = allow_multiple_clients self.allow_multiple_clients = allow_multiple_clients
self.skip_identities_fn = skip_identities_fn self.skip_identities_fn = skip_identities_fn
if name: if name:
...@@ -816,7 +852,7 @@ class PatternSub(LocalOptimizer): ...@@ -816,7 +852,7 @@ class PatternSub(LocalOptimizer):
if node.op != self.op: if node.op != self.op:
return False return False
def match(pattern, expr, u, allow_multiple_clients = False, pdb = False): def match(pattern, expr, u, allow_multiple_clients=False, pdb=False):
def retry_with_equiv(): def retry_with_equiv():
expr_equiv = self.skip_identities(expr) expr_equiv = self.skip_identities(expr)
if expr_equiv is None: if expr_equiv is None:
...@@ -829,7 +865,9 @@ class PatternSub(LocalOptimizer): ...@@ -829,7 +865,9 @@ class PatternSub(LocalOptimizer):
if isinstance(pattern, (list, tuple)): if isinstance(pattern, (list, tuple)):
if expr.owner is None: if expr.owner is None:
return False return False
if not (expr.owner.op == pattern[0]) or (not allow_multiple_clients and len(expr.clients) > 1): if (not (expr.owner.op == pattern[0])
or (not allow_multiple_clients
and len(expr.clients) > 1)):
return retry_with_equiv() return retry_with_equiv()
if len(pattern) - 1 != len(expr.owner.inputs): if len(pattern) - 1 != len(expr.owner.inputs):
return retry_with_equiv() return retry_with_equiv()
...@@ -841,10 +879,14 @@ class PatternSub(LocalOptimizer): ...@@ -841,10 +879,14 @@ class PatternSub(LocalOptimizer):
try: try:
real_pattern = pattern['pattern'] real_pattern = pattern['pattern']
except KeyError: except KeyError:
raise KeyError("Malformed pattern: %s (expected key 'pattern')" % pattern) raise KeyError(
"Malformed pattern: %s (expected key 'pattern')"
% pattern)
constraint = pattern.get('constraint', lambda expr: True) constraint = pattern.get('constraint', lambda expr: True)
if constraint(expr): if constraint(expr):
return match(real_pattern, expr, u, pattern.get('allow_multiple_clients', allow_multiple_clients)) return match(real_pattern, expr, u,
pattern.get('allow_multiple_clients',
allow_multiple_clients))
else: else:
return retry_with_equiv() return retry_with_equiv()
elif isinstance(pattern, basestring): elif isinstance(pattern, basestring):
...@@ -853,17 +895,22 @@ class PatternSub(LocalOptimizer): ...@@ -853,17 +895,22 @@ class PatternSub(LocalOptimizer):
return retry_with_equiv() return retry_with_equiv()
else: else:
u = u.merge(expr, v) u = u.merge(expr, v)
elif isinstance(pattern, (int, float)) and isinstance(expr, graph.Constant): elif (isinstance(pattern, (int, float))
if numpy.all(theano.tensor.constant(pattern).value==expr.value): and isinstance(expr, graph.Constant)):
if numpy.all(
theano.tensor.constant(pattern).value == expr.value):
return u return u
else: else:
return retry_with_equiv() return retry_with_equiv()
elif isinstance(pattern, graph.Constant) and isinstance(expr, graph.Constant) and pattern.equals(expr): elif (isinstance(pattern, graph.Constant)
and isinstance(expr, graph.Constant)
and pattern.equals(expr)):
return u return u
else: else:
return retry_with_equiv() return retry_with_equiv()
if pdb: if pdb:
import pdb;pdb.set_trace() import pdb
pdb.set_trace()
return u return u
def build(pattern, u): def build(pattern, u):
...@@ -872,11 +919,12 @@ class PatternSub(LocalOptimizer): ...@@ -872,11 +919,12 @@ class PatternSub(LocalOptimizer):
return pattern[0](*args) return pattern[0](*args)
elif isinstance(pattern, basestring): elif isinstance(pattern, basestring):
return u[unify.Var(pattern)] return u[unify.Var(pattern)]
elif isinstance(pattern, (int,float)): elif isinstance(pattern, (int, float)):
return pattern return pattern
else: else:
return pattern.clone() return pattern.clone()
u = match(self.in_pattern, node.out, unify.Unification(), True, self.pdb) u = match(self.in_pattern, node.out, unify.Unification(), True,
self.pdb)
if u: if u:
p = self.out_pattern p = self.out_pattern
new = build(p, u) new = build(p, u)
...@@ -886,23 +934,31 @@ class PatternSub(LocalOptimizer): ...@@ -886,23 +934,31 @@ class PatternSub(LocalOptimizer):
return False return False
def __str__(self): def __str__(self):
if getattr(self,'__name__',None): if getattr(self, '__name__', None):
return self.__name__ return self.__name__
def pattern_to_str(pattern): def pattern_to_str(pattern):
if isinstance(pattern, (list, tuple)): if isinstance(pattern, (list, tuple)):
return "%s(%s)" % (str(pattern[0]), ", ".join([pattern_to_str(p) for p in pattern[1:]])) return "%s(%s)" % (
str(pattern[0]),
", ".join([pattern_to_str(p) for p in pattern[1:]]))
elif isinstance(pattern, dict): elif isinstance(pattern, dict):
return "%s subject to %s" % (pattern_to_str(pattern['pattern']), str(pattern.get('constraint', 'no conditions'))) return "%s subject to %s" % (
pattern_to_str(pattern['pattern']),
str(pattern.get('constraint', 'no conditions')))
else: else:
return str(pattern) return str(pattern)
return "%s -> %s" % (pattern_to_str(self.in_pattern), pattern_to_str(self.out_pattern)) return "%s -> %s" % (
pattern_to_str(self.in_pattern),
pattern_to_str(self.out_pattern))
def __repr__(self): def __repr__(self):
return str(self) return str(self)
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, '__name__', getattr(self, 'name', None)) name = getattr(self, '__name__', getattr(self, 'name', None))
print >> stream, "%s%s %s(%s, %s) id=%i" %(' '*level, print >> stream, "%s%s %s(%s, %s) id=%i" % (
' ' * level,
self.__class__.__name__, self.__class__.__name__,
name, name,
str(self.in_pattern), str(self.in_pattern),
...@@ -930,37 +986,48 @@ class NavigatorOptimizer(Optimizer): ...@@ -930,37 +986,48 @@ class NavigatorOptimizer(Optimizer):
_logger.error(traceback.format_exc()) _logger.error(traceback.format_exc())
if isinstance(exc, AssertionError) or config.on_opt_error == 'raise': if isinstance(exc, AssertionError) or config.on_opt_error == 'raise':
raise exc raise exc
@staticmethod @staticmethod
def warn_inplace(exc, nav, repl_pairs, local_opt): def warn_inplace(exc, nav, repl_pairs, local_opt):
"""failure_callback for NavigatorOptimizer: ignore InconsistencyErrors, print traceback """failure_callback for NavigatorOptimizer
ignore InconsistencyErrors, print traceback
""" """
if isinstance(exc, InconsistencyError): if isinstance(exc, InconsistencyError):
return return
return NavigatorOptimizer.warn(exc, nav, repl_pairs, local_opt) return NavigatorOptimizer.warn(exc, nav, repl_pairs, local_opt)
@staticmethod @staticmethod
def warn_ignore(exc, nav, repl_pairs, local_opt): def warn_ignore(exc, nav, repl_pairs, local_opt):
"""failure_callback for NavigatorOptimizer: ignore all errors """failure_callback for NavigatorOptimizer: ignore all errors
""" """
pass pass
def __init__(self, local_opt, ignore_newtrees = 'auto', failure_callback = None): def __init__(self, local_opt, ignore_newtrees='auto',
failure_callback=None):
""" """
:param local_opt: a LocalOptimizer to apply over a Env (or None is Ok too). :param local_opt: a LocalOptimizer to apply over a Env
(or None is Ok too).
:param ignore_newtrees: :param ignore_newtrees:
- True: new subgraphs returned by an optimization is not a candidate for optimization - True: new subgraphs returned by an optimization is not a
- False: new subgraphs returned by an optimization is a candidate for optimization candidate for optimization
- 'auto': let the local_opt set this parameter via its 'reentrant' attribute. - False: new subgraphs returned by an optimization is a candidate
for optimization
- 'auto': let the local_opt set this parameter via its 'reentrant'
attribute.
:param failure_callback: :param failure_callback:
a function that takes (exception, navigator, [(old, new), a function that takes (exception, navigator, [(old, new),
(old,new),...]) and we call it if there's an exception. (old,new),...]) and we call it if there's an exception.
If the trouble is from local_opt.transform(), the new variables will be 'None'. If the trouble is from local_opt.transform(), the new variables
will be 'None'.
If the trouble is from validation (the new types don't match for If the trouble is from validation (the new types don't match for
example) then the new variables will be the ones created by example) then the new variables will be the ones created by
transform(). transform().
If this parameter is None, then exceptions are not caught here (raised normally). If this parameter is None, then exceptions are not caught here
(raised normally).
""" """
self.local_opt = local_opt self.local_opt = local_opt
if ignore_newtrees == 'auto': if ignore_newtrees == 'auto':
...@@ -969,15 +1036,19 @@ class NavigatorOptimizer(Optimizer): ...@@ -969,15 +1036,19 @@ class NavigatorOptimizer(Optimizer):
self.ignore_newtrees = ignore_newtrees self.ignore_newtrees = ignore_newtrees
self.failure_callback = failure_callback self.failure_callback = failure_callback
def attach_updater(self, env, importer, pruner, chin = None): def attach_updater(self, env, importer, pruner, chin=None):
"""Install some Env listeners to help the navigator deal with the ignore_trees-related functionality. """
Install some Env listeners to help the navigator deal with the
ignore_trees-related functionality.
:param importer: function that will be called whenever when optimizations add stuff to the graph. :param importer: function that will be called whenever when
:param pruner: function to be called when optimizations remove stuff from graph. optimizations add stuff to the graph.
:param pruner: function to be called when optimizations remove stuff
from graph.
:param chin: "on change input" called whenever an node's inputs change. :param chin: "on change input" called whenever an node's inputs change.
:returns: The Env plugin that handles the three tasks. Keep this around so that you can detach later! :returns: The Env plugin that handles the three tasks.
Keep this around so that you can detach later!
""" """
if self.ignore_newtrees: if self.ignore_newtrees:
importer = None importer = None
...@@ -1010,21 +1081,22 @@ class NavigatorOptimizer(Optimizer): ...@@ -1010,21 +1081,22 @@ class NavigatorOptimizer(Optimizer):
if u is not None: if u is not None:
env.remove_feature(u) env.remove_feature(u)
def process_node(self, env, node, lopt = None): def process_node(self, env, node, lopt=None):
""" """
This function will use `lopt` to `transform` the `node`. The `transform` method will This function will use `lopt` to `transform` the `node`. The
return either False or a list of Variables that are intended to replace `node.outputs`. `transform` method will return either False or a list of Variables
that are intended to replace `node.outputs`.
If the env accepts the replacement, then the optimization is successful, and this If the env accepts the replacement, then the optimization is
function returns True. successful, and this function returns True.
If there are no replacement candidates or the env rejects the replacements, this If there are no replacement candidates or the env rejects the
function returns False. replacements, this function returns False.
:param env: an Env :param env: an Env
:param node: an Apply instance in `env` :param node: an Apply instance in `env`
:param lopt: a LocalOptimizer instance that may have a better idea for how to compute :param lopt: a LocalOptimizer instance that may have a better idea for
node's outputs. how to compute node's outputs.
:rtype: Bool :rtype: Bool
:returns: True iff the `node`'s outputs were replaced in the `env`. :returns: True iff the `node`'s outputs were replaced in the `env`.
...@@ -1034,16 +1106,19 @@ class NavigatorOptimizer(Optimizer): ...@@ -1034,16 +1106,19 @@ class NavigatorOptimizer(Optimizer):
replacements = lopt.transform(node) replacements = lopt.transform(node)
except Exception, e: except Exception, e:
if self.failure_callback is not None: if self.failure_callback is not None:
self.failure_callback(e, self, [(x, None) for x in node.outputs], lopt) self.failure_callback(e, self,
[(x, None) for x in node.outputs], lopt)
return False return False
else: else:
raise raise
if replacements is False or replacements is None: if replacements is False or replacements is None:
return False return False
if not isinstance(replacements, (tuple, list)): if not isinstance(replacements, (tuple, list)):
raise TypeError('Optimizer %s gave wrong type of replacement. Expected list or tuple.' % lopt) raise TypeError('Optimizer %s gave wrong type of replacement. '
'Expected list or tuple.' % lopt)
if len(node.outputs) != len(replacements): if len(node.outputs) != len(replacements):
raise ValueError('Optimizer %s gave wrong number of replacements' % lopt) raise ValueError('Optimizer %s gave wrong number of replacements'
% lopt)
# If an output would be replaced by itself, no need to perform # If an output would be replaced by itself, no need to perform
# the replacement # the replacement
repl_pairs = [(r, rnew) for r, rnew in zip(node.outputs, replacements) repl_pairs = [(r, rnew) for r, rnew in zip(node.outputs, replacements)
...@@ -1056,8 +1131,8 @@ class NavigatorOptimizer(Optimizer): ...@@ -1056,8 +1131,8 @@ class NavigatorOptimizer(Optimizer):
except Exception, e: except Exception, e:
# This means the replacements were rejected by the env. # This means the replacements were rejected by the env.
# #
# This is not supposed to happen. The default failure_callback will print a # This is not supposed to happen. The default failure_callback
# traceback as a warning. # will print a traceback as a warning.
if self.failure_callback is not None: if self.failure_callback is not None:
self.failure_callback(e, self, repl_pairs, lopt) self.failure_callback(e, self, repl_pairs, lopt)
return False return False
...@@ -1072,26 +1147,33 @@ class NavigatorOptimizer(Optimizer): ...@@ -1072,26 +1147,33 @@ class NavigatorOptimizer(Optimizer):
self.local_opt.add_requirements(env) self.local_opt.add_requirements(env)
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print >> stream, "%s%s (%i)" %(' '*level, self.__class__.__name__, id(self)) print >> stream, "%s%s (%i)" % (
(' ' * level), self.__class__.__name__, id(self))
if depth != 0: if depth != 0:
self.local_opt.print_summary(stream, level=level+2, depth=depth-1) self.local_opt.print_summary(stream, level=(level + 2),
depth=(depth - 1))
class TopoOptimizer(NavigatorOptimizer): class TopoOptimizer(NavigatorOptimizer):
"""WRITEME""" """WRITEME"""
def __init__(self, local_opt, order = 'in_to_out', ignore_newtrees = False, failure_callback = None): def __init__(self, local_opt, order='in_to_out', ignore_newtrees=False,
failure_callback=None):
if order not in ['out_to_in', 'in_to_out']: if order not in ['out_to_in', 'in_to_out']:
raise ValueError("order must be 'out_to_in' or 'in_to_out'") raise ValueError("order must be 'out_to_in' or 'in_to_out'")
self.order = order self.order = order
NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees, failure_callback) NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees,
failure_callback)
def apply(self, env, start_from = None): def apply(self, env, start_from=None):
if start_from is None: start_from = env.outputs if start_from is None:
start_from = env.outputs
q = deque(graph.io_toposort(env.inputs, start_from)) q = deque(graph.io_toposort(env.inputs, start_from))
def importer(node): def importer(node):
if node is not current_node: if node is not current_node:
q.append(node) q.append(node)
def pruner(node): def pruner(node):
if node is not current_node: if node is not current_node:
try: try:
...@@ -1114,14 +1196,16 @@ class TopoOptimizer(NavigatorOptimizer): ...@@ -1114,14 +1196,16 @@ class TopoOptimizer(NavigatorOptimizer):
self.detach_updater(env, u) self.detach_updater(env, u)
class OpKeyOptimizer(NavigatorOptimizer): class OpKeyOptimizer(NavigatorOptimizer):
"""WRITEME""" """WRITEME"""
def __init__(self, local_opt, ignore_newtrees = False, failure_callback = None): def __init__(self, local_opt, ignore_newtrees=False,
failure_callback=None):
if not hasattr(local_opt, 'op_key'): if not hasattr(local_opt, 'op_key'):
raise TypeError("LocalOptimizer for OpKeyOptimizer must have an 'op_key' method.") raise TypeError("LocalOptimizer for OpKeyOptimizer must have "
NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees, failure_callback) "an 'op_key' method.")
NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees,
failure_callback)
def apply(self, env): def apply(self, env):
op = self.local_opt.op_key() op = self.local_opt.op_key()
...@@ -1129,9 +1213,12 @@ class OpKeyOptimizer(NavigatorOptimizer): ...@@ -1129,9 +1213,12 @@ class OpKeyOptimizer(NavigatorOptimizer):
q = reduce(list.__iadd__, map(env.get_nodes, op)) q = reduce(list.__iadd__, map(env.get_nodes, op))
else: else:
q = list(env.get_nodes(op)) q = list(env.get_nodes(op))
def importer(node): def importer(node):
if node is not current_node: if node is not current_node:
if node.op == op: q.append(node) if node.op == op:
q.append(node)
def pruner(node): def pruner(node):
if node is not current_node and node.op == op: if node is not current_node and node.op == op:
try: try:
...@@ -1159,7 +1246,6 @@ class OpKeyOptimizer(NavigatorOptimizer): ...@@ -1159,7 +1246,6 @@ class OpKeyOptimizer(NavigatorOptimizer):
env.extend(toolbox.NodeFinder()) env.extend(toolbox.NodeFinder())
class ChangeTracker: class ChangeTracker:
def __init__(self): def __init__(self):
self.changed = False self.changed = False
...@@ -1176,17 +1262,19 @@ class ChangeTracker: ...@@ -1176,17 +1262,19 @@ class ChangeTracker:
def on_attach(self, env): def on_attach(self, env):
env.change_tracker = self env.change_tracker = self
class EquilibriumOptimizer(NavigatorOptimizer): class EquilibriumOptimizer(NavigatorOptimizer):
def __init__(self, def __init__(self,
optimizers, optimizers,
failure_callback = None, failure_callback=None,
max_depth = None, max_depth=None,
max_use_ratio = None): max_use_ratio=None):
""" """
:param optimizers: list or set of local or global optimizations to apply until :param optimizers: list or set of local or global optimizations to
equilibrium. apply until equilibrium.
:param max_use_ratio: each optimizer can be applied at most (size of graph * this number) :param max_use_ratio: each optimizer can be applied at most
(size of graph * this number) times
:param max_depth: TODO what does this do? (EquilibriumDB sets it to 5) :param max_depth: TODO what does this do? (EquilibriumDB sets it to 5)
...@@ -1194,8 +1282,8 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1194,8 +1282,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
super(EquilibriumOptimizer, self).__init__( super(EquilibriumOptimizer, self).__init__(
None, None,
ignore_newtrees = True, ignore_newtrees=True,
failure_callback = failure_callback) failure_callback=failure_callback)
self.local_optimizers = [] self.local_optimizers = []
self.global_optimizers = [] self.global_optimizers = []
...@@ -1206,7 +1294,8 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1206,7 +1294,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
self.global_optimizers.append(opt) self.global_optimizers.append(opt)
self.max_depth = max_depth self.max_depth = max_depth
self.max_use_ratio = max_use_ratio self.max_use_ratio = max_use_ratio
assert self.max_use_ratio is not None, 'max_use_ratio has to be a number' assert self.max_use_ratio is not None, (
'max_use_ratio has to be a number')
def add_requirements(self, env): def add_requirements(self, env):
super(EquilibriumOptimizer, self).add_requirements(env) super(EquilibriumOptimizer, self).add_requirements(env)
...@@ -1216,7 +1305,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1216,7 +1305,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
for opt in self.global_optimizers: for opt in self.global_optimizers:
opt.add_requirements(env) opt.add_requirements(env)
def apply(self, env, start_from = None): def apply(self, env, start_from=None):
if start_from is None: if start_from is None:
start_from = env.outputs start_from = env.outputs
changed = True changed = True
...@@ -1251,9 +1340,11 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1251,9 +1340,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
nb_nodes.append(len(q)) nb_nodes.append(len(q))
max_nb_nodes = max(max_nb_nodes, len(q)) max_nb_nodes = max(max_nb_nodes, len(q))
max_use = max_nb_nodes * self.max_use_ratio max_use = max_nb_nodes * self.max_use_ratio
def importer(node): def importer(node):
if node is not current_node: if node is not current_node:
q.append(node) q.append(node)
def pruner(node): def pruner(node):
if node is not current_node: if node is not current_node:
try: try:
...@@ -1277,7 +1368,8 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1277,7 +1368,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
opt_name = (getattr(lopt, "name", None) opt_name = (getattr(lopt, "name", None)
or getattr(lopt, "__name__", "")) or getattr(lopt, "__name__", ""))
if node not in env.nodes: if node not in env.nodes:
break # go to next node # go to next node
break
finally: finally:
self.detach_updater(env, u) self.detach_updater(env, u)
self.detach_updater(env, u) #TODO: erase this line, it's redundant at best self.detach_updater(env, u) #TODO: erase this line, it's redundant at best
...@@ -1314,10 +1406,12 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1314,10 +1406,12 @@ class EquilibriumOptimizer(NavigatorOptimizer):
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, 'name', None) name = getattr(self, 'name', None)
print >> stream, "%s%s %s id=%i" %(' '*level, self.__class__.__name__, name, id(self)) print >> stream, "%s%s %s id=%i" % (
(' ' * level), self.__class__.__name__, name, id(self))
if depth != 0: if depth != 0:
for lopt in self.local_optimizers: for lopt in self.local_optimizers:
lopt.print_summary(stream, level=level+2, depth=depth-1) lopt.print_summary(stream, level=(level + 2),
depth=(depth - 1))
################# #################
...@@ -1340,7 +1434,8 @@ def _check_chain(r, chain): ...@@ -1340,7 +1434,8 @@ def _check_chain(r, chain):
return False return False
else: else:
try: try:
if issubclass(elem, op.Op) and not isinstance(r.owner.op, elem): if (issubclass(elem, op.Op)
and not isinstance(r.owner.op, elem)):
return False return False
except TypeError: except TypeError:
return False return False
...@@ -1354,6 +1449,7 @@ def _check_chain(r, chain): ...@@ -1354,6 +1449,7 @@ def _check_chain(r, chain):
return (r is not None) return (r is not None)
#_check_chain.n_calls = 0 #_check_chain.n_calls = 0
def check_chain(r, *chain): def check_chain(r, *chain):
"""WRITEME""" """WRITEME"""
if isinstance(r, graph.Apply): if isinstance(r, graph.Apply):
...@@ -1378,7 +1474,7 @@ def pre_greedy_local_optimizer(list_optimizations, out): ...@@ -1378,7 +1474,7 @@ def pre_greedy_local_optimizer(list_optimizations, out):
add additional node to the inputs of the node, it can add additional node to the inputs of the node, it can
be needed to call this function multiple time. be needed to call this function multiple time.
''' '''
def local_recursive_function( list_opt, out, optimized_vars, depth): def local_recursive_function(list_opt, out, optimized_vars, depth):
if not getattr(out, 'owner', None): if not getattr(out, 'owner', None):
return [out], optimized_vars return [out], optimized_vars
node = out.owner node = out.owner
...@@ -1390,11 +1486,11 @@ def pre_greedy_local_optimizer(list_optimizations, out): ...@@ -1390,11 +1486,11 @@ def pre_greedy_local_optimizer(list_optimizations, out):
else: else:
if inp.owner: if inp.owner:
outs, optimized_vars = local_recursive_function( outs, optimized_vars = local_recursive_function(
list_opt list_opt,
, inp inp,
, optimized_vars optimized_vars,
, depth+1) depth + 1)
for k,v in zip(inp.owner.outputs, outs): for k, v in zip(inp.owner.outputs, outs):
optimized_vars[k] = v optimized_vars[k] = v
nw_in = outs[inp.owner.outputs.index(inp)] nw_in = outs[inp.owner.outputs.index(inp)]
...@@ -1408,10 +1504,10 @@ def pre_greedy_local_optimizer(list_optimizations, out): ...@@ -1408,10 +1504,10 @@ def pre_greedy_local_optimizer(list_optimizations, out):
ret = opt.transform(node) ret = opt.transform(node)
if ret is not False and ret is not None: if ret is not False and ret is not None:
assert len(ret) == len(node.outputs) assert len(ret) == len(node.outputs)
for k,v in zip(node.outputs, ret): for k, v in zip(node.outputs, ret):
optimized_vars[k] = v optimized_vars[k] = v
results = ret results = ret
if ret[0].owner : if ret[0].owner:
node = out.owner node = out.owner
else: else:
break break
...@@ -1422,8 +1518,6 @@ def pre_greedy_local_optimizer(list_optimizations, out): ...@@ -1422,8 +1518,6 @@ def pre_greedy_local_optimizer(list_optimizations, out):
return final_outs[0] return final_outs[0]
############ ############
### Misc ### ### Misc ###
############ ############
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论