提交 d4dfbf2a authored 作者: nouiz's avatar nouiz

Merge pull request #612 from lamblin/merge_feature_rebased

Merge feature (rebased)
"""WRITEME
"""
import os, logging, warnings
import logging
import numpy, theano
import numpy
import theano
from theano import gof
import theano.gof.vm
from theano.configparser import config, AddConfigVar, StrParam, EnumStr
from theano.configparser import config, AddConfigVar, StrParam
_logger = logging.getLogger('theano.compile.mode')
AddConfigVar('optimizer_excluding',
"When using the default mode, we will remove optimizer with that tag. Separate many tags with ':'.",
("When using the default mode, we will remove optimizer with these "
"tags. Separate tags with ':'."),
StrParam("", allow_override=False),
in_c_key=False)
AddConfigVar('optimizer_including',
"When using the default mode, we will add optimizer with that tag. Separate many tags with ':'.",
("When using the default mode, we will add optimizer with these tags. "
"Separate tags with ':'."),
StrParam("", allow_override=False),
in_c_key=False)
AddConfigVar('optimizer_requiring',
"When using the default mode, we will require optimizer with that tag. Separate many tags with ':'.",
("When using the default mode, we will require optimizer with these "
"tags. Separate tags with ':'."),
StrParam("", allow_override=False),
in_c_key=False)
def check_equal(x, y):
"""
Returns True iff x[0] and y[0] are equal (checks the dtype and
......@@ -32,35 +38,37 @@ def check_equal(x, y):
import scipy.sparse as sp
x, y = x[0], y[0]
# TODO: bug in current scipy, two sparse matrices are never equal, remove when moving to 0.7
# TODO: bug in current scipy, two sparse matrices are never equal,
# remove when moving to 0.7
if sp.issparse(x):
x = x.todense()
if sp.issparse(y):
y = y.todense()
if isinstance(x, numpy.ndarray) and isinstance(y, numpy.ndarray):
if x.dtype != y.dtype or x.shape != y.shape or numpy.any(abs(x - y) > 1e-10):
raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y})
if (x.dtype != y.dtype
or x.shape != y.shape
or numpy.any(abs(x - y) > 1e-10)):
raise Exception("Output mismatch.",
{'performlinker': x, 'clinker': y})
else:
if x != y:
raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y})
raise Exception("Output mismatch.",
{'performlinker': x, 'clinker': y})
# If a string is passed as the linker argument in the constructor for
# Mode, it will be used as the key to retrieve the real linker in this
# dictionary
predefined_linkers = {
'py' : gof.PerformLinker(),
'c' : gof.CLinker(),
'c|py' : gof.OpWiseCLinker(allow_gc=True),
'c|py_nogc' : gof.OpWiseCLinker(allow_gc=False),
'c&py' : gof.DualLinker(checker = check_equal),
'vm' : gof.vm.VM_Linker(allow_gc=True, use_cloop=False),
'cvm' : gof.vm.VM_Linker(allow_gc=True, use_cloop=True),
'vm_nogc' : gof.vm.VM_Linker(allow_gc=False, use_cloop=False),
'py': gof.PerformLinker(),
'c': gof.CLinker(),
'c|py': gof.OpWiseCLinker(allow_gc=True),
'c|py_nogc': gof.OpWiseCLinker(allow_gc=False),
'c&py': gof.DualLinker(checker=check_equal),
'vm': gof.vm.VM_Linker(allow_gc=True, use_cloop=False),
'cvm': gof.vm.VM_Linker(allow_gc=True, use_cloop=True),
'vm_nogc': gof.vm.VM_Linker(allow_gc=False, use_cloop=False),
'cvm_nogc': gof.vm.VM_Linker(allow_gc=False, use_cloop=True),
}
......@@ -72,37 +80,37 @@ def register_linker(name, linker):
predefined_linkers[name] = linker
# If a string is passed as the optimizer argument in the constructor
# for Mode, it will be used as the key to retrieve the real optimizer
# in this dictionary
OPT_FAST_RUN = gof.Query(include = ['fast_run'])
OPT_FAST_RUN = gof.Query(include=['fast_run'])
OPT_FAST_RUN_STABLE = OPT_FAST_RUN.requiring('stable')
OPT_FAST_COMPILE = gof.Query(include = ['fast_compile'])
OPT_STABILIZE = gof.Query(include = ['fast_run'])
OPT_FAST_COMPILE = gof.Query(include=['fast_compile'])
OPT_STABILIZE = gof.Query(include=['fast_run'])
OPT_STABILIZE.position_cutoff = 1.5000001
predefined_optimizers = {
None : lambda env: None,
'None' : lambda env: None,
'merge' : gof.MergeOptimizer(),
'fast_run' : OPT_FAST_RUN,
'fast_run_stable' : OPT_FAST_RUN_STABLE,
'fast_compile' : OPT_FAST_COMPILE,
None: (lambda env: None),
'None': (lambda env: None),
'merge': gof.MergeOptimizer(),
'fast_run': OPT_FAST_RUN,
'fast_run_stable': OPT_FAST_RUN_STABLE,
'fast_compile': OPT_FAST_COMPILE,
'stabilize': OPT_STABILIZE
}
def register_optimizer(name, opt):
"""Add a `Optimizer` which can be referred to by `name` in `Mode`."""
if name in predefined_optimizers:
raise ValueError('Optimizer name already taken: %s' % name)
predefined_optimizers[name] = opt
def register_OutputGuard_c_code(type):
OutputGuard.c_code_types.append(type)
class OutputGuard(gof.Op):
"""
This op is used only internally by Theano.
......@@ -120,20 +128,24 @@ class OutputGuard(gof.Op):
TODO: find a current full explanation.
"""
destroy_map = {0:[0]}
view_map = {0:[0]}
destroy_map = {0: [0]}
view_map = {0: [0]}
c_code_types = []
def make_node(self, x):
return gof.Apply(self, [x], [x.type()])
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def perform(self, node, inp, out):
x, = inp
z, = out
z[0] = x
def __str__(self):
return '%s' % self.__class__.__name__
......@@ -141,7 +153,8 @@ class OutputGuard(gof.Op):
x, = inp
z, = out
if isinstance(node.inputs[0].type, theano.scalar.Scalar):
# Scalars are C objects on the stacks, and should not be inc/decrefed
# Scalars are C objects on the stack,
# and should not be inc/decrefed
return """
%(z)s = %(x)s;
""" % locals()
......@@ -161,71 +174,99 @@ class OutputGuard(gof.Op):
_output_guard = OutputGuard()
class AddDestroyHandler(gof.Optimizer):
"""This optimizer performs two important functions:
1) it has a 'requirement' of the destroyhandler. This means that the env will include it
as a feature for this optimization, and keep this feature enabled for subsequent
optimizations. All optimizations that work inplace on any of their inputs must run *after*
this optimization to ensure that the DestroyHandler has been included in the env.
1) it has a 'requirement' of the destroyhandler. This means that the env
will include it as a feature for this optimization, and keep this feature
enabled for subsequent optimizations. All optimizations that work inplace
on any of their inputs must run *after* this optimization to ensure that
the DestroyHandler has been included in the env.
2) It tries to replace each output with an Op that purports to destroy it (but it won't I
promise). If this replacement succeeds it means that there is a bug in theano. It should
not be possible to destroy outputs.
2) It tries to replace each output with an Op that purports to destroy it
(but it won't I promise). If this replacement succeeds it means that
there is a bug in theano. It should not be possible to destroy outputs.
"""
def apply(self, env):
for o in env.outputs:
try:
env.replace_validate(o, _output_guard(o), reason='output_guard')
_logger.info("Output variable %s required output_guard,"
" how was this output left unprotected against destructive operations?"
env.replace_validate(o, _output_guard(o),
reason='output_guard')
_logger.info("Output variable %s required output_guard, "
"how was this output left unprotected against "
"destructive operations?"
% o)
except gof.InconsistencyError:
#this output is already impossible to destroy. no guard necessary
# This output is already impossible to destroy.
# No guard necessary
pass
def add_requirements(self, env):
super(AddDestroyHandler, self).add_requirements(env)
env.extend(gof.DestroyHandler())
class PrintCurrentEnv(gof.Optimizer):
"""This optimizer is for debugging.
Toss it into the optimization pipeline to see the state of things at any given point.
Toss it into the optimization pipeline to see the state of things at any
given point.
"""
def __init__(self, header):
self.header =header
self.header = header
def apply(self, env):
import theano.printing
print "PrintCurrentEnv:", self.header
theano.printing.debugprint(env.outputs)
optdb = gof.SequenceDB()
optdb.register('merge1', gof.MergeOptimizer(),
0, 'fast_run', 'fast_compile')
optdb.register('canonicalize', gof.EquilibriumDB(), # rearranges elemwise expressions
# rearranges elemwise expressions
optdb.register('canonicalize', gof.EquilibriumDB(),
1, 'fast_run', 'fast_compile')
optdb.register('merge1.2', gof.MergeOptimizer(skip_const_merge=False),
optdb.register('merge1.2', gof.MergeOptimizer(),
1.2, 'fast_run', 'fast_compile')
optdb.register('Print1.21', PrintCurrentEnv('Post-canonicalize'),
1.21,)# 'fast_run', 'fast_compile')
1.21,) # 'fast_run', 'fast_compile')
optdb.register('stabilize', gof.EquilibriumDB(), # replace unstable subgraphs
# replace unstable subgraphs
optdb.register('stabilize', gof.EquilibriumDB(),
1.5, 'fast_run')
optdb.register('Print1.51', PrintCurrentEnv('Post-stabilize'),
1.51,) #'fast_run', 'fast_compile')
optdb.register('specialize', gof.EquilibriumDB(), # misc special cases for speed
1.51,) # 'fast_run', 'fast_compile')
# misc special cases for speed
optdb.register('specialize', gof.EquilibriumDB(),
2, 'fast_run')
optdb.register('Print2.01', PrintCurrentEnv('Post-specialize'),
2.01, )#'fast_run', 'fast_compile')
optdb.register('uncanonicalize', gof.EquilibriumDB(),# misc special cases for speed that break canonicalization
2.01,) # 'fast_run', 'fast_compile')
# misc special cases for speed that break canonicalization
optdb.register('uncanonicalize', gof.EquilibriumDB(),
3, 'fast_run')
optdb.register('specialize_device', gof.EquilibriumDB(), # misc special cases for speed that are dependent on the device.
48.6, 'fast_run')#must be after gpu stuff at 48.5
optdb.register('merge2', gof.MergeOptimizer(), # especially constant merge
# misc special cases for speed that are dependent on the device.
optdb.register('specialize_device', gof.EquilibriumDB(),
48.6, 'fast_run') # must be after gpu stuff at 48.5
# especially constant merge
optdb.register('merge2', gof.MergeOptimizer(),
49, 'fast_run')
optdb.register('add_destroy_handler', AddDestroyHandler(),
49.5, 'fast_run', 'inplace')
optdb.register('merge3', gof.MergeOptimizer(), # final pass just to make sure
# final pass just to make sure
optdb.register('merge3', gof.MergeOptimizer(),
100, 'fast_run')
......@@ -251,12 +292,15 @@ class Mode(object):
if optimizer is None:
optimizer = config.optimizer
self.__setstate__((linker, optimizer))
#self.provided_optimizer - typically the `optimizer` arg. But if the `optimizer` arg is
# keyword corresponding to a predefined Query, then this stores the query
#self._optimizer - typically same as provided_optimizer??
#self.__get_optimizer - returns self._optimizer (possibly querying optdb with self._optimizer)
#self.optimizer - property that returns __get_optimizer()
# self.provided_optimizer - typically the `optimizer` arg.
# But if the `optimizer` arg is keyword corresponding to a predefined
# Query, then this stores the query
# self._optimizer - typically same as provided_optimizer??
# self.__get_optimizer - returns self._optimizer (possibly querying
# optdb with self._optimizer)
# self.optimizer - property that returns __get_optimizer()
def __getstate__(self):
return (self.provided_linker, self.provided_optimizer)
......@@ -275,12 +319,13 @@ class Mode(object):
self._optimizer = optimizer
self.call_time = 0
self.fn_time = 0
linker.mode = self #TODO: WHY IS THIS HERE?
linker.mode = self # TODO: WHY IS THIS HERE?
self.optimizer_time = 0
self.linker_time = 0
def __str__(self):
return "Mode(linker = %s, optimizer = %s)" % (self.provided_linker, self.provided_optimizer)
return "Mode(linker = %s, optimizer = %s)" % (
self.provided_linker, self.provided_optimizer)
def __get_optimizer(self):
if isinstance(self._optimizer, gof.Query):
......@@ -298,17 +343,20 @@ class Mode(object):
return (linker, optimizer)
def including(self, *tags):
link, opt = self.get_linker_optimizer(self.provided_linker, self.provided_optimizer)
link, opt = self.get_linker_optimizer(self.provided_linker,
self.provided_optimizer)
#N.B. opt might be a Query instance, not sure what else it might be...
# string? Optimizer? OptDB? who knows???
return self.__class__(linker=link, optimizer=opt.including(*tags))
def excluding(self, *tags):
link, opt = self.get_linker_optimizer(self.provided_linker, self.provided_optimizer)
link, opt = self.get_linker_optimizer(self.provided_linker,
self.provided_optimizer)
return self.__class__(linker=link, optimizer=opt.excluding(*tags))
def requiring(self, *tags):
link, opt = self.get_linker_optimizer(self.provided_linker, self.provided_optimizer)
link, opt = self.get_linker_optimizer(self.provided_linker,
self.provided_optimizer)
return self.__class__(linker=link, optimizer=opt.requiring(*tags))
# If a string is passed as the mode argument in function or
......@@ -321,20 +369,22 @@ predefined_modes = {'FAST_COMPILE': FAST_COMPILE,
'FAST_RUN': FAST_RUN,
}
instanciated_default_mode=None
instanciated_default_mode = None
def get_mode(orig_string):
if orig_string is None:
string = config.mode
else:
string = orig_string
if not isinstance(string, basestring):
return string #it is hopefully already a mode...
return string # it is hopefully already a mode...
global instanciated_default_mode
# The default mode is cached. However, config.mode can change
# If instanciated_default_mode has the right class, use it.
if orig_string is None and instanciated_default_mode:
if predefined_modes.has_key(string):
if string in predefined_modes:
default_mode_class = predefined_modes[string].__class__.__name__
else:
default_mode_class = string
......@@ -342,7 +392,7 @@ def get_mode(orig_string):
default_mode_class):
return instanciated_default_mode
if string in ['Mode','ProfileMode','DebugMode']:
if string in ['Mode', 'ProfileMode', 'DebugMode']:
if string == 'DebugMode':
#need to import later to break circular dependency.
from debugmode import DebugMode
......@@ -350,12 +400,13 @@ def get_mode(orig_string):
ret = DebugMode(optimizer=config.optimizer)
else:
# The import is needed in case string is ProfileMode
from profilemode import ProfileMode,prof_mode_instance_to_print
ret = eval(string+'(linker=config.linker, optimizer=config.optimizer)')
elif predefined_modes.has_key(string):
from profilemode import ProfileMode, prof_mode_instance_to_print
ret = eval(string
+ '(linker=config.linker, optimizer=config.optimizer)')
elif string in predefined_modes:
ret = predefined_modes[string]
else:
raise Exception("No predefined mode exist for string: %s"%string)
raise Exception("No predefined mode exist for string: %s" % string)
if orig_string is None:
# Build and cache the default mode
......@@ -374,12 +425,14 @@ def get_mode(orig_string):
return ret
def get_default_mode():
return get_mode(None)
# Removed: use config.mode instead.
#default_mode = config.mode
def register_mode(name, mode):
"""Add a `Mode` which can be referred to by `name` in `function`."""
if name in predefined_modes:
......
......@@ -3,21 +3,23 @@ Defines the base class for optimizations as well as a certain
amount of useful generic optimization tools.
"""
import copy, logging, sys, time
import copy
import logging
import sys
import time
import numpy
import graph
from env import InconsistencyError
import op
import utils
import unify
import toolbox
import op
import theano
from theano import config
from theano.gof.python25 import any, all, deque
from theano.configparser import AddConfigVar, BoolParam, config
from theano.configparser import AddConfigVar, BoolParam
#if sys.version_info[:2] >= (2,5):
# from collections import defaultdict
......@@ -39,9 +41,11 @@ import traceback
_optimizer_idx = [0]
def _list_of_nodes(env):
return list(graph.io_toposort(env.inputs, env.outputs))
class Optimizer(object):
"""WRITEME
An L{Optimizer} can be applied to an L{Env} to transform it.
......@@ -91,26 +95,30 @@ class Optimizer(object):
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, 'name', None)
print >> stream, "%s%s %s id=%i" %(' '*level, self.__class__.__name__,
name, id(self))
print >> stream, "%s%s %s id=%i" % (
(' ' * level), self.__class__.__name__, name, id(self))
class FromFunctionOptimizer(Optimizer):
"""WRITEME"""
def __init__(self, fn):
self.apply = fn
def add_requirements(self, env):
# Added by default
#env.extend(toolbox.ReplaceValidate())
pass
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print >> stream, "%s%s id=%i" %(' '*level,
print >> stream, "%s%s id=%i" % (
' ' * level,
str(self.apply),
id(self))
def __call__(self, *args, **kwargs):
return self.fn(*args, **kwargs)
def optimizer(f):
"""decorator for FromFunctionOptimizer"""
rval = FromFunctionOptimizer(f)
......@@ -118,7 +126,6 @@ def optimizer(f):
return rval
class SeqOptimizer(Optimizer, list):
#inherit from Optimizer first to get Optimizer.__hash__
"""WRITEME
......@@ -129,7 +136,7 @@ class SeqOptimizer(Optimizer, list):
def warn(exc, self, optimizer):
"""Default failure_callback for SeqOptimizer
"""
_logger.error("SeqOptimizer apply %s"% str(optimizer))
_logger.error("SeqOptimizer apply %s" % str(optimizer))
_logger.error("Traceback:")
_logger.error(traceback.format_exc())
if config.on_opt_error == 'raise':
......@@ -146,14 +153,15 @@ class SeqOptimizer(Optimizer, list):
"""WRITEME
Applies each L{Optimizer} in self in turn.
"""
l=[]
l = []
nb_node_before = len(env.nodes)
for optimizer in self:
try:
t0=time.time()
t0 = time.time()
optimizer.optimize(env)
l.append(float(time.time()-t0))
except AssertionError: # do not catch Assertion failures
l.append(float(time.time() - t0))
except AssertionError:
# do not catch Assertion failures
raise
except Exception, e:
if self.failure_callback:
......@@ -192,7 +200,6 @@ class SeqOptimizer(Optimizer, list):
#added to override the list's __neq__ implementation
return id(self) != id(other)
def __str__(self):
return "SeqOpt(%s)" % list.__str__(self)
......@@ -201,14 +208,13 @@ class SeqOptimizer(Optimizer, list):
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, 'name', None)
print >> stream, "%s%s %s id=%i" %(' '*level, self.__class__.__name__, name, id(self))
print >> stream, "%s%s %s id=%i" % (
(' ' * level), self.__class__.__name__, name, id(self))
# This way, -1 will do all depth
if depth != 0:
depth -= 1
for opt in self:
opt.print_summary(stream, level=level+2, depth=depth)
opt.print_summary(stream, level=(level + 2), depth=depth)
class _metadict:
......@@ -219,17 +225,39 @@ class _metadict:
def __init__(self):
self.d = {}
self.l = []
def __getitem__(self, item):
return self.get(item, None)
def __setitem__(self, item, value):
try:
self.d[item] = value
except Exception:
for i, (key,val) in enumerate(self.l):
for i, (key, val) in enumerate(self.l):
if key == item:
self.l[i] = (item, value)
return
self.l.append((item, value))
def __delitem__(self, item):
if item in self.d:
del self.d[item]
else:
for i, (key, val) in enumerate(self.l):
if key == item:
del self.l[i]
return
raise KeyError(item)
def discard(self, item):
if item in self.d:
del self.d[item]
else:
for i, (key, val) in enumerate(self.l):
if key == item:
del self.l[i]
return
def get(self, item, default):
try:
return self.d[item]
......@@ -245,13 +273,148 @@ class _metadict:
return value
else:
return default
def clear(self):
self.d = {}
self.l = []
def __str__(self):
return "(%s, %s)" % (self.d, self.l)
class MergeFeature(object):
"""
Keeps track of variables in env that cannot be merged together.
That way, the MergeOptimizer can remember the result of the last merge
pass on the env.
"""
def on_attach(self, env):
assert not hasattr(env, 'merge_feature')
env.merge_feature = self
## For constants
self.seen_constants = set()
# variable -> signature (for constants)
self.const_sig = _metadict()
# signature -> variable (for constants)
self.const_sig_inv = _metadict()
## For all variables
# Set of distinct (not mergeable) nodes
self.nodes_seen = set()
# Each element of scheduled is a list of list of (out, new_out) pairs.
# Each list of pairs represent the substitution needed to replace all
# the outputs of a node with the outputs of a replacement candidate.
# Each node can have several candidates. For instance, if "node" has
# 2 outputs, and there are 3 replacement candidates, we will have:
# shelf.scheduled = [
# [[(node.out1, cand1.out1), (node.out2, cand1.out2)],
# [(node.out1, cand2.out1), (node.out2, cand2.out2)],
# [(node.out1, cand3.out1), (node.out2, cand3.out2)]]]
self.scheduled = []
# List of (node, candidate) pairs, where we tried to replace node by
# candidate, but it failed. This is used to avoid infinite loops
# during the replacement phase.
self.blacklist = []
for node in env.toposort():
self.on_import(env, node)
def on_change_input(self, env, node, i, r, new_r):
# If inputs to node change, it is not guaranteed that it is distinct
# from the other nodes in nodes_seen
if node in self.nodes_seen:
self.nodes_seen.discard(node)
self.process_node(env, node)
if isinstance(new_r, graph.Constant):
self.process_constant(env, new_r)
def on_import(self, env, node):
for c in node.inputs:
if isinstance(c, graph.Constant):
self.process_constant(env, c)
self.process_node(env, node)
def on_prune(self, env, node):
self.nodes_seen.discard(node)
for c in node.inputs:
if isinstance(c, graph.Constant) and (len(c.clients) <= 1):
# This was the last node using this constant
sig = self.const_sig[c]
self.const_sig.discard(c)
self.const_sig_inv.discard(sig)
self.seen_constants.discard(id(c))
def process_constant(self, env, c):
"""Check if a constant can be merged, and queue that replacement"""
if id(c) in self.seen_constants:
return
sig = c.signature()
other_c = self.const_sig_inv.get(sig, None)
if other_c is not None:
# multiple names will clobber each other..
# we adopt convention to keep the last name
if c.name:
other_c.name = c.name
self.scheduled.append([[(c, other_c)]])
else:
#this is a new constant
self.const_sig[c] = sig
self.const_sig_inv[sig] = c
self.seen_constants.add(id(c))
def process_node(self, env, node):
"""Check if a node can be merged, and queue that replacement."""
if node in self.nodes_seen:
return
# These asserts ensure that the env has set the clients field properly.
# The clients should at least contain `node` itself!
if node.inputs:
assert len(node.inputs[0].clients) > 0
assert (node, 0) in node.inputs[0].clients
merge_candidates = [c for (c, i) in node.inputs[0].clients
if c in self.nodes_seen]
else:
merge_candidates = []
replacement_candidates = []
for candidate in merge_candidates:
if candidate is node:
continue
if len(node.inputs) != len(candidate.inputs):
continue
inputs_match = all(node_in is cand_in
for node_in, cand_in in zip(node.inputs, candidate.inputs))
if inputs_match and node.op == candidate.op:
if (node, candidate) in self.blacklist:
# They were already tried, and there was an error
continue
# Schedule transfer of clients from node to candidate
pairs = zip(node.outputs, candidate.outputs)
#transfer names
for node_output, cand_output in pairs:
#clobber old name with new one
#it's arbitrary... one of the names has to go
if node_output.name:
cand_output.name = node_output.name
replacement_candidates.append(pairs)
if replacement_candidates:
self.scheduled.append(replacement_candidates)
else:
self.nodes_seen.add(node)
class MergeOptimizer(Optimizer):
"""
Merges parts of the graph that are identical and redundant.
......@@ -264,94 +427,32 @@ class MergeOptimizer(Optimizer):
The first step of merging is constant-merging, so that all clients of an
int(1) for example, are transferred to a particular instance of int(1).
"""
def __init__(self, skip_const_merge=False):
self.skip_const_merge = skip_const_merge
def add_requirements(self, env):
# Added by default
#env.extend(toolbox.ReplaceValidate())
pass
if not hasattr(env, 'merge_feature'):
env.extend(MergeFeature())
def apply_constant_merge(self, env):
seen_constants = set()
const_sig = _metadict() # variable -> variable.signature() (for constants)
const_sig_inv = _metadict() # signature -> variable (for constants)
for node in _list_of_nodes(env):
for i, c in enumerate([r for r in node.inputs if isinstance(r, graph.Constant)]):
if id(c) in seen_constants:
continue
else:
seen_constants.add(id(c))
sig = c.signature()
other_c = const_sig_inv.get(sig, None)
if other_c is not None:
# multiple names will clobber each other..
# we adopt convention to keep the last name
if c.name:
other_c.name = c.name
env.replace_validate(c, other_c, reason='Constant Merge')
else:
#this is a new constant
const_sig[c] = sig
const_sig_inv[sig] = c
def apply_node_merge(self, env):
# we clear the dicts because the Constants signatures are not necessarily hashable
# and it's more efficient to give them an integer like the other Variables
nodes_seen = {}
for node_idx, node in enumerate(_list_of_nodes(env)):
#
# these asserts ensure that the env has set the clients field properly the clients
# should at least contain `node` itself!
#
if node.inputs:
assert len(node.inputs[0].clients) > 0
assert (node,0) in node.inputs[0].clients
merge_candidates = [(nodes_seen[c],c) for (c,i) in node.inputs[0].clients if c in nodes_seen]
else:
merge_candidates = []
merge_candidates.sort()
nodes_seen[node] = node_idx
#print 'NODE', node, merge_candidates, node.inputs[0].clients
for candidate_idx, candidate in merge_candidates:
if len(node.inputs) != len(candidate.inputs):
continue
inputs_match = all(node_in is cand_in for node_in, cand_in in zip(node.inputs, candidate.inputs))
if inputs_match and node.op == candidate.op:
assert node is not candidate
#
#transfer clients from node to candidate
#
success = True
assert len(node.outputs) == len(candidate.outputs)
pairs = zip(node.outputs, candidate.outputs)
#transfer names
for node_output, cand_output in pairs:
#clobber old name with new one
#it's arbitrary... one of the names has to go
if node_output.name:
cand_output.name = node_output.name
try:
env.replace_all_validate(pairs, reason="Merge")
except InconsistencyError, e:
success = False
if success:
#break out of the candidate loop
break
else:
#try the next candidate
pass
#TODO: Consider splitting this into a separate optimizer (SeqOptimizer)
def apply(self, env):
if not self.skip_const_merge:
self.apply_constant_merge(env)
self.apply_node_merge(env)
# Constant and non-constant are now applied in the same phase.
# I am not sure why, but it seems to be faster this way.
sched = env.merge_feature.scheduled
while sched:
pairs_list = sched.pop()
success = True
for pairs in pairs_list:
try:
env.replace_all_validate(pairs, 'Merge')
except InconsistencyError:
success = False
env.merge_feature.blacklist.append(
(pairs[0][0].owner, pairs[0][1].owner))
if success:
break
# clear blacklist
env.merge_feature.blacklist = []
merge_optimizer = MergeOptimizer()
......@@ -417,8 +518,9 @@ def pre_constant_merge(vars):
"""
seen_var = set()
const_sig = {} # variable -> variable.signature() (for constants)
const_sig_inv = {} # signature -> variable (for constants)
# signature -> variable (for constants)
const_sig_inv = {}
def recursive_merge(var):
if var in seen_var:
return var
......@@ -434,12 +536,13 @@ def pre_constant_merge(vars):
const_sig_inv[sig] = var
return var
if var.owner:
for idx,inp in enumerate(var.owner.inputs):
for idx, inp in enumerate(var.owner.inputs):
var.owner.inputs[idx] = recursive_merge(inp)
return var
return map(recursive_merge, vars)
########################
### Local Optimizers ###
########################
......@@ -463,25 +566,31 @@ class LocalOptimizer(object):
Subclasses should implement this function so that it returns one of two
kinds of things:
- False to indicate that no optimization can be applied to this `node`; or
- <list of variables> to use in place of `node`'s outputs in the greater graph.
- False to indicate that no optimization can be applied to this `node`;
or
- <list of variables> to use in place of `node`'s outputs in the
greater graph.
:type node: an Apply instance
"""
raise utils.MethodNotDefined("transform", type(self), self.__class__.__name__)
raise utils.MethodNotDefined("transform",
type(self), self.__class__.__name__)
def add_requirements(self, env):
"""If this local optimization wants to add some requirements to the env,
This is the place to do it."""
"""
If this local optimization wants to add some requirements to the env,
This is the place to do it.
"""
# Added by default
#env.extend(toolbox.ReplaceValidate())
pass
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print >> stream, "%s%s id=%i" %(' '*level, self.__class__.__name__, id(self))
print >> stream, "%s%s id=%i" % (
(' ' * level), self.__class__.__name__, id(self))
class FromFunctionLocalOptimizer(LocalOptimizer):
"""WRITEME"""
......@@ -490,15 +599,21 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
tracks = []
self.transform = fn
self._tracks = tracks
def tracks(self):
return self._tracks
def __str__(self):
return getattr(self, '__name__', '<FromFunctionLocalOptimizer instance>')
return getattr(self, '__name__',
'<FromFunctionLocalOptimizer instance>')
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print >> stream, "%s%s id=%i" %(' '*level,
print >> stream, "%s%s id=%i" % (
' ' * level,
str(self.transform),
id(self))
def local_optimizer(*tracks):
def decorator(f):
"""WRITEME"""
......@@ -513,11 +628,15 @@ class LocalOptGroup(LocalOptimizer):
def __init__(self, *optimizers):
self.opts = optimizers
self.reentrant = any(getattr(opt, 'reentrant', True) for opt in optimizers)
self.retains_inputs = all(getattr(opt, 'retains_inputs', False) for opt in optimizers)
self.reentrant = any(getattr(opt, 'reentrant', True)
for opt in optimizers)
self.retains_inputs = all(getattr(opt, 'retains_inputs', False)
for opt in optimizers)
def __str__(self):
return getattr(self, '__name__', '<theano.gof.opt.LocalOptGroup instance>'+str([str(o) for o in self.opts]))
return getattr(self, '__name__',
('<theano.gof.opt.LocalOptGroup instance>'
+ str([str(o) for o in self.opts])))
def transform(self, node):
for opt in self.opts:
......@@ -526,11 +645,12 @@ class LocalOptGroup(LocalOptimizer):
return repl
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print >> stream, "%s%s id=%i" %(' '*level, self.__class__.__name__, id(self))
print >> stream, "%s%s id=%i" % (
(' ' * level), self.__class__.__name__, id(self))
if depth != 0:
depth -= 1
for lopt in self.opts:
lopt.print_summary(stream, level=level+2, depth=depth)
lopt.print_summary(stream, level=(level + 2), depth=depth)
class _LocalOpKeyOptGroup(LocalOptGroup):
......@@ -550,13 +670,16 @@ class OpSub(LocalOptimizer):
Replaces the application of a certain op by the application of
another op that take the same inputs as what they are replacing.
e.g. OpSub(add, sub) ==> add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
e.g. OpSub(add, sub) ==>
add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
"""
reentrant = False # an OpSub does not apply to the nodes it produces
retains_inputs = True # all the inputs of the original node are transferred to the outputs
# an OpSub does not apply to the nodes it produces
reentrant = False
# all the inputs of the original node are transferred to the outputs
retains_inputs = True
def __init__(self, op1, op2, transfer_tags = True):
def __init__(self, op1, op2, transfer_tags=True):
"""
op1.make_node and op2.make_node must take the same number of
inputs and have the same number of outputs.
......@@ -611,7 +734,8 @@ class OpRemove(LocalOptimizer):
return "%s(x) -> x" % (self.op)
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print >> stream, "%s%s(%s) id=%i" %(' '*level,
print >> stream, "%s%s(%s) id=%i" % (
' ' * level,
self.__class__.__name__,
str(self.op),
id(self))
......@@ -662,12 +786,12 @@ class PatternSub(LocalOptimizer):
PatternSub((subtract, (add, 'x', 'y'), 'y'), 'x')
PatternSub((power, 'x', Constant(double, 2.0)), (square, 'x'))
PatternSub((boggle, {'pattern': 'x',
'constraint': lambda expr: expr.type == scrabble}),
'constraint': lambda expr: expr.type == scrabble}),
(scrabble, 'x'))
"""
def __init__(self, in_pattern, out_pattern, allow_multiple_clients = False,
skip_identities_fn = None, name = None, pdb = False):
def __init__(self, in_pattern, out_pattern, allow_multiple_clients=False,
skip_identities_fn=None, name=None, pdb=False):
"""
Creates a PatternSub that replaces occurrences of
in_pattern by occurrences of out_pattern.
......@@ -677,7 +801,8 @@ class PatternSub(LocalOptimizer):
:param allow_multiple_clients: if False, the pattern matching will fail
if one of the subpatterns has more than
one client.
:param pdb: if True, we invoke pdb when the first node in the pattern match.
:param pdb: if True, we invoke pdb when the first node in the
pattern match.
"""
self.in_pattern = in_pattern
self.out_pattern = out_pattern
......@@ -686,8 +811,11 @@ class PatternSub(LocalOptimizer):
elif isinstance(in_pattern, dict):
self.op = self.in_pattern['pattern'][0]
else:
raise TypeError("The pattern to search for must start with a specific Op instance.")
self.__doc__ = self.__class__.__doc__ + "\n\nThis instance does: " + str(self) + "\n"
raise TypeError("The pattern to search for must start with "
"a specific Op instance.")
self.__doc__ = (self.__class__.__doc__
+ "\n\nThis instance does: "
+ str(self) + "\n")
self.allow_multiple_clients = allow_multiple_clients
self.skip_identities_fn = skip_identities_fn
if name:
......@@ -722,7 +850,7 @@ class PatternSub(LocalOptimizer):
if node.op != self.op:
return False
def match(pattern, expr, u, allow_multiple_clients = False, pdb = False):
def match(pattern, expr, u, allow_multiple_clients=False, pdb=False):
def retry_with_equiv():
expr_equiv = self.skip_identities(expr)
if expr_equiv is None:
......@@ -735,7 +863,9 @@ class PatternSub(LocalOptimizer):
if isinstance(pattern, (list, tuple)):
if expr.owner is None:
return False
if not (expr.owner.op == pattern[0]) or (not allow_multiple_clients and len(expr.clients) > 1):
if (not (expr.owner.op == pattern[0])
or (not allow_multiple_clients
and len(expr.clients) > 1)):
return retry_with_equiv()
if len(pattern) - 1 != len(expr.owner.inputs):
return retry_with_equiv()
......@@ -747,10 +877,14 @@ class PatternSub(LocalOptimizer):
try:
real_pattern = pattern['pattern']
except KeyError:
raise KeyError("Malformed pattern: %s (expected key 'pattern')" % pattern)
raise KeyError(
"Malformed pattern: %s (expected key 'pattern')"
% pattern)
constraint = pattern.get('constraint', lambda expr: True)
if constraint(expr):
return match(real_pattern, expr, u, pattern.get('allow_multiple_clients', allow_multiple_clients))
return match(real_pattern, expr, u,
pattern.get('allow_multiple_clients',
allow_multiple_clients))
else:
return retry_with_equiv()
elif isinstance(pattern, basestring):
......@@ -759,17 +893,22 @@ class PatternSub(LocalOptimizer):
return retry_with_equiv()
else:
u = u.merge(expr, v)
elif isinstance(pattern, (int, float)) and isinstance(expr, graph.Constant):
if numpy.all(theano.tensor.constant(pattern).value==expr.value):
elif (isinstance(pattern, (int, float))
and isinstance(expr, graph.Constant)):
if numpy.all(
theano.tensor.constant(pattern).value == expr.value):
return u
else:
return retry_with_equiv()
elif isinstance(pattern, graph.Constant) and isinstance(expr, graph.Constant) and pattern.equals(expr):
elif (isinstance(pattern, graph.Constant)
and isinstance(expr, graph.Constant)
and pattern.equals(expr)):
return u
else:
return retry_with_equiv()
if pdb:
import pdb;pdb.set_trace()
import pdb
pdb.set_trace()
return u
def build(pattern, u):
......@@ -778,11 +917,12 @@ class PatternSub(LocalOptimizer):
return pattern[0](*args)
elif isinstance(pattern, basestring):
return u[unify.Var(pattern)]
elif isinstance(pattern, (int,float)):
elif isinstance(pattern, (int, float)):
return pattern
else:
return pattern.clone()
u = match(self.in_pattern, node.out, unify.Unification(), True, self.pdb)
u = match(self.in_pattern, node.out, unify.Unification(), True,
self.pdb)
if u:
p = self.out_pattern
new = build(p, u)
......@@ -792,23 +932,31 @@ class PatternSub(LocalOptimizer):
return False
def __str__(self):
if getattr(self,'__name__',None):
if getattr(self, '__name__', None):
return self.__name__
def pattern_to_str(pattern):
if isinstance(pattern, (list, tuple)):
return "%s(%s)" % (str(pattern[0]), ", ".join([pattern_to_str(p) for p in pattern[1:]]))
return "%s(%s)" % (
str(pattern[0]),
", ".join([pattern_to_str(p) for p in pattern[1:]]))
elif isinstance(pattern, dict):
return "%s subject to %s" % (pattern_to_str(pattern['pattern']), str(pattern.get('constraint', 'no conditions')))
return "%s subject to %s" % (
pattern_to_str(pattern['pattern']),
str(pattern.get('constraint', 'no conditions')))
else:
return str(pattern)
return "%s -> %s" % (pattern_to_str(self.in_pattern), pattern_to_str(self.out_pattern))
return "%s -> %s" % (
pattern_to_str(self.in_pattern),
pattern_to_str(self.out_pattern))
def __repr__(self):
return str(self)
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, '__name__', getattr(self, 'name', None))
print >> stream, "%s%s %s(%s, %s) id=%i" %(' '*level,
print >> stream, "%s%s %s(%s, %s) id=%i" % (
' ' * level,
self.__class__.__name__,
name,
str(self.in_pattern),
......@@ -836,37 +984,48 @@ class NavigatorOptimizer(Optimizer):
_logger.error(traceback.format_exc())
if isinstance(exc, AssertionError) or config.on_opt_error == 'raise':
raise exc
@staticmethod
def warn_inplace(exc, nav, repl_pairs, local_opt):
"""failure_callback for NavigatorOptimizer: ignore InconsistencyErrors, print traceback
"""failure_callback for NavigatorOptimizer
ignore InconsistencyErrors, print traceback
"""
if isinstance(exc, InconsistencyError):
return
return NavigatorOptimizer.warn(exc, nav, repl_pairs, local_opt)
@staticmethod
def warn_ignore(exc, nav, repl_pairs, local_opt):
"""failure_callback for NavigatorOptimizer: ignore all errors
"""
pass
def __init__(self, local_opt, ignore_newtrees = 'auto', failure_callback = None):
def __init__(self, local_opt, ignore_newtrees='auto',
failure_callback=None):
"""
:param local_opt: a LocalOptimizer to apply over a Env (or None is Ok too).
:param local_opt: a LocalOptimizer to apply over a Env
(or None is Ok too).
:param ignore_newtrees:
- True: new subgraphs returned by an optimization is not a candidate for optimization
- False: new subgraphs returned by an optimization is a candidate for optimization
- 'auto': let the local_opt set this parameter via its 'reentrant' attribute.
- True: new subgraphs returned by an optimization is not a
candidate for optimization
- False: new subgraphs returned by an optimization is a candidate
for optimization
- 'auto': let the local_opt set this parameter via its 'reentrant'
attribute.
:param failure_callback:
a function that takes (exception, navigator, [(old, new),
(old,new),...]) and we call it if there's an exception.
If the trouble is from local_opt.transform(), the new variables will be 'None'.
If the trouble is from local_opt.transform(), the new variables
will be 'None'.
If the trouble is from validation (the new types don't match for
example) then the new variables will be the ones created by
transform().
If this parameter is None, then exceptions are not caught here (raised normally).
If this parameter is None, then exceptions are not caught here
(raised normally).
"""
self.local_opt = local_opt
if ignore_newtrees == 'auto':
......@@ -875,15 +1034,19 @@ class NavigatorOptimizer(Optimizer):
self.ignore_newtrees = ignore_newtrees
self.failure_callback = failure_callback
def attach_updater(self, env, importer, pruner, chin = None):
"""Install some Env listeners to help the navigator deal with the ignore_trees-related functionality.
def attach_updater(self, env, importer, pruner, chin=None):
"""
Install some Env listeners to help the navigator deal with the
ignore_trees-related functionality.
:param importer: function that will be called whenever when optimizations add stuff to the graph.
:param pruner: function to be called when optimizations remove stuff from graph.
:param importer: function that will be called whenever when
optimizations add stuff to the graph.
:param pruner: function to be called when optimizations remove stuff
from graph.
:param chin: "on change input" called whenever an node's inputs change.
:returns: The Env plugin that handles the three tasks. Keep this around so that you can detach later!
:returns: The Env plugin that handles the three tasks.
Keep this around so that you can detach later!
"""
if self.ignore_newtrees:
importer = None
......@@ -916,21 +1079,22 @@ class NavigatorOptimizer(Optimizer):
if u is not None:
env.remove_feature(u)
def process_node(self, env, node, lopt = None):
def process_node(self, env, node, lopt=None):
"""
This function will use `lopt` to `transform` the `node`. The `transform` method will
return either False or a list of Variables that are intended to replace `node.outputs`.
This function will use `lopt` to `transform` the `node`. The
`transform` method will return either False or a list of Variables
that are intended to replace `node.outputs`.
If the env accepts the replacement, then the optimization is successful, and this
function returns True.
If the env accepts the replacement, then the optimization is
successful, and this function returns True.
If there are no replacement candidates or the env rejects the replacements, this
function returns False.
If there are no replacement candidates or the env rejects the
replacements, this function returns False.
:param env: an Env
:param node: an Apply instance in `env`
:param lopt: a LocalOptimizer instance that may have a better idea for how to compute
node's outputs.
:param lopt: a LocalOptimizer instance that may have a better idea for
how to compute node's outputs.
:rtype: Bool
:returns: True iff the `node`'s outputs were replaced in the `env`.
......@@ -940,16 +1104,19 @@ class NavigatorOptimizer(Optimizer):
replacements = lopt.transform(node)
except Exception, e:
if self.failure_callback is not None:
self.failure_callback(e, self, [(x, None) for x in node.outputs], lopt)
self.failure_callback(e, self,
[(x, None) for x in node.outputs], lopt)
return False
else:
raise
if replacements is False or replacements is None:
return False
if not isinstance(replacements, (tuple, list)):
raise TypeError('Optimizer %s gave wrong type of replacement. Expected list or tuple.' % lopt)
raise TypeError('Optimizer %s gave wrong type of replacement. '
'Expected list or tuple.' % lopt)
if len(node.outputs) != len(replacements):
raise ValueError('Optimizer %s gave wrong number of replacements' % lopt)
raise ValueError('Optimizer %s gave wrong number of replacements'
% lopt)
# If an output would be replaced by itself, no need to perform
# the replacement
repl_pairs = [(r, rnew) for r, rnew in zip(node.outputs, replacements)
......@@ -962,8 +1129,8 @@ class NavigatorOptimizer(Optimizer):
except Exception, e:
# This means the replacements were rejected by the env.
#
# This is not supposed to happen. The default failure_callback will print a
# traceback as a warning.
# This is not supposed to happen. The default failure_callback
# will print a traceback as a warning.
if self.failure_callback is not None:
self.failure_callback(e, self, repl_pairs, lopt)
return False
......@@ -978,26 +1145,33 @@ class NavigatorOptimizer(Optimizer):
self.local_opt.add_requirements(env)
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print >> stream, "%s%s (%i)" %(' '*level, self.__class__.__name__, id(self))
print >> stream, "%s%s (%i)" % (
(' ' * level), self.__class__.__name__, id(self))
if depth != 0:
self.local_opt.print_summary(stream, level=level+2, depth=depth-1)
self.local_opt.print_summary(stream, level=(level + 2),
depth=(depth - 1))
class TopoOptimizer(NavigatorOptimizer):
"""WRITEME"""
def __init__(self, local_opt, order = 'in_to_out', ignore_newtrees = False, failure_callback = None):
def __init__(self, local_opt, order='in_to_out', ignore_newtrees=False,
failure_callback=None):
if order not in ['out_to_in', 'in_to_out']:
raise ValueError("order must be 'out_to_in' or 'in_to_out'")
self.order = order
NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees, failure_callback)
NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees,
failure_callback)
def apply(self, env, start_from = None):
if start_from is None: start_from = env.outputs
def apply(self, env, start_from=None):
if start_from is None:
start_from = env.outputs
q = deque(graph.io_toposort(env.inputs, start_from))
def importer(node):
if node is not current_node:
q.append(node)
def pruner(node):
if node is not current_node:
try:
......@@ -1020,14 +1194,16 @@ class TopoOptimizer(NavigatorOptimizer):
self.detach_updater(env, u)
class OpKeyOptimizer(NavigatorOptimizer):
"""WRITEME"""
def __init__(self, local_opt, ignore_newtrees = False, failure_callback = None):
def __init__(self, local_opt, ignore_newtrees=False,
failure_callback=None):
if not hasattr(local_opt, 'op_key'):
raise TypeError("LocalOptimizer for OpKeyOptimizer must have an 'op_key' method.")
NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees, failure_callback)
raise TypeError("LocalOptimizer for OpKeyOptimizer must have "
"an 'op_key' method.")
NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees,
failure_callback)
def apply(self, env):
op = self.local_opt.op_key()
......@@ -1035,9 +1211,12 @@ class OpKeyOptimizer(NavigatorOptimizer):
q = reduce(list.__iadd__, map(env.get_nodes, op))
else:
q = list(env.get_nodes(op))
def importer(node):
if node is not current_node:
if node.op == op: q.append(node)
if node.op == op:
q.append(node)
def pruner(node):
if node is not current_node and node.op == op:
try:
......@@ -1065,7 +1244,6 @@ class OpKeyOptimizer(NavigatorOptimizer):
env.extend(toolbox.NodeFinder())
class ChangeTracker:
def __init__(self):
self.changed = False
......@@ -1082,17 +1260,19 @@ class ChangeTracker:
def on_attach(self, env):
env.change_tracker = self
class EquilibriumOptimizer(NavigatorOptimizer):
def __init__(self,
optimizers,
failure_callback = None,
max_depth = None,
max_use_ratio = None):
failure_callback=None,
max_depth=None,
max_use_ratio=None):
"""
:param optimizers: list or set of local or global optimizations to apply until
equilibrium.
:param optimizers: list or set of local or global optimizations to
apply until equilibrium.
:param max_use_ratio: each optimizer can be applied at most (size of graph * this number)
:param max_use_ratio: each optimizer can be applied at most
(size of graph * this number) times
:param max_depth: TODO what does this do? (EquilibriumDB sets it to 5)
......@@ -1100,8 +1280,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
super(EquilibriumOptimizer, self).__init__(
None,
ignore_newtrees = True,
failure_callback = failure_callback)
ignore_newtrees=True,
failure_callback=failure_callback)
self.local_optimizers = []
self.global_optimizers = []
......@@ -1112,13 +1292,18 @@ class EquilibriumOptimizer(NavigatorOptimizer):
self.global_optimizers.append(opt)
self.max_depth = max_depth
self.max_use_ratio = max_use_ratio
assert self.max_use_ratio is not None, 'max_use_ratio has to be a number'
assert self.max_use_ratio is not None, (
'max_use_ratio has to be a number')
def add_requirements(self, env):
super(EquilibriumOptimizer, self).add_requirements(env)
env.extend(ChangeTracker())
for opt in self.local_optimizers:
opt.add_requirements(env)
for opt in self.global_optimizers:
opt.add_requirements(env)
def apply(self, env, start_from = None):
def apply(self, env, start_from=None):
if start_from is None:
start_from = env.outputs
changed = True
......@@ -1153,9 +1338,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
nb_nodes.append(len(q))
max_nb_nodes = max(max_nb_nodes, len(q))
max_use = max_nb_nodes * self.max_use_ratio
def importer(node):
if node is not current_node:
q.append(node)
def pruner(node):
if node is not current_node:
try:
......@@ -1179,12 +1366,13 @@ class EquilibriumOptimizer(NavigatorOptimizer):
opt_name = (getattr(lopt, "name", None)
or getattr(lopt, "__name__", ""))
if node not in env.nodes:
break # go to next node
# go to next node
break
finally:
self.detach_updater(env, u)
self.detach_updater(env, u) #TODO: erase this line, it's redundant at best
loop_timing.append(float(time.time() - t0))
if max_use_abort:
_logger.error("EquilibriumOptimizer max'ed out by '%s'" % opt_name
+ ". You can safely raise the current threshold of "
......@@ -1216,10 +1404,12 @@ class EquilibriumOptimizer(NavigatorOptimizer):
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, 'name', None)
print >> stream, "%s%s %s id=%i" %(' '*level, self.__class__.__name__, name, id(self))
print >> stream, "%s%s %s id=%i" % (
(' ' * level), self.__class__.__name__, name, id(self))
if depth != 0:
for lopt in self.local_optimizers:
lopt.print_summary(stream, level=level+2, depth=depth-1)
lopt.print_summary(stream, level=(level + 2),
depth=(depth - 1))
#################
......@@ -1242,7 +1432,8 @@ def _check_chain(r, chain):
return False
else:
try:
if issubclass(elem, op.Op) and not isinstance(r.owner.op, elem):
if (issubclass(elem, op.Op)
and not isinstance(r.owner.op, elem)):
return False
except TypeError:
return False
......@@ -1256,6 +1447,7 @@ def _check_chain(r, chain):
return (r is not None)
#_check_chain.n_calls = 0
def check_chain(r, *chain):
"""WRITEME"""
if isinstance(r, graph.Apply):
......@@ -1280,7 +1472,7 @@ def pre_greedy_local_optimizer(list_optimizations, out):
add additional node to the inputs of the node, it can
be needed to call this function multiple time.
'''
def local_recursive_function( list_opt, out, optimized_vars, depth):
def local_recursive_function(list_opt, out, optimized_vars, depth):
if not getattr(out, 'owner', None):
return [out], optimized_vars
node = out.owner
......@@ -1292,11 +1484,11 @@ def pre_greedy_local_optimizer(list_optimizations, out):
else:
if inp.owner:
outs, optimized_vars = local_recursive_function(
list_opt
, inp
, optimized_vars
, depth+1)
for k,v in zip(inp.owner.outputs, outs):
list_opt,
inp,
optimized_vars,
depth + 1)
for k, v in zip(inp.owner.outputs, outs):
optimized_vars[k] = v
nw_in = outs[inp.owner.outputs.index(inp)]
......@@ -1310,10 +1502,10 @@ def pre_greedy_local_optimizer(list_optimizations, out):
ret = opt.transform(node)
if ret is not False and ret is not None:
assert len(ret) == len(node.outputs)
for k,v in zip(node.outputs, ret):
for k, v in zip(node.outputs, ret):
optimized_vars[k] = v
results = ret
if ret[0].owner :
if ret[0].owner:
node = out.owner
else:
break
......@@ -1324,8 +1516,6 @@ def pre_greedy_local_optimizer(list_optimizations, out):
return final_outs[0]
############
### Misc ###
############
......
......@@ -1823,7 +1823,6 @@ def local_subtensor_merge(node):
merged_slices.append(slice1)
pos_1 += 1
if pos_2 < len(slices2):
merged_slices += slices2[pos_2:]
else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论