提交 d4dfbf2a authored 作者: nouiz's avatar nouiz

Merge pull request #612 from lamblin/merge_feature_rebased

Merge feature (rebased)
"""WRITEME """WRITEME
""" """
import os, logging, warnings import logging
import numpy, theano import numpy
import theano
from theano import gof from theano import gof
import theano.gof.vm import theano.gof.vm
from theano.configparser import config, AddConfigVar, StrParam, EnumStr from theano.configparser import config, AddConfigVar, StrParam
_logger = logging.getLogger('theano.compile.mode') _logger = logging.getLogger('theano.compile.mode')
AddConfigVar('optimizer_excluding', AddConfigVar('optimizer_excluding',
"When using the default mode, we will remove optimizer with that tag. Separate many tags with ':'.", ("When using the default mode, we will remove optimizer with these "
"tags. Separate tags with ':'."),
StrParam("", allow_override=False), StrParam("", allow_override=False),
in_c_key=False) in_c_key=False)
AddConfigVar('optimizer_including', AddConfigVar('optimizer_including',
"When using the default mode, we will add optimizer with that tag. Separate many tags with ':'.", ("When using the default mode, we will add optimizer with these tags. "
"Separate tags with ':'."),
StrParam("", allow_override=False), StrParam("", allow_override=False),
in_c_key=False) in_c_key=False)
AddConfigVar('optimizer_requiring', AddConfigVar('optimizer_requiring',
"When using the default mode, we will require optimizer with that tag. Separate many tags with ':'.", ("When using the default mode, we will require optimizer with these "
"tags. Separate tags with ':'."),
StrParam("", allow_override=False), StrParam("", allow_override=False),
in_c_key=False) in_c_key=False)
def check_equal(x, y): def check_equal(x, y):
""" """
Returns True iff x[0] and y[0] are equal (checks the dtype and Returns True iff x[0] and y[0] are equal (checks the dtype and
...@@ -32,35 +38,37 @@ def check_equal(x, y): ...@@ -32,35 +38,37 @@ def check_equal(x, y):
import scipy.sparse as sp import scipy.sparse as sp
x, y = x[0], y[0] x, y = x[0], y[0]
# TODO: bug in current scipy, two sparse matrices are never equal, remove when moving to 0.7 # TODO: bug in current scipy, two sparse matrices are never equal,
# remove when moving to 0.7
if sp.issparse(x): if sp.issparse(x):
x = x.todense() x = x.todense()
if sp.issparse(y): if sp.issparse(y):
y = y.todense() y = y.todense()
if isinstance(x, numpy.ndarray) and isinstance(y, numpy.ndarray): if isinstance(x, numpy.ndarray) and isinstance(y, numpy.ndarray):
if x.dtype != y.dtype or x.shape != y.shape or numpy.any(abs(x - y) > 1e-10): if (x.dtype != y.dtype
raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y}) or x.shape != y.shape
or numpy.any(abs(x - y) > 1e-10)):
raise Exception("Output mismatch.",
{'performlinker': x, 'clinker': y})
else: else:
if x != y: if x != y:
raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y}) raise Exception("Output mismatch.",
{'performlinker': x, 'clinker': y})
# If a string is passed as the linker argument in the constructor for # If a string is passed as the linker argument in the constructor for
# Mode, it will be used as the key to retrieve the real linker in this # Mode, it will be used as the key to retrieve the real linker in this
# dictionary # dictionary
predefined_linkers = { predefined_linkers = {
'py' : gof.PerformLinker(), 'py': gof.PerformLinker(),
'c' : gof.CLinker(), 'c': gof.CLinker(),
'c|py' : gof.OpWiseCLinker(allow_gc=True), 'c|py': gof.OpWiseCLinker(allow_gc=True),
'c|py_nogc' : gof.OpWiseCLinker(allow_gc=False), 'c|py_nogc': gof.OpWiseCLinker(allow_gc=False),
'c&py' : gof.DualLinker(checker = check_equal), 'c&py': gof.DualLinker(checker=check_equal),
'vm' : gof.vm.VM_Linker(allow_gc=True, use_cloop=False), 'vm': gof.vm.VM_Linker(allow_gc=True, use_cloop=False),
'cvm' : gof.vm.VM_Linker(allow_gc=True, use_cloop=True), 'cvm': gof.vm.VM_Linker(allow_gc=True, use_cloop=True),
'vm_nogc' : gof.vm.VM_Linker(allow_gc=False, use_cloop=False), 'vm_nogc': gof.vm.VM_Linker(allow_gc=False, use_cloop=False),
'cvm_nogc': gof.vm.VM_Linker(allow_gc=False, use_cloop=True), 'cvm_nogc': gof.vm.VM_Linker(allow_gc=False, use_cloop=True),
} }
...@@ -72,37 +80,37 @@ def register_linker(name, linker): ...@@ -72,37 +80,37 @@ def register_linker(name, linker):
predefined_linkers[name] = linker predefined_linkers[name] = linker
# If a string is passed as the optimizer argument in the constructor # If a string is passed as the optimizer argument in the constructor
# for Mode, it will be used as the key to retrieve the real optimizer # for Mode, it will be used as the key to retrieve the real optimizer
# in this dictionary # in this dictionary
OPT_FAST_RUN = gof.Query(include = ['fast_run']) OPT_FAST_RUN = gof.Query(include=['fast_run'])
OPT_FAST_RUN_STABLE = OPT_FAST_RUN.requiring('stable') OPT_FAST_RUN_STABLE = OPT_FAST_RUN.requiring('stable')
OPT_FAST_COMPILE = gof.Query(include = ['fast_compile']) OPT_FAST_COMPILE = gof.Query(include=['fast_compile'])
OPT_STABILIZE = gof.Query(include = ['fast_run']) OPT_STABILIZE = gof.Query(include=['fast_run'])
OPT_STABILIZE.position_cutoff = 1.5000001 OPT_STABILIZE.position_cutoff = 1.5000001
predefined_optimizers = { predefined_optimizers = {
None : lambda env: None, None: (lambda env: None),
'None' : lambda env: None, 'None': (lambda env: None),
'merge' : gof.MergeOptimizer(), 'merge': gof.MergeOptimizer(),
'fast_run' : OPT_FAST_RUN, 'fast_run': OPT_FAST_RUN,
'fast_run_stable' : OPT_FAST_RUN_STABLE, 'fast_run_stable': OPT_FAST_RUN_STABLE,
'fast_compile' : OPT_FAST_COMPILE, 'fast_compile': OPT_FAST_COMPILE,
'stabilize': OPT_STABILIZE 'stabilize': OPT_STABILIZE
} }
def register_optimizer(name, opt): def register_optimizer(name, opt):
"""Add a `Optimizer` which can be referred to by `name` in `Mode`.""" """Add a `Optimizer` which can be referred to by `name` in `Mode`."""
if name in predefined_optimizers: if name in predefined_optimizers:
raise ValueError('Optimizer name already taken: %s' % name) raise ValueError('Optimizer name already taken: %s' % name)
predefined_optimizers[name] = opt predefined_optimizers[name] = opt
def register_OutputGuard_c_code(type): def register_OutputGuard_c_code(type):
OutputGuard.c_code_types.append(type) OutputGuard.c_code_types.append(type)
class OutputGuard(gof.Op): class OutputGuard(gof.Op):
""" """
This op is used only internally by Theano. This op is used only internally by Theano.
...@@ -120,20 +128,24 @@ class OutputGuard(gof.Op): ...@@ -120,20 +128,24 @@ class OutputGuard(gof.Op):
TODO: find a current full explanation. TODO: find a current full explanation.
""" """
destroy_map = {0:[0]} destroy_map = {0: [0]}
view_map = {0:[0]} view_map = {0: [0]}
c_code_types = [] c_code_types = []
def make_node(self, x): def make_node(self, x):
return gof.Apply(self, [x], [x.type()]) return gof.Apply(self, [x], [x.type()])
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def perform(self, node, inp, out): def perform(self, node, inp, out):
x, = inp x, = inp
z, = out z, = out
z[0] = x z[0] = x
def __str__(self): def __str__(self):
return '%s' % self.__class__.__name__ return '%s' % self.__class__.__name__
...@@ -141,7 +153,8 @@ class OutputGuard(gof.Op): ...@@ -141,7 +153,8 @@ class OutputGuard(gof.Op):
x, = inp x, = inp
z, = out z, = out
if isinstance(node.inputs[0].type, theano.scalar.Scalar): if isinstance(node.inputs[0].type, theano.scalar.Scalar):
# Scalars are C objects on the stacks, and should not be inc/decrefed # Scalars are C objects on the stack,
# and should not be inc/decrefed
return """ return """
%(z)s = %(x)s; %(z)s = %(x)s;
""" % locals() """ % locals()
...@@ -161,71 +174,99 @@ class OutputGuard(gof.Op): ...@@ -161,71 +174,99 @@ class OutputGuard(gof.Op):
_output_guard = OutputGuard() _output_guard = OutputGuard()
class AddDestroyHandler(gof.Optimizer): class AddDestroyHandler(gof.Optimizer):
"""This optimizer performs two important functions: """This optimizer performs two important functions:
1) it has a 'requirement' of the destroyhandler. This means that the env will include it 1) it has a 'requirement' of the destroyhandler. This means that the env
as a feature for this optimization, and keep this feature enabled for subsequent will include it as a feature for this optimization, and keep this feature
optimizations. All optimizations that work inplace on any of their inputs must run *after* enabled for subsequent optimizations. All optimizations that work inplace
this optimization to ensure that the DestroyHandler has been included in the env. on any of their inputs must run *after* this optimization to ensure that
the DestroyHandler has been included in the env.
2) It tries to replace each output with an Op that purports to destroy it (but it won't I 2) It tries to replace each output with an Op that purports to destroy it
promise). If this replacement succeeds it means that there is a bug in theano. It should (but it won't I promise). If this replacement succeeds it means that
not be possible to destroy outputs. there is a bug in theano. It should not be possible to destroy outputs.
""" """
def apply(self, env): def apply(self, env):
for o in env.outputs: for o in env.outputs:
try: try:
env.replace_validate(o, _output_guard(o), reason='output_guard') env.replace_validate(o, _output_guard(o),
_logger.info("Output variable %s required output_guard," reason='output_guard')
" how was this output left unprotected against destructive operations?" _logger.info("Output variable %s required output_guard, "
"how was this output left unprotected against "
"destructive operations?"
% o) % o)
except gof.InconsistencyError: except gof.InconsistencyError:
#this output is already impossible to destroy. no guard necessary # This output is already impossible to destroy.
# No guard necessary
pass pass
def add_requirements(self, env): def add_requirements(self, env):
super(AddDestroyHandler, self).add_requirements(env) super(AddDestroyHandler, self).add_requirements(env)
env.extend(gof.DestroyHandler()) env.extend(gof.DestroyHandler())
class PrintCurrentEnv(gof.Optimizer): class PrintCurrentEnv(gof.Optimizer):
"""This optimizer is for debugging. """This optimizer is for debugging.
Toss it into the optimization pipeline to see the state of things at any given point. Toss it into the optimization pipeline to see the state of things at any
given point.
""" """
def __init__(self, header): def __init__(self, header):
self.header =header self.header = header
def apply(self, env): def apply(self, env):
import theano.printing import theano.printing
print "PrintCurrentEnv:", self.header print "PrintCurrentEnv:", self.header
theano.printing.debugprint(env.outputs) theano.printing.debugprint(env.outputs)
optdb = gof.SequenceDB() optdb = gof.SequenceDB()
optdb.register('merge1', gof.MergeOptimizer(), optdb.register('merge1', gof.MergeOptimizer(),
0, 'fast_run', 'fast_compile') 0, 'fast_run', 'fast_compile')
optdb.register('canonicalize', gof.EquilibriumDB(), # rearranges elemwise expressions
# rearranges elemwise expressions
optdb.register('canonicalize', gof.EquilibriumDB(),
1, 'fast_run', 'fast_compile') 1, 'fast_run', 'fast_compile')
optdb.register('merge1.2', gof.MergeOptimizer(skip_const_merge=False),
optdb.register('merge1.2', gof.MergeOptimizer(),
1.2, 'fast_run', 'fast_compile') 1.2, 'fast_run', 'fast_compile')
optdb.register('Print1.21', PrintCurrentEnv('Post-canonicalize'), optdb.register('Print1.21', PrintCurrentEnv('Post-canonicalize'),
1.21,)# 'fast_run', 'fast_compile') 1.21,) # 'fast_run', 'fast_compile')
optdb.register('stabilize', gof.EquilibriumDB(), # replace unstable subgraphs # replace unstable subgraphs
optdb.register('stabilize', gof.EquilibriumDB(),
1.5, 'fast_run') 1.5, 'fast_run')
optdb.register('Print1.51', PrintCurrentEnv('Post-stabilize'), optdb.register('Print1.51', PrintCurrentEnv('Post-stabilize'),
1.51,) #'fast_run', 'fast_compile') 1.51,) # 'fast_run', 'fast_compile')
optdb.register('specialize', gof.EquilibriumDB(), # misc special cases for speed
# misc special cases for speed
optdb.register('specialize', gof.EquilibriumDB(),
2, 'fast_run') 2, 'fast_run')
optdb.register('Print2.01', PrintCurrentEnv('Post-specialize'), optdb.register('Print2.01', PrintCurrentEnv('Post-specialize'),
2.01, )#'fast_run', 'fast_compile') 2.01,) # 'fast_run', 'fast_compile')
optdb.register('uncanonicalize', gof.EquilibriumDB(),# misc special cases for speed that break canonicalization
# misc special cases for speed that break canonicalization
optdb.register('uncanonicalize', gof.EquilibriumDB(),
3, 'fast_run') 3, 'fast_run')
optdb.register('specialize_device', gof.EquilibriumDB(), # misc special cases for speed that are dependent on the device.
48.6, 'fast_run')#must be after gpu stuff at 48.5 # misc special cases for speed that are dependent on the device.
optdb.register('merge2', gof.MergeOptimizer(), # especially constant merge optdb.register('specialize_device', gof.EquilibriumDB(),
48.6, 'fast_run') # must be after gpu stuff at 48.5
# especially constant merge
optdb.register('merge2', gof.MergeOptimizer(),
49, 'fast_run') 49, 'fast_run')
optdb.register('add_destroy_handler', AddDestroyHandler(), optdb.register('add_destroy_handler', AddDestroyHandler(),
49.5, 'fast_run', 'inplace') 49.5, 'fast_run', 'inplace')
optdb.register('merge3', gof.MergeOptimizer(), # final pass just to make sure
# final pass just to make sure
optdb.register('merge3', gof.MergeOptimizer(),
100, 'fast_run') 100, 'fast_run')
...@@ -251,12 +292,15 @@ class Mode(object): ...@@ -251,12 +292,15 @@ class Mode(object):
if optimizer is None: if optimizer is None:
optimizer = config.optimizer optimizer = config.optimizer
self.__setstate__((linker, optimizer)) self.__setstate__((linker, optimizer))
#self.provided_optimizer - typically the `optimizer` arg. But if the `optimizer` arg is
# keyword corresponding to a predefined Query, then this stores the query
#self._optimizer - typically same as provided_optimizer??
#self.__get_optimizer - returns self._optimizer (possibly querying optdb with self._optimizer) # self.provided_optimizer - typically the `optimizer` arg.
#self.optimizer - property that returns __get_optimizer() # But if the `optimizer` arg is keyword corresponding to a predefined
# Query, then this stores the query
# self._optimizer - typically same as provided_optimizer??
# self.__get_optimizer - returns self._optimizer (possibly querying
# optdb with self._optimizer)
# self.optimizer - property that returns __get_optimizer()
def __getstate__(self): def __getstate__(self):
return (self.provided_linker, self.provided_optimizer) return (self.provided_linker, self.provided_optimizer)
...@@ -275,12 +319,13 @@ class Mode(object): ...@@ -275,12 +319,13 @@ class Mode(object):
self._optimizer = optimizer self._optimizer = optimizer
self.call_time = 0 self.call_time = 0
self.fn_time = 0 self.fn_time = 0
linker.mode = self #TODO: WHY IS THIS HERE? linker.mode = self # TODO: WHY IS THIS HERE?
self.optimizer_time = 0 self.optimizer_time = 0
self.linker_time = 0 self.linker_time = 0
def __str__(self): def __str__(self):
return "Mode(linker = %s, optimizer = %s)" % (self.provided_linker, self.provided_optimizer) return "Mode(linker = %s, optimizer = %s)" % (
self.provided_linker, self.provided_optimizer)
def __get_optimizer(self): def __get_optimizer(self):
if isinstance(self._optimizer, gof.Query): if isinstance(self._optimizer, gof.Query):
...@@ -298,17 +343,20 @@ class Mode(object): ...@@ -298,17 +343,20 @@ class Mode(object):
return (linker, optimizer) return (linker, optimizer)
def including(self, *tags): def including(self, *tags):
link, opt = self.get_linker_optimizer(self.provided_linker, self.provided_optimizer) link, opt = self.get_linker_optimizer(self.provided_linker,
self.provided_optimizer)
#N.B. opt might be a Query instance, not sure what else it might be... #N.B. opt might be a Query instance, not sure what else it might be...
# string? Optimizer? OptDB? who knows??? # string? Optimizer? OptDB? who knows???
return self.__class__(linker=link, optimizer=opt.including(*tags)) return self.__class__(linker=link, optimizer=opt.including(*tags))
def excluding(self, *tags): def excluding(self, *tags):
link, opt = self.get_linker_optimizer(self.provided_linker, self.provided_optimizer) link, opt = self.get_linker_optimizer(self.provided_linker,
self.provided_optimizer)
return self.__class__(linker=link, optimizer=opt.excluding(*tags)) return self.__class__(linker=link, optimizer=opt.excluding(*tags))
def requiring(self, *tags): def requiring(self, *tags):
link, opt = self.get_linker_optimizer(self.provided_linker, self.provided_optimizer) link, opt = self.get_linker_optimizer(self.provided_linker,
self.provided_optimizer)
return self.__class__(linker=link, optimizer=opt.requiring(*tags)) return self.__class__(linker=link, optimizer=opt.requiring(*tags))
# If a string is passed as the mode argument in function or # If a string is passed as the mode argument in function or
...@@ -321,20 +369,22 @@ predefined_modes = {'FAST_COMPILE': FAST_COMPILE, ...@@ -321,20 +369,22 @@ predefined_modes = {'FAST_COMPILE': FAST_COMPILE,
'FAST_RUN': FAST_RUN, 'FAST_RUN': FAST_RUN,
} }
instanciated_default_mode=None instanciated_default_mode = None
def get_mode(orig_string): def get_mode(orig_string):
if orig_string is None: if orig_string is None:
string = config.mode string = config.mode
else: else:
string = orig_string string = orig_string
if not isinstance(string, basestring): if not isinstance(string, basestring):
return string #it is hopefully already a mode... return string # it is hopefully already a mode...
global instanciated_default_mode global instanciated_default_mode
# The default mode is cached. However, config.mode can change # The default mode is cached. However, config.mode can change
# If instanciated_default_mode has the right class, use it. # If instanciated_default_mode has the right class, use it.
if orig_string is None and instanciated_default_mode: if orig_string is None and instanciated_default_mode:
if predefined_modes.has_key(string): if string in predefined_modes:
default_mode_class = predefined_modes[string].__class__.__name__ default_mode_class = predefined_modes[string].__class__.__name__
else: else:
default_mode_class = string default_mode_class = string
...@@ -342,7 +392,7 @@ def get_mode(orig_string): ...@@ -342,7 +392,7 @@ def get_mode(orig_string):
default_mode_class): default_mode_class):
return instanciated_default_mode return instanciated_default_mode
if string in ['Mode','ProfileMode','DebugMode']: if string in ['Mode', 'ProfileMode', 'DebugMode']:
if string == 'DebugMode': if string == 'DebugMode':
#need to import later to break circular dependency. #need to import later to break circular dependency.
from debugmode import DebugMode from debugmode import DebugMode
...@@ -350,12 +400,13 @@ def get_mode(orig_string): ...@@ -350,12 +400,13 @@ def get_mode(orig_string):
ret = DebugMode(optimizer=config.optimizer) ret = DebugMode(optimizer=config.optimizer)
else: else:
# The import is needed in case string is ProfileMode # The import is needed in case string is ProfileMode
from profilemode import ProfileMode,prof_mode_instance_to_print from profilemode import ProfileMode, prof_mode_instance_to_print
ret = eval(string+'(linker=config.linker, optimizer=config.optimizer)') ret = eval(string
elif predefined_modes.has_key(string): + '(linker=config.linker, optimizer=config.optimizer)')
elif string in predefined_modes:
ret = predefined_modes[string] ret = predefined_modes[string]
else: else:
raise Exception("No predefined mode exist for string: %s"%string) raise Exception("No predefined mode exist for string: %s" % string)
if orig_string is None: if orig_string is None:
# Build and cache the default mode # Build and cache the default mode
...@@ -374,12 +425,14 @@ def get_mode(orig_string): ...@@ -374,12 +425,14 @@ def get_mode(orig_string):
return ret return ret
def get_default_mode(): def get_default_mode():
return get_mode(None) return get_mode(None)
# Removed: use config.mode instead. # Removed: use config.mode instead.
#default_mode = config.mode #default_mode = config.mode
def register_mode(name, mode): def register_mode(name, mode):
"""Add a `Mode` which can be referred to by `name` in `function`.""" """Add a `Mode` which can be referred to by `name` in `function`."""
if name in predefined_modes: if name in predefined_modes:
......
...@@ -3,21 +3,23 @@ Defines the base class for optimizations as well as a certain ...@@ -3,21 +3,23 @@ Defines the base class for optimizations as well as a certain
amount of useful generic optimization tools. amount of useful generic optimization tools.
""" """
import copy
import copy, logging, sys, time import logging
import sys
import time
import numpy import numpy
import graph import graph
from env import InconsistencyError from env import InconsistencyError
import op
import utils import utils
import unify import unify
import toolbox import toolbox
import op
import theano import theano
from theano import config from theano import config
from theano.gof.python25 import any, all, deque from theano.gof.python25 import any, all, deque
from theano.configparser import AddConfigVar, BoolParam, config from theano.configparser import AddConfigVar, BoolParam
#if sys.version_info[:2] >= (2,5): #if sys.version_info[:2] >= (2,5):
# from collections import defaultdict # from collections import defaultdict
...@@ -39,9 +41,11 @@ import traceback ...@@ -39,9 +41,11 @@ import traceback
_optimizer_idx = [0] _optimizer_idx = [0]
def _list_of_nodes(env): def _list_of_nodes(env):
return list(graph.io_toposort(env.inputs, env.outputs)) return list(graph.io_toposort(env.inputs, env.outputs))
class Optimizer(object): class Optimizer(object):
"""WRITEME """WRITEME
An L{Optimizer} can be applied to an L{Env} to transform it. An L{Optimizer} can be applied to an L{Env} to transform it.
...@@ -91,26 +95,30 @@ class Optimizer(object): ...@@ -91,26 +95,30 @@ class Optimizer(object):
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, 'name', None) name = getattr(self, 'name', None)
print >> stream, "%s%s %s id=%i" %(' '*level, self.__class__.__name__, print >> stream, "%s%s %s id=%i" % (
name, id(self)) (' ' * level), self.__class__.__name__, name, id(self))
class FromFunctionOptimizer(Optimizer): class FromFunctionOptimizer(Optimizer):
"""WRITEME""" """WRITEME"""
def __init__(self, fn): def __init__(self, fn):
self.apply = fn self.apply = fn
def add_requirements(self, env): def add_requirements(self, env):
# Added by default # Added by default
#env.extend(toolbox.ReplaceValidate()) #env.extend(toolbox.ReplaceValidate())
pass pass
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print >> stream, "%s%s id=%i" %(' '*level, print >> stream, "%s%s id=%i" % (
' ' * level,
str(self.apply), str(self.apply),
id(self)) id(self))
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.fn(*args, **kwargs) return self.fn(*args, **kwargs)
def optimizer(f): def optimizer(f):
"""decorator for FromFunctionOptimizer""" """decorator for FromFunctionOptimizer"""
rval = FromFunctionOptimizer(f) rval = FromFunctionOptimizer(f)
...@@ -118,7 +126,6 @@ def optimizer(f): ...@@ -118,7 +126,6 @@ def optimizer(f):
return rval return rval
class SeqOptimizer(Optimizer, list): class SeqOptimizer(Optimizer, list):
#inherit from Optimizer first to get Optimizer.__hash__ #inherit from Optimizer first to get Optimizer.__hash__
"""WRITEME """WRITEME
...@@ -129,7 +136,7 @@ class SeqOptimizer(Optimizer, list): ...@@ -129,7 +136,7 @@ class SeqOptimizer(Optimizer, list):
def warn(exc, self, optimizer): def warn(exc, self, optimizer):
"""Default failure_callback for SeqOptimizer """Default failure_callback for SeqOptimizer
""" """
_logger.error("SeqOptimizer apply %s"% str(optimizer)) _logger.error("SeqOptimizer apply %s" % str(optimizer))
_logger.error("Traceback:") _logger.error("Traceback:")
_logger.error(traceback.format_exc()) _logger.error(traceback.format_exc())
if config.on_opt_error == 'raise': if config.on_opt_error == 'raise':
...@@ -146,14 +153,15 @@ class SeqOptimizer(Optimizer, list): ...@@ -146,14 +153,15 @@ class SeqOptimizer(Optimizer, list):
"""WRITEME """WRITEME
Applies each L{Optimizer} in self in turn. Applies each L{Optimizer} in self in turn.
""" """
l=[] l = []
nb_node_before = len(env.nodes) nb_node_before = len(env.nodes)
for optimizer in self: for optimizer in self:
try: try:
t0=time.time() t0 = time.time()
optimizer.optimize(env) optimizer.optimize(env)
l.append(float(time.time()-t0)) l.append(float(time.time() - t0))
except AssertionError: # do not catch Assertion failures except AssertionError:
# do not catch Assertion failures
raise raise
except Exception, e: except Exception, e:
if self.failure_callback: if self.failure_callback:
...@@ -192,7 +200,6 @@ class SeqOptimizer(Optimizer, list): ...@@ -192,7 +200,6 @@ class SeqOptimizer(Optimizer, list):
#added to override the list's __neq__ implementation #added to override the list's __neq__ implementation
return id(self) != id(other) return id(self) != id(other)
def __str__(self): def __str__(self):
return "SeqOpt(%s)" % list.__str__(self) return "SeqOpt(%s)" % list.__str__(self)
...@@ -201,14 +208,13 @@ class SeqOptimizer(Optimizer, list): ...@@ -201,14 +208,13 @@ class SeqOptimizer(Optimizer, list):
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, 'name', None) name = getattr(self, 'name', None)
print >> stream, "%s%s %s id=%i" %(' '*level, self.__class__.__name__, name, id(self)) print >> stream, "%s%s %s id=%i" % (
(' ' * level), self.__class__.__name__, name, id(self))
# This way, -1 will do all depth # This way, -1 will do all depth
if depth != 0: if depth != 0:
depth -= 1 depth -= 1
for opt in self: for opt in self:
opt.print_summary(stream, level=level+2, depth=depth) opt.print_summary(stream, level=(level + 2), depth=depth)
class _metadict: class _metadict:
...@@ -219,17 +225,39 @@ class _metadict: ...@@ -219,17 +225,39 @@ class _metadict:
def __init__(self): def __init__(self):
self.d = {} self.d = {}
self.l = [] self.l = []
def __getitem__(self, item): def __getitem__(self, item):
return self.get(item, None) return self.get(item, None)
def __setitem__(self, item, value): def __setitem__(self, item, value):
try: try:
self.d[item] = value self.d[item] = value
except Exception: except Exception:
for i, (key,val) in enumerate(self.l): for i, (key, val) in enumerate(self.l):
if key == item: if key == item:
self.l[i] = (item, value) self.l[i] = (item, value)
return return
self.l.append((item, value)) self.l.append((item, value))
def __delitem__(self, item):
if item in self.d:
del self.d[item]
else:
for i, (key, val) in enumerate(self.l):
if key == item:
del self.l[i]
return
raise KeyError(item)
def discard(self, item):
if item in self.d:
del self.d[item]
else:
for i, (key, val) in enumerate(self.l):
if key == item:
del self.l[i]
return
def get(self, item, default): def get(self, item, default):
try: try:
return self.d[item] return self.d[item]
...@@ -245,13 +273,148 @@ class _metadict: ...@@ -245,13 +273,148 @@ class _metadict:
return value return value
else: else:
return default return default
def clear(self): def clear(self):
self.d = {} self.d = {}
self.l = [] self.l = []
def __str__(self): def __str__(self):
return "(%s, %s)" % (self.d, self.l) return "(%s, %s)" % (self.d, self.l)
class MergeFeature(object):
"""
Keeps track of variables in env that cannot be merged together.
That way, the MergeOptimizer can remember the result of the last merge
pass on the env.
"""
def on_attach(self, env):
assert not hasattr(env, 'merge_feature')
env.merge_feature = self
## For constants
self.seen_constants = set()
# variable -> signature (for constants)
self.const_sig = _metadict()
# signature -> variable (for constants)
self.const_sig_inv = _metadict()
## For all variables
# Set of distinct (not mergeable) nodes
self.nodes_seen = set()
# Each element of scheduled is a list of list of (out, new_out) pairs.
# Each list of pairs represent the substitution needed to replace all
# the outputs of a node with the outputs of a replacement candidate.
# Each node can have several candidates. For instance, if "node" has
# 2 outputs, and there are 3 replacement candidates, we will have:
# shelf.scheduled = [
# [[(node.out1, cand1.out1), (node.out2, cand1.out2)],
# [(node.out1, cand2.out1), (node.out2, cand2.out2)],
# [(node.out1, cand3.out1), (node.out2, cand3.out2)]]]
self.scheduled = []
# List of (node, candidate) pairs, where we tried to replace node by
# candidate, but it failed. This is used to avoid infinite loops
# during the replacement phase.
self.blacklist = []
for node in env.toposort():
self.on_import(env, node)
def on_change_input(self, env, node, i, r, new_r):
# If inputs to node change, it is not guaranteed that it is distinct
# from the other nodes in nodes_seen
if node in self.nodes_seen:
self.nodes_seen.discard(node)
self.process_node(env, node)
if isinstance(new_r, graph.Constant):
self.process_constant(env, new_r)
def on_import(self, env, node):
for c in node.inputs:
if isinstance(c, graph.Constant):
self.process_constant(env, c)
self.process_node(env, node)
def on_prune(self, env, node):
self.nodes_seen.discard(node)
for c in node.inputs:
if isinstance(c, graph.Constant) and (len(c.clients) <= 1):
# This was the last node using this constant
sig = self.const_sig[c]
self.const_sig.discard(c)
self.const_sig_inv.discard(sig)
self.seen_constants.discard(id(c))
def process_constant(self, env, c):
"""Check if a constant can be merged, and queue that replacement"""
if id(c) in self.seen_constants:
return
sig = c.signature()
other_c = self.const_sig_inv.get(sig, None)
if other_c is not None:
# multiple names will clobber each other..
# we adopt convention to keep the last name
if c.name:
other_c.name = c.name
self.scheduled.append([[(c, other_c)]])
else:
#this is a new constant
self.const_sig[c] = sig
self.const_sig_inv[sig] = c
self.seen_constants.add(id(c))
def process_node(self, env, node):
"""Check if a node can be merged, and queue that replacement."""
if node in self.nodes_seen:
return
# These asserts ensure that the env has set the clients field properly.
# The clients should at least contain `node` itself!
if node.inputs:
assert len(node.inputs[0].clients) > 0
assert (node, 0) in node.inputs[0].clients
merge_candidates = [c for (c, i) in node.inputs[0].clients
if c in self.nodes_seen]
else:
merge_candidates = []
replacement_candidates = []
for candidate in merge_candidates:
if candidate is node:
continue
if len(node.inputs) != len(candidate.inputs):
continue
inputs_match = all(node_in is cand_in
for node_in, cand_in in zip(node.inputs, candidate.inputs))
if inputs_match and node.op == candidate.op:
if (node, candidate) in self.blacklist:
# They were already tried, and there was an error
continue
# Schedule transfer of clients from node to candidate
pairs = zip(node.outputs, candidate.outputs)
#transfer names
for node_output, cand_output in pairs:
#clobber old name with new one
#it's arbitrary... one of the names has to go
if node_output.name:
cand_output.name = node_output.name
replacement_candidates.append(pairs)
if replacement_candidates:
self.scheduled.append(replacement_candidates)
else:
self.nodes_seen.add(node)
class MergeOptimizer(Optimizer): class MergeOptimizer(Optimizer):
""" """
Merges parts of the graph that are identical and redundant. Merges parts of the graph that are identical and redundant.
...@@ -264,94 +427,32 @@ class MergeOptimizer(Optimizer): ...@@ -264,94 +427,32 @@ class MergeOptimizer(Optimizer):
The first step of merging is constant-merging, so that all clients of an The first step of merging is constant-merging, so that all clients of an
int(1) for example, are transferred to a particular instance of int(1). int(1) for example, are transferred to a particular instance of int(1).
""" """
def __init__(self, skip_const_merge=False):
self.skip_const_merge = skip_const_merge
def add_requirements(self, env): def add_requirements(self, env):
# Added by default # Added by default
#env.extend(toolbox.ReplaceValidate()) #env.extend(toolbox.ReplaceValidate())
pass if not hasattr(env, 'merge_feature'):
env.extend(MergeFeature())
def apply_constant_merge(self, env):
seen_constants = set()
const_sig = _metadict() # variable -> variable.signature() (for constants)
const_sig_inv = _metadict() # signature -> variable (for constants)
for node in _list_of_nodes(env):
for i, c in enumerate([r for r in node.inputs if isinstance(r, graph.Constant)]):
if id(c) in seen_constants:
continue
else:
seen_constants.add(id(c))
sig = c.signature()
other_c = const_sig_inv.get(sig, None)
if other_c is not None:
# multiple names will clobber each other..
# we adopt convention to keep the last name
if c.name:
other_c.name = c.name
env.replace_validate(c, other_c, reason='Constant Merge')
else:
#this is a new constant
const_sig[c] = sig
const_sig_inv[sig] = c
def apply_node_merge(self, env):
# we clear the dicts because the Constants signatures are not necessarily hashable
# and it's more efficient to give them an integer like the other Variables
nodes_seen = {}
for node_idx, node in enumerate(_list_of_nodes(env)):
#
# these asserts ensure that the env has set the clients field properly the clients
# should at least contain `node` itself!
#
if node.inputs:
assert len(node.inputs[0].clients) > 0
assert (node,0) in node.inputs[0].clients
merge_candidates = [(nodes_seen[c],c) for (c,i) in node.inputs[0].clients if c in nodes_seen]
else:
merge_candidates = []
merge_candidates.sort()
nodes_seen[node] = node_idx
#print 'NODE', node, merge_candidates, node.inputs[0].clients
for candidate_idx, candidate in merge_candidates:
if len(node.inputs) != len(candidate.inputs):
continue
inputs_match = all(node_in is cand_in for node_in, cand_in in zip(node.inputs, candidate.inputs))
if inputs_match and node.op == candidate.op:
assert node is not candidate
#
#transfer clients from node to candidate
#
success = True
assert len(node.outputs) == len(candidate.outputs)
pairs = zip(node.outputs, candidate.outputs)
#transfer names
for node_output, cand_output in pairs:
#clobber old name with new one
#it's arbitrary... one of the names has to go
if node_output.name:
cand_output.name = node_output.name
try:
env.replace_all_validate(pairs, reason="Merge")
except InconsistencyError, e:
success = False
if success:
#break out of the candidate loop
break
else:
#try the next candidate
pass
#TODO: Consider splitting this into a separate optimizer (SeqOptimizer)
def apply(self, env): def apply(self, env):
if not self.skip_const_merge: # Constant and non-constant are now applied in the same phase.
self.apply_constant_merge(env) # I am not sure why, but it seems to be faster this way.
self.apply_node_merge(env) sched = env.merge_feature.scheduled
while sched:
pairs_list = sched.pop()
success = True
for pairs in pairs_list:
try:
env.replace_all_validate(pairs, 'Merge')
except InconsistencyError:
success = False
env.merge_feature.blacklist.append(
(pairs[0][0].owner, pairs[0][1].owner))
if success:
break
# clear blacklist
env.merge_feature.blacklist = []
merge_optimizer = MergeOptimizer() merge_optimizer = MergeOptimizer()
...@@ -417,8 +518,9 @@ def pre_constant_merge(vars): ...@@ -417,8 +518,9 @@ def pre_constant_merge(vars):
""" """
seen_var = set() seen_var = set()
const_sig = {} # variable -> variable.signature() (for constants) # signature -> variable (for constants)
const_sig_inv = {} # signature -> variable (for constants) const_sig_inv = {}
def recursive_merge(var): def recursive_merge(var):
if var in seen_var: if var in seen_var:
return var return var
...@@ -434,12 +536,13 @@ def pre_constant_merge(vars): ...@@ -434,12 +536,13 @@ def pre_constant_merge(vars):
const_sig_inv[sig] = var const_sig_inv[sig] = var
return var return var
if var.owner: if var.owner:
for idx,inp in enumerate(var.owner.inputs): for idx, inp in enumerate(var.owner.inputs):
var.owner.inputs[idx] = recursive_merge(inp) var.owner.inputs[idx] = recursive_merge(inp)
return var return var
return map(recursive_merge, vars) return map(recursive_merge, vars)
######################## ########################
### Local Optimizers ### ### Local Optimizers ###
######################## ########################
...@@ -463,25 +566,31 @@ class LocalOptimizer(object): ...@@ -463,25 +566,31 @@ class LocalOptimizer(object):
Subclasses should implement this function so that it returns one of two Subclasses should implement this function so that it returns one of two
kinds of things: kinds of things:
- False to indicate that no optimization can be applied to this `node`; or - False to indicate that no optimization can be applied to this `node`;
or
- <list of variables> to use in place of `node`'s outputs in the greater graph. - <list of variables> to use in place of `node`'s outputs in the
greater graph.
:type node: an Apply instance :type node: an Apply instance
""" """
raise utils.MethodNotDefined("transform", type(self), self.__class__.__name__) raise utils.MethodNotDefined("transform",
type(self), self.__class__.__name__)
def add_requirements(self, env): def add_requirements(self, env):
"""If this local optimization wants to add some requirements to the env, """
This is the place to do it.""" If this local optimization wants to add some requirements to the env,
This is the place to do it.
"""
# Added by default # Added by default
#env.extend(toolbox.ReplaceValidate()) #env.extend(toolbox.ReplaceValidate())
pass pass
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print >> stream, "%s%s id=%i" %(' '*level, self.__class__.__name__, id(self)) print >> stream, "%s%s id=%i" % (
(' ' * level), self.__class__.__name__, id(self))
class FromFunctionLocalOptimizer(LocalOptimizer): class FromFunctionLocalOptimizer(LocalOptimizer):
"""WRITEME""" """WRITEME"""
...@@ -490,15 +599,21 @@ class FromFunctionLocalOptimizer(LocalOptimizer): ...@@ -490,15 +599,21 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
tracks = [] tracks = []
self.transform = fn self.transform = fn
self._tracks = tracks self._tracks = tracks
def tracks(self): def tracks(self):
return self._tracks return self._tracks
def __str__(self): def __str__(self):
return getattr(self, '__name__', '<FromFunctionLocalOptimizer instance>') return getattr(self, '__name__',
'<FromFunctionLocalOptimizer instance>')
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print >> stream, "%s%s id=%i" %(' '*level, print >> stream, "%s%s id=%i" % (
' ' * level,
str(self.transform), str(self.transform),
id(self)) id(self))
def local_optimizer(*tracks): def local_optimizer(*tracks):
def decorator(f): def decorator(f):
"""WRITEME""" """WRITEME"""
...@@ -513,11 +628,15 @@ class LocalOptGroup(LocalOptimizer): ...@@ -513,11 +628,15 @@ class LocalOptGroup(LocalOptimizer):
def __init__(self, *optimizers): def __init__(self, *optimizers):
self.opts = optimizers self.opts = optimizers
self.reentrant = any(getattr(opt, 'reentrant', True) for opt in optimizers) self.reentrant = any(getattr(opt, 'reentrant', True)
self.retains_inputs = all(getattr(opt, 'retains_inputs', False) for opt in optimizers) for opt in optimizers)
self.retains_inputs = all(getattr(opt, 'retains_inputs', False)
for opt in optimizers)
def __str__(self): def __str__(self):
return getattr(self, '__name__', '<theano.gof.opt.LocalOptGroup instance>'+str([str(o) for o in self.opts])) return getattr(self, '__name__',
('<theano.gof.opt.LocalOptGroup instance>'
+ str([str(o) for o in self.opts])))
def transform(self, node): def transform(self, node):
for opt in self.opts: for opt in self.opts:
...@@ -526,11 +645,12 @@ class LocalOptGroup(LocalOptimizer): ...@@ -526,11 +645,12 @@ class LocalOptGroup(LocalOptimizer):
return repl return repl
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print >> stream, "%s%s id=%i" %(' '*level, self.__class__.__name__, id(self)) print >> stream, "%s%s id=%i" % (
(' ' * level), self.__class__.__name__, id(self))
if depth != 0: if depth != 0:
depth -= 1 depth -= 1
for lopt in self.opts: for lopt in self.opts:
lopt.print_summary(stream, level=level+2, depth=depth) lopt.print_summary(stream, level=(level + 2), depth=depth)
class _LocalOpKeyOptGroup(LocalOptGroup): class _LocalOpKeyOptGroup(LocalOptGroup):
...@@ -550,13 +670,16 @@ class OpSub(LocalOptimizer): ...@@ -550,13 +670,16 @@ class OpSub(LocalOptimizer):
Replaces the application of a certain op by the application of Replaces the application of a certain op by the application of
another op that take the same inputs as what they are replacing. another op that take the same inputs as what they are replacing.
e.g. OpSub(add, sub) ==> add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x)) e.g. OpSub(add, sub) ==>
add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
""" """
reentrant = False # an OpSub does not apply to the nodes it produces # an OpSub does not apply to the nodes it produces
retains_inputs = True # all the inputs of the original node are transferred to the outputs reentrant = False
# all the inputs of the original node are transferred to the outputs
retains_inputs = True
def __init__(self, op1, op2, transfer_tags = True): def __init__(self, op1, op2, transfer_tags=True):
""" """
op1.make_node and op2.make_node must take the same number of op1.make_node and op2.make_node must take the same number of
inputs and have the same number of outputs. inputs and have the same number of outputs.
...@@ -611,7 +734,8 @@ class OpRemove(LocalOptimizer): ...@@ -611,7 +734,8 @@ class OpRemove(LocalOptimizer):
return "%s(x) -> x" % (self.op) return "%s(x) -> x" % (self.op)
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print >> stream, "%s%s(%s) id=%i" %(' '*level, print >> stream, "%s%s(%s) id=%i" % (
' ' * level,
self.__class__.__name__, self.__class__.__name__,
str(self.op), str(self.op),
id(self)) id(self))
...@@ -662,12 +786,12 @@ class PatternSub(LocalOptimizer): ...@@ -662,12 +786,12 @@ class PatternSub(LocalOptimizer):
PatternSub((subtract, (add, 'x', 'y'), 'y'), 'x') PatternSub((subtract, (add, 'x', 'y'), 'y'), 'x')
PatternSub((power, 'x', Constant(double, 2.0)), (square, 'x')) PatternSub((power, 'x', Constant(double, 2.0)), (square, 'x'))
PatternSub((boggle, {'pattern': 'x', PatternSub((boggle, {'pattern': 'x',
'constraint': lambda expr: expr.type == scrabble}), 'constraint': lambda expr: expr.type == scrabble}),
(scrabble, 'x')) (scrabble, 'x'))
""" """
def __init__(self, in_pattern, out_pattern, allow_multiple_clients = False, def __init__(self, in_pattern, out_pattern, allow_multiple_clients=False,
skip_identities_fn = None, name = None, pdb = False): skip_identities_fn=None, name=None, pdb=False):
""" """
Creates a PatternSub that replaces occurrences of Creates a PatternSub that replaces occurrences of
in_pattern by occurrences of out_pattern. in_pattern by occurrences of out_pattern.
...@@ -677,7 +801,8 @@ class PatternSub(LocalOptimizer): ...@@ -677,7 +801,8 @@ class PatternSub(LocalOptimizer):
:param allow_multiple_clients: if False, the pattern matching will fail :param allow_multiple_clients: if False, the pattern matching will fail
if one of the subpatterns has more than if one of the subpatterns has more than
one client. one client.
:param pdb: if True, we invoke pdb when the first node in the pattern match. :param pdb: if True, we invoke pdb when the first node in the
pattern match.
""" """
self.in_pattern = in_pattern self.in_pattern = in_pattern
self.out_pattern = out_pattern self.out_pattern = out_pattern
...@@ -686,8 +811,11 @@ class PatternSub(LocalOptimizer): ...@@ -686,8 +811,11 @@ class PatternSub(LocalOptimizer):
elif isinstance(in_pattern, dict): elif isinstance(in_pattern, dict):
self.op = self.in_pattern['pattern'][0] self.op = self.in_pattern['pattern'][0]
else: else:
raise TypeError("The pattern to search for must start with a specific Op instance.") raise TypeError("The pattern to search for must start with "
self.__doc__ = self.__class__.__doc__ + "\n\nThis instance does: " + str(self) + "\n" "a specific Op instance.")
self.__doc__ = (self.__class__.__doc__
+ "\n\nThis instance does: "
+ str(self) + "\n")
self.allow_multiple_clients = allow_multiple_clients self.allow_multiple_clients = allow_multiple_clients
self.skip_identities_fn = skip_identities_fn self.skip_identities_fn = skip_identities_fn
if name: if name:
...@@ -722,7 +850,7 @@ class PatternSub(LocalOptimizer): ...@@ -722,7 +850,7 @@ class PatternSub(LocalOptimizer):
if node.op != self.op: if node.op != self.op:
return False return False
def match(pattern, expr, u, allow_multiple_clients = False, pdb = False): def match(pattern, expr, u, allow_multiple_clients=False, pdb=False):
def retry_with_equiv(): def retry_with_equiv():
expr_equiv = self.skip_identities(expr) expr_equiv = self.skip_identities(expr)
if expr_equiv is None: if expr_equiv is None:
...@@ -735,7 +863,9 @@ class PatternSub(LocalOptimizer): ...@@ -735,7 +863,9 @@ class PatternSub(LocalOptimizer):
if isinstance(pattern, (list, tuple)): if isinstance(pattern, (list, tuple)):
if expr.owner is None: if expr.owner is None:
return False return False
if not (expr.owner.op == pattern[0]) or (not allow_multiple_clients and len(expr.clients) > 1): if (not (expr.owner.op == pattern[0])
or (not allow_multiple_clients
and len(expr.clients) > 1)):
return retry_with_equiv() return retry_with_equiv()
if len(pattern) - 1 != len(expr.owner.inputs): if len(pattern) - 1 != len(expr.owner.inputs):
return retry_with_equiv() return retry_with_equiv()
...@@ -747,10 +877,14 @@ class PatternSub(LocalOptimizer): ...@@ -747,10 +877,14 @@ class PatternSub(LocalOptimizer):
try: try:
real_pattern = pattern['pattern'] real_pattern = pattern['pattern']
except KeyError: except KeyError:
raise KeyError("Malformed pattern: %s (expected key 'pattern')" % pattern) raise KeyError(
"Malformed pattern: %s (expected key 'pattern')"
% pattern)
constraint = pattern.get('constraint', lambda expr: True) constraint = pattern.get('constraint', lambda expr: True)
if constraint(expr): if constraint(expr):
return match(real_pattern, expr, u, pattern.get('allow_multiple_clients', allow_multiple_clients)) return match(real_pattern, expr, u,
pattern.get('allow_multiple_clients',
allow_multiple_clients))
else: else:
return retry_with_equiv() return retry_with_equiv()
elif isinstance(pattern, basestring): elif isinstance(pattern, basestring):
...@@ -759,17 +893,22 @@ class PatternSub(LocalOptimizer): ...@@ -759,17 +893,22 @@ class PatternSub(LocalOptimizer):
return retry_with_equiv() return retry_with_equiv()
else: else:
u = u.merge(expr, v) u = u.merge(expr, v)
elif isinstance(pattern, (int, float)) and isinstance(expr, graph.Constant): elif (isinstance(pattern, (int, float))
if numpy.all(theano.tensor.constant(pattern).value==expr.value): and isinstance(expr, graph.Constant)):
if numpy.all(
theano.tensor.constant(pattern).value == expr.value):
return u return u
else: else:
return retry_with_equiv() return retry_with_equiv()
elif isinstance(pattern, graph.Constant) and isinstance(expr, graph.Constant) and pattern.equals(expr): elif (isinstance(pattern, graph.Constant)
and isinstance(expr, graph.Constant)
and pattern.equals(expr)):
return u return u
else: else:
return retry_with_equiv() return retry_with_equiv()
if pdb: if pdb:
import pdb;pdb.set_trace() import pdb
pdb.set_trace()
return u return u
def build(pattern, u): def build(pattern, u):
...@@ -778,11 +917,12 @@ class PatternSub(LocalOptimizer): ...@@ -778,11 +917,12 @@ class PatternSub(LocalOptimizer):
return pattern[0](*args) return pattern[0](*args)
elif isinstance(pattern, basestring): elif isinstance(pattern, basestring):
return u[unify.Var(pattern)] return u[unify.Var(pattern)]
elif isinstance(pattern, (int,float)): elif isinstance(pattern, (int, float)):
return pattern return pattern
else: else:
return pattern.clone() return pattern.clone()
u = match(self.in_pattern, node.out, unify.Unification(), True, self.pdb) u = match(self.in_pattern, node.out, unify.Unification(), True,
self.pdb)
if u: if u:
p = self.out_pattern p = self.out_pattern
new = build(p, u) new = build(p, u)
...@@ -792,23 +932,31 @@ class PatternSub(LocalOptimizer): ...@@ -792,23 +932,31 @@ class PatternSub(LocalOptimizer):
return False return False
def __str__(self): def __str__(self):
if getattr(self,'__name__',None): if getattr(self, '__name__', None):
return self.__name__ return self.__name__
def pattern_to_str(pattern): def pattern_to_str(pattern):
if isinstance(pattern, (list, tuple)): if isinstance(pattern, (list, tuple)):
return "%s(%s)" % (str(pattern[0]), ", ".join([pattern_to_str(p) for p in pattern[1:]])) return "%s(%s)" % (
str(pattern[0]),
", ".join([pattern_to_str(p) for p in pattern[1:]]))
elif isinstance(pattern, dict): elif isinstance(pattern, dict):
return "%s subject to %s" % (pattern_to_str(pattern['pattern']), str(pattern.get('constraint', 'no conditions'))) return "%s subject to %s" % (
pattern_to_str(pattern['pattern']),
str(pattern.get('constraint', 'no conditions')))
else: else:
return str(pattern) return str(pattern)
return "%s -> %s" % (pattern_to_str(self.in_pattern), pattern_to_str(self.out_pattern)) return "%s -> %s" % (
pattern_to_str(self.in_pattern),
pattern_to_str(self.out_pattern))
def __repr__(self): def __repr__(self):
return str(self) return str(self)
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, '__name__', getattr(self, 'name', None)) name = getattr(self, '__name__', getattr(self, 'name', None))
print >> stream, "%s%s %s(%s, %s) id=%i" %(' '*level, print >> stream, "%s%s %s(%s, %s) id=%i" % (
' ' * level,
self.__class__.__name__, self.__class__.__name__,
name, name,
str(self.in_pattern), str(self.in_pattern),
...@@ -836,37 +984,48 @@ class NavigatorOptimizer(Optimizer): ...@@ -836,37 +984,48 @@ class NavigatorOptimizer(Optimizer):
_logger.error(traceback.format_exc()) _logger.error(traceback.format_exc())
if isinstance(exc, AssertionError) or config.on_opt_error == 'raise': if isinstance(exc, AssertionError) or config.on_opt_error == 'raise':
raise exc raise exc
@staticmethod @staticmethod
def warn_inplace(exc, nav, repl_pairs, local_opt): def warn_inplace(exc, nav, repl_pairs, local_opt):
"""failure_callback for NavigatorOptimizer: ignore InconsistencyErrors, print traceback """failure_callback for NavigatorOptimizer
ignore InconsistencyErrors, print traceback
""" """
if isinstance(exc, InconsistencyError): if isinstance(exc, InconsistencyError):
return return
return NavigatorOptimizer.warn(exc, nav, repl_pairs, local_opt) return NavigatorOptimizer.warn(exc, nav, repl_pairs, local_opt)
@staticmethod @staticmethod
def warn_ignore(exc, nav, repl_pairs, local_opt): def warn_ignore(exc, nav, repl_pairs, local_opt):
"""failure_callback for NavigatorOptimizer: ignore all errors """failure_callback for NavigatorOptimizer: ignore all errors
""" """
pass pass
def __init__(self, local_opt, ignore_newtrees = 'auto', failure_callback = None): def __init__(self, local_opt, ignore_newtrees='auto',
failure_callback=None):
""" """
:param local_opt: a LocalOptimizer to apply over a Env (or None is Ok too). :param local_opt: a LocalOptimizer to apply over a Env
(or None is Ok too).
:param ignore_newtrees: :param ignore_newtrees:
- True: new subgraphs returned by an optimization is not a candidate for optimization - True: new subgraphs returned by an optimization is not a
- False: new subgraphs returned by an optimization is a candidate for optimization candidate for optimization
- 'auto': let the local_opt set this parameter via its 'reentrant' attribute. - False: new subgraphs returned by an optimization is a candidate
for optimization
- 'auto': let the local_opt set this parameter via its 'reentrant'
attribute.
:param failure_callback: :param failure_callback:
a function that takes (exception, navigator, [(old, new), a function that takes (exception, navigator, [(old, new),
(old,new),...]) and we call it if there's an exception. (old,new),...]) and we call it if there's an exception.
If the trouble is from local_opt.transform(), the new variables will be 'None'. If the trouble is from local_opt.transform(), the new variables
will be 'None'.
If the trouble is from validation (the new types don't match for If the trouble is from validation (the new types don't match for
example) then the new variables will be the ones created by example) then the new variables will be the ones created by
transform(). transform().
If this parameter is None, then exceptions are not caught here (raised normally). If this parameter is None, then exceptions are not caught here
(raised normally).
""" """
self.local_opt = local_opt self.local_opt = local_opt
if ignore_newtrees == 'auto': if ignore_newtrees == 'auto':
...@@ -875,15 +1034,19 @@ class NavigatorOptimizer(Optimizer): ...@@ -875,15 +1034,19 @@ class NavigatorOptimizer(Optimizer):
self.ignore_newtrees = ignore_newtrees self.ignore_newtrees = ignore_newtrees
self.failure_callback = failure_callback self.failure_callback = failure_callback
def attach_updater(self, env, importer, pruner, chin = None): def attach_updater(self, env, importer, pruner, chin=None):
"""Install some Env listeners to help the navigator deal with the ignore_trees-related functionality. """
Install some Env listeners to help the navigator deal with the
ignore_trees-related functionality.
:param importer: function that will be called whenever when optimizations add stuff to the graph. :param importer: function that will be called whenever when
:param pruner: function to be called when optimizations remove stuff from graph. optimizations add stuff to the graph.
:param pruner: function to be called when optimizations remove stuff
from graph.
:param chin: "on change input" called whenever an node's inputs change. :param chin: "on change input" called whenever an node's inputs change.
:returns: The Env plugin that handles the three tasks. Keep this around so that you can detach later! :returns: The Env plugin that handles the three tasks.
Keep this around so that you can detach later!
""" """
if self.ignore_newtrees: if self.ignore_newtrees:
importer = None importer = None
...@@ -916,21 +1079,22 @@ class NavigatorOptimizer(Optimizer): ...@@ -916,21 +1079,22 @@ class NavigatorOptimizer(Optimizer):
if u is not None: if u is not None:
env.remove_feature(u) env.remove_feature(u)
def process_node(self, env, node, lopt = None): def process_node(self, env, node, lopt=None):
""" """
This function will use `lopt` to `transform` the `node`. The `transform` method will This function will use `lopt` to `transform` the `node`. The
return either False or a list of Variables that are intended to replace `node.outputs`. `transform` method will return either False or a list of Variables
that are intended to replace `node.outputs`.
If the env accepts the replacement, then the optimization is successful, and this If the env accepts the replacement, then the optimization is
function returns True. successful, and this function returns True.
If there are no replacement candidates or the env rejects the replacements, this If there are no replacement candidates or the env rejects the
function returns False. replacements, this function returns False.
:param env: an Env :param env: an Env
:param node: an Apply instance in `env` :param node: an Apply instance in `env`
:param lopt: a LocalOptimizer instance that may have a better idea for how to compute :param lopt: a LocalOptimizer instance that may have a better idea for
node's outputs. how to compute node's outputs.
:rtype: Bool :rtype: Bool
:returns: True iff the `node`'s outputs were replaced in the `env`. :returns: True iff the `node`'s outputs were replaced in the `env`.
...@@ -940,16 +1104,19 @@ class NavigatorOptimizer(Optimizer): ...@@ -940,16 +1104,19 @@ class NavigatorOptimizer(Optimizer):
replacements = lopt.transform(node) replacements = lopt.transform(node)
except Exception, e: except Exception, e:
if self.failure_callback is not None: if self.failure_callback is not None:
self.failure_callback(e, self, [(x, None) for x in node.outputs], lopt) self.failure_callback(e, self,
[(x, None) for x in node.outputs], lopt)
return False return False
else: else:
raise raise
if replacements is False or replacements is None: if replacements is False or replacements is None:
return False return False
if not isinstance(replacements, (tuple, list)): if not isinstance(replacements, (tuple, list)):
raise TypeError('Optimizer %s gave wrong type of replacement. Expected list or tuple.' % lopt) raise TypeError('Optimizer %s gave wrong type of replacement. '
'Expected list or tuple.' % lopt)
if len(node.outputs) != len(replacements): if len(node.outputs) != len(replacements):
raise ValueError('Optimizer %s gave wrong number of replacements' % lopt) raise ValueError('Optimizer %s gave wrong number of replacements'
% lopt)
# If an output would be replaced by itself, no need to perform # If an output would be replaced by itself, no need to perform
# the replacement # the replacement
repl_pairs = [(r, rnew) for r, rnew in zip(node.outputs, replacements) repl_pairs = [(r, rnew) for r, rnew in zip(node.outputs, replacements)
...@@ -962,8 +1129,8 @@ class NavigatorOptimizer(Optimizer): ...@@ -962,8 +1129,8 @@ class NavigatorOptimizer(Optimizer):
except Exception, e: except Exception, e:
# This means the replacements were rejected by the env. # This means the replacements were rejected by the env.
# #
# This is not supposed to happen. The default failure_callback will print a # This is not supposed to happen. The default failure_callback
# traceback as a warning. # will print a traceback as a warning.
if self.failure_callback is not None: if self.failure_callback is not None:
self.failure_callback(e, self, repl_pairs, lopt) self.failure_callback(e, self, repl_pairs, lopt)
return False return False
...@@ -978,26 +1145,33 @@ class NavigatorOptimizer(Optimizer): ...@@ -978,26 +1145,33 @@ class NavigatorOptimizer(Optimizer):
self.local_opt.add_requirements(env) self.local_opt.add_requirements(env)
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print >> stream, "%s%s (%i)" %(' '*level, self.__class__.__name__, id(self)) print >> stream, "%s%s (%i)" % (
(' ' * level), self.__class__.__name__, id(self))
if depth != 0: if depth != 0:
self.local_opt.print_summary(stream, level=level+2, depth=depth-1) self.local_opt.print_summary(stream, level=(level + 2),
depth=(depth - 1))
class TopoOptimizer(NavigatorOptimizer): class TopoOptimizer(NavigatorOptimizer):
"""WRITEME""" """WRITEME"""
def __init__(self, local_opt, order = 'in_to_out', ignore_newtrees = False, failure_callback = None): def __init__(self, local_opt, order='in_to_out', ignore_newtrees=False,
failure_callback=None):
if order not in ['out_to_in', 'in_to_out']: if order not in ['out_to_in', 'in_to_out']:
raise ValueError("order must be 'out_to_in' or 'in_to_out'") raise ValueError("order must be 'out_to_in' or 'in_to_out'")
self.order = order self.order = order
NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees, failure_callback) NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees,
failure_callback)
def apply(self, env, start_from = None): def apply(self, env, start_from=None):
if start_from is None: start_from = env.outputs if start_from is None:
start_from = env.outputs
q = deque(graph.io_toposort(env.inputs, start_from)) q = deque(graph.io_toposort(env.inputs, start_from))
def importer(node): def importer(node):
if node is not current_node: if node is not current_node:
q.append(node) q.append(node)
def pruner(node): def pruner(node):
if node is not current_node: if node is not current_node:
try: try:
...@@ -1020,14 +1194,16 @@ class TopoOptimizer(NavigatorOptimizer): ...@@ -1020,14 +1194,16 @@ class TopoOptimizer(NavigatorOptimizer):
self.detach_updater(env, u) self.detach_updater(env, u)
class OpKeyOptimizer(NavigatorOptimizer): class OpKeyOptimizer(NavigatorOptimizer):
"""WRITEME""" """WRITEME"""
def __init__(self, local_opt, ignore_newtrees = False, failure_callback = None): def __init__(self, local_opt, ignore_newtrees=False,
failure_callback=None):
if not hasattr(local_opt, 'op_key'): if not hasattr(local_opt, 'op_key'):
raise TypeError("LocalOptimizer for OpKeyOptimizer must have an 'op_key' method.") raise TypeError("LocalOptimizer for OpKeyOptimizer must have "
NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees, failure_callback) "an 'op_key' method.")
NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees,
failure_callback)
def apply(self, env): def apply(self, env):
op = self.local_opt.op_key() op = self.local_opt.op_key()
...@@ -1035,9 +1211,12 @@ class OpKeyOptimizer(NavigatorOptimizer): ...@@ -1035,9 +1211,12 @@ class OpKeyOptimizer(NavigatorOptimizer):
q = reduce(list.__iadd__, map(env.get_nodes, op)) q = reduce(list.__iadd__, map(env.get_nodes, op))
else: else:
q = list(env.get_nodes(op)) q = list(env.get_nodes(op))
def importer(node): def importer(node):
if node is not current_node: if node is not current_node:
if node.op == op: q.append(node) if node.op == op:
q.append(node)
def pruner(node): def pruner(node):
if node is not current_node and node.op == op: if node is not current_node and node.op == op:
try: try:
...@@ -1065,7 +1244,6 @@ class OpKeyOptimizer(NavigatorOptimizer): ...@@ -1065,7 +1244,6 @@ class OpKeyOptimizer(NavigatorOptimizer):
env.extend(toolbox.NodeFinder()) env.extend(toolbox.NodeFinder())
class ChangeTracker: class ChangeTracker:
def __init__(self): def __init__(self):
self.changed = False self.changed = False
...@@ -1082,17 +1260,19 @@ class ChangeTracker: ...@@ -1082,17 +1260,19 @@ class ChangeTracker:
def on_attach(self, env): def on_attach(self, env):
env.change_tracker = self env.change_tracker = self
class EquilibriumOptimizer(NavigatorOptimizer): class EquilibriumOptimizer(NavigatorOptimizer):
def __init__(self, def __init__(self,
optimizers, optimizers,
failure_callback = None, failure_callback=None,
max_depth = None, max_depth=None,
max_use_ratio = None): max_use_ratio=None):
""" """
:param optimizers: list or set of local or global optimizations to apply until :param optimizers: list or set of local or global optimizations to
equilibrium. apply until equilibrium.
:param max_use_ratio: each optimizer can be applied at most (size of graph * this number) :param max_use_ratio: each optimizer can be applied at most
(size of graph * this number) times
:param max_depth: TODO what does this do? (EquilibriumDB sets it to 5) :param max_depth: TODO what does this do? (EquilibriumDB sets it to 5)
...@@ -1100,8 +1280,8 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1100,8 +1280,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
super(EquilibriumOptimizer, self).__init__( super(EquilibriumOptimizer, self).__init__(
None, None,
ignore_newtrees = True, ignore_newtrees=True,
failure_callback = failure_callback) failure_callback=failure_callback)
self.local_optimizers = [] self.local_optimizers = []
self.global_optimizers = [] self.global_optimizers = []
...@@ -1112,13 +1292,18 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1112,13 +1292,18 @@ class EquilibriumOptimizer(NavigatorOptimizer):
self.global_optimizers.append(opt) self.global_optimizers.append(opt)
self.max_depth = max_depth self.max_depth = max_depth
self.max_use_ratio = max_use_ratio self.max_use_ratio = max_use_ratio
assert self.max_use_ratio is not None, 'max_use_ratio has to be a number' assert self.max_use_ratio is not None, (
'max_use_ratio has to be a number')
def add_requirements(self, env): def add_requirements(self, env):
super(EquilibriumOptimizer, self).add_requirements(env) super(EquilibriumOptimizer, self).add_requirements(env)
env.extend(ChangeTracker()) env.extend(ChangeTracker())
for opt in self.local_optimizers:
opt.add_requirements(env)
for opt in self.global_optimizers:
opt.add_requirements(env)
def apply(self, env, start_from = None): def apply(self, env, start_from=None):
if start_from is None: if start_from is None:
start_from = env.outputs start_from = env.outputs
changed = True changed = True
...@@ -1153,9 +1338,11 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1153,9 +1338,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
nb_nodes.append(len(q)) nb_nodes.append(len(q))
max_nb_nodes = max(max_nb_nodes, len(q)) max_nb_nodes = max(max_nb_nodes, len(q))
max_use = max_nb_nodes * self.max_use_ratio max_use = max_nb_nodes * self.max_use_ratio
def importer(node): def importer(node):
if node is not current_node: if node is not current_node:
q.append(node) q.append(node)
def pruner(node): def pruner(node):
if node is not current_node: if node is not current_node:
try: try:
...@@ -1179,12 +1366,13 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1179,12 +1366,13 @@ class EquilibriumOptimizer(NavigatorOptimizer):
opt_name = (getattr(lopt, "name", None) opt_name = (getattr(lopt, "name", None)
or getattr(lopt, "__name__", "")) or getattr(lopt, "__name__", ""))
if node not in env.nodes: if node not in env.nodes:
break # go to next node # go to next node
break
finally: finally:
self.detach_updater(env, u) self.detach_updater(env, u)
self.detach_updater(env, u) #TODO: erase this line, it's redundant at best
loop_timing.append(float(time.time() - t0)) loop_timing.append(float(time.time() - t0))
if max_use_abort: if max_use_abort:
_logger.error("EquilibriumOptimizer max'ed out by '%s'" % opt_name _logger.error("EquilibriumOptimizer max'ed out by '%s'" % opt_name
+ ". You can safely raise the current threshold of " + ". You can safely raise the current threshold of "
...@@ -1216,10 +1404,12 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1216,10 +1404,12 @@ class EquilibriumOptimizer(NavigatorOptimizer):
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, 'name', None) name = getattr(self, 'name', None)
print >> stream, "%s%s %s id=%i" %(' '*level, self.__class__.__name__, name, id(self)) print >> stream, "%s%s %s id=%i" % (
(' ' * level), self.__class__.__name__, name, id(self))
if depth != 0: if depth != 0:
for lopt in self.local_optimizers: for lopt in self.local_optimizers:
lopt.print_summary(stream, level=level+2, depth=depth-1) lopt.print_summary(stream, level=(level + 2),
depth=(depth - 1))
################# #################
...@@ -1242,7 +1432,8 @@ def _check_chain(r, chain): ...@@ -1242,7 +1432,8 @@ def _check_chain(r, chain):
return False return False
else: else:
try: try:
if issubclass(elem, op.Op) and not isinstance(r.owner.op, elem): if (issubclass(elem, op.Op)
and not isinstance(r.owner.op, elem)):
return False return False
except TypeError: except TypeError:
return False return False
...@@ -1256,6 +1447,7 @@ def _check_chain(r, chain): ...@@ -1256,6 +1447,7 @@ def _check_chain(r, chain):
return (r is not None) return (r is not None)
#_check_chain.n_calls = 0 #_check_chain.n_calls = 0
def check_chain(r, *chain): def check_chain(r, *chain):
"""WRITEME""" """WRITEME"""
if isinstance(r, graph.Apply): if isinstance(r, graph.Apply):
...@@ -1280,7 +1472,7 @@ def pre_greedy_local_optimizer(list_optimizations, out): ...@@ -1280,7 +1472,7 @@ def pre_greedy_local_optimizer(list_optimizations, out):
add additional node to the inputs of the node, it can add additional node to the inputs of the node, it can
be needed to call this function multiple time. be needed to call this function multiple time.
''' '''
def local_recursive_function( list_opt, out, optimized_vars, depth): def local_recursive_function(list_opt, out, optimized_vars, depth):
if not getattr(out, 'owner', None): if not getattr(out, 'owner', None):
return [out], optimized_vars return [out], optimized_vars
node = out.owner node = out.owner
...@@ -1292,11 +1484,11 @@ def pre_greedy_local_optimizer(list_optimizations, out): ...@@ -1292,11 +1484,11 @@ def pre_greedy_local_optimizer(list_optimizations, out):
else: else:
if inp.owner: if inp.owner:
outs, optimized_vars = local_recursive_function( outs, optimized_vars = local_recursive_function(
list_opt list_opt,
, inp inp,
, optimized_vars optimized_vars,
, depth+1) depth + 1)
for k,v in zip(inp.owner.outputs, outs): for k, v in zip(inp.owner.outputs, outs):
optimized_vars[k] = v optimized_vars[k] = v
nw_in = outs[inp.owner.outputs.index(inp)] nw_in = outs[inp.owner.outputs.index(inp)]
...@@ -1310,10 +1502,10 @@ def pre_greedy_local_optimizer(list_optimizations, out): ...@@ -1310,10 +1502,10 @@ def pre_greedy_local_optimizer(list_optimizations, out):
ret = opt.transform(node) ret = opt.transform(node)
if ret is not False and ret is not None: if ret is not False and ret is not None:
assert len(ret) == len(node.outputs) assert len(ret) == len(node.outputs)
for k,v in zip(node.outputs, ret): for k, v in zip(node.outputs, ret):
optimized_vars[k] = v optimized_vars[k] = v
results = ret results = ret
if ret[0].owner : if ret[0].owner:
node = out.owner node = out.owner
else: else:
break break
...@@ -1324,8 +1516,6 @@ def pre_greedy_local_optimizer(list_optimizations, out): ...@@ -1324,8 +1516,6 @@ def pre_greedy_local_optimizer(list_optimizations, out):
return final_outs[0] return final_outs[0]
############ ############
### Misc ### ### Misc ###
############ ############
......
...@@ -1823,7 +1823,6 @@ def local_subtensor_merge(node): ...@@ -1823,7 +1823,6 @@ def local_subtensor_merge(node):
merged_slices.append(slice1) merged_slices.append(slice1)
pos_1 += 1 pos_1 += 1
if pos_2 < len(slices2): if pos_2 < len(slices2):
merged_slices += slices2[pos_2:] merged_slices += slices2[pos_2:]
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论