提交 c066b30a authored 作者: Pascal Lamblin's avatar Pascal Lamblin

PEP 8

上级 5a4734bd
"""WRITEME """WRITEME
""" """
import os, logging, warnings import logging
import numpy, theano import numpy
import theano
from theano import gof from theano import gof
import theano.gof.vm import theano.gof.vm
from theano.configparser import config, AddConfigVar, StrParam, EnumStr from theano.configparser import config, AddConfigVar, StrParam
_logger = logging.getLogger('theano.compile.mode') _logger = logging.getLogger('theano.compile.mode')
AddConfigVar('optimizer_excluding', AddConfigVar('optimizer_excluding',
"When using the default mode, we will remove optimizer with that tag. Separate many tags with ':'.", ("When using the default mode, we will remove optimizer with these "
"tags. Separate tags with ':'."),
StrParam("", allow_override=False), StrParam("", allow_override=False),
in_c_key=False) in_c_key=False)
AddConfigVar('optimizer_including', AddConfigVar('optimizer_including',
"When using the default mode, we will add optimizer with that tag. Separate many tags with ':'.", ("When using the default mode, we will add optimizer with these tags. "
"Separate tags with ':'."),
StrParam("", allow_override=False), StrParam("", allow_override=False),
in_c_key=False) in_c_key=False)
AddConfigVar('optimizer_requiring', AddConfigVar('optimizer_requiring',
"When using the default mode, we will require optimizer with that tag. Separate many tags with ':'.", ("When using the default mode, we will require optimizer with these "
"tags. Separate tags with ':'."),
StrParam("", allow_override=False), StrParam("", allow_override=False),
in_c_key=False) in_c_key=False)
def check_equal(x, y): def check_equal(x, y):
""" """
Returns True iff x[0] and y[0] are equal (checks the dtype and Returns True iff x[0] and y[0] are equal (checks the dtype and
...@@ -32,35 +38,37 @@ def check_equal(x, y): ...@@ -32,35 +38,37 @@ def check_equal(x, y):
import scipy.sparse as sp import scipy.sparse as sp
x, y = x[0], y[0] x, y = x[0], y[0]
# TODO: bug in current scipy, two sparse matrices are never equal, remove when moving to 0.7 # TODO: bug in current scipy, two sparse matrices are never equal,
# remove when moving to 0.7
if sp.issparse(x): if sp.issparse(x):
x = x.todense() x = x.todense()
if sp.issparse(y): if sp.issparse(y):
y = y.todense() y = y.todense()
if isinstance(x, numpy.ndarray) and isinstance(y, numpy.ndarray): if isinstance(x, numpy.ndarray) and isinstance(y, numpy.ndarray):
if x.dtype != y.dtype or x.shape != y.shape or numpy.any(abs(x - y) > 1e-10): if (x.dtype != y.dtype
raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y}) or x.shape != y.shape
or numpy.any(abs(x - y) > 1e-10)):
raise Exception("Output mismatch.",
{'performlinker': x, 'clinker': y})
else: else:
if x != y: if x != y:
raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y}) raise Exception("Output mismatch.",
{'performlinker': x, 'clinker': y})
# If a string is passed as the linker argument in the constructor for # If a string is passed as the linker argument in the constructor for
# Mode, it will be used as the key to retrieve the real linker in this # Mode, it will be used as the key to retrieve the real linker in this
# dictionary # dictionary
predefined_linkers = { predefined_linkers = {
'py' : gof.PerformLinker(), 'py': gof.PerformLinker(),
'c' : gof.CLinker(), 'c': gof.CLinker(),
'c|py' : gof.OpWiseCLinker(allow_gc=True), 'c|py': gof.OpWiseCLinker(allow_gc=True),
'c|py_nogc' : gof.OpWiseCLinker(allow_gc=False), 'c|py_nogc': gof.OpWiseCLinker(allow_gc=False),
'c&py' : gof.DualLinker(checker = check_equal), 'c&py': gof.DualLinker(checker=check_equal),
'vm' : gof.vm.VM_Linker(allow_gc=True, use_cloop=False), 'vm': gof.vm.VM_Linker(allow_gc=True, use_cloop=False),
'cvm' : gof.vm.VM_Linker(allow_gc=True, use_cloop=True), 'cvm': gof.vm.VM_Linker(allow_gc=True, use_cloop=True),
'vm_nogc' : gof.vm.VM_Linker(allow_gc=False, use_cloop=False), 'vm_nogc': gof.vm.VM_Linker(allow_gc=False, use_cloop=False),
'cvm_nogc': gof.vm.VM_Linker(allow_gc=False, use_cloop=True), 'cvm_nogc': gof.vm.VM_Linker(allow_gc=False, use_cloop=True),
} }
...@@ -72,37 +80,37 @@ def register_linker(name, linker): ...@@ -72,37 +80,37 @@ def register_linker(name, linker):
predefined_linkers[name] = linker predefined_linkers[name] = linker
# If a string is passed as the optimizer argument in the constructor # If a string is passed as the optimizer argument in the constructor
# for Mode, it will be used as the key to retrieve the real optimizer # for Mode, it will be used as the key to retrieve the real optimizer
# in this dictionary # in this dictionary
OPT_FAST_RUN = gof.Query(include = ['fast_run']) OPT_FAST_RUN = gof.Query(include=['fast_run'])
OPT_FAST_RUN_STABLE = OPT_FAST_RUN.requiring('stable') OPT_FAST_RUN_STABLE = OPT_FAST_RUN.requiring('stable')
OPT_FAST_COMPILE = gof.Query(include = ['fast_compile']) OPT_FAST_COMPILE = gof.Query(include=['fast_compile'])
OPT_STABILIZE = gof.Query(include = ['fast_run']) OPT_STABILIZE = gof.Query(include=['fast_run'])
OPT_STABILIZE.position_cutoff = 1.5000001 OPT_STABILIZE.position_cutoff = 1.5000001
predefined_optimizers = { predefined_optimizers = {
None : lambda env: None, None: (lambda env: None),
'None' : lambda env: None, 'None': (lambda env: None),
'merge' : gof.MergeOptimizer(), 'merge': gof.MergeOptimizer(),
'fast_run' : OPT_FAST_RUN, 'fast_run': OPT_FAST_RUN,
'fast_run_stable' : OPT_FAST_RUN_STABLE, 'fast_run_stable': OPT_FAST_RUN_STABLE,
'fast_compile' : OPT_FAST_COMPILE, 'fast_compile': OPT_FAST_COMPILE,
'stabilize': OPT_STABILIZE 'stabilize': OPT_STABILIZE
} }
def register_optimizer(name, opt): def register_optimizer(name, opt):
"""Add a `Optimizer` which can be referred to by `name` in `Mode`.""" """Add a `Optimizer` which can be referred to by `name` in `Mode`."""
if name in predefined_optimizers: if name in predefined_optimizers:
raise ValueError('Optimizer name already taken: %s' % name) raise ValueError('Optimizer name already taken: %s' % name)
predefined_optimizers[name] = opt predefined_optimizers[name] = opt
def register_OutputGuard_c_code(type): def register_OutputGuard_c_code(type):
OutputGuard.c_code_types.append(type) OutputGuard.c_code_types.append(type)
class OutputGuard(gof.Op): class OutputGuard(gof.Op):
""" """
This op is used only internally by Theano. This op is used only internally by Theano.
...@@ -120,20 +128,24 @@ class OutputGuard(gof.Op): ...@@ -120,20 +128,24 @@ class OutputGuard(gof.Op):
TODO: find a current full explanation. TODO: find a current full explanation.
""" """
destroy_map = {0:[0]} destroy_map = {0: [0]}
view_map = {0:[0]} view_map = {0: [0]}
c_code_types = [] c_code_types = []
def make_node(self, x): def make_node(self, x):
return gof.Apply(self, [x], [x.type()]) return gof.Apply(self, [x], [x.type()])
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def perform(self, node, inp, out): def perform(self, node, inp, out):
x, = inp x, = inp
z, = out z, = out
z[0] = x z[0] = x
def __str__(self): def __str__(self):
return '%s' % self.__class__.__name__ return '%s' % self.__class__.__name__
...@@ -141,7 +153,8 @@ class OutputGuard(gof.Op): ...@@ -141,7 +153,8 @@ class OutputGuard(gof.Op):
x, = inp x, = inp
z, = out z, = out
if isinstance(node.inputs[0].type, theano.scalar.Scalar): if isinstance(node.inputs[0].type, theano.scalar.Scalar):
# Scalars are C objects on the stacks, and should not be inc/decrefed # Scalars are C objects on the stack,
# and should not be inc/decrefed
return """ return """
%(z)s = %(x)s; %(z)s = %(x)s;
""" % locals() """ % locals()
...@@ -161,71 +174,99 @@ class OutputGuard(gof.Op): ...@@ -161,71 +174,99 @@ class OutputGuard(gof.Op):
_output_guard = OutputGuard() _output_guard = OutputGuard()
class AddDestroyHandler(gof.Optimizer): class AddDestroyHandler(gof.Optimizer):
"""This optimizer performs two important functions: """This optimizer performs two important functions:
1) it has a 'requirement' of the destroyhandler. This means that the env will include it 1) it has a 'requirement' of the destroyhandler. This means that the env
as a feature for this optimization, and keep this feature enabled for subsequent will include it as a feature for this optimization, and keep this feature
optimizations. All optimizations that work inplace on any of their inputs must run *after* enabled for subsequent optimizations. All optimizations that work inplace
this optimization to ensure that the DestroyHandler has been included in the env. on any of their inputs must run *after* this optimization to ensure that
the DestroyHandler has been included in the env.
2) It tries to replace each output with an Op that purports to destroy it (but it won't I 2) It tries to replace each output with an Op that purports to destroy it
promise). If this replacement succeeds it means that there is a bug in theano. It should (but it won't I promise). If this replacement succeeds it means that
not be possible to destroy outputs. there is a bug in theano. It should not be possible to destroy outputs.
""" """
def apply(self, env): def apply(self, env):
for o in env.outputs: for o in env.outputs:
try: try:
env.replace_validate(o, _output_guard(o), reason='output_guard') env.replace_validate(o, _output_guard(o),
_logger.info("Output variable %s required output_guard," reason='output_guard')
" how was this output left unprotected against destructive operations?" _logger.info("Output variable %s required output_guard, "
"how was this output left unprotected against "
"destructive operations?"
% o) % o)
except gof.InconsistencyError: except gof.InconsistencyError:
#this output is already impossible to destroy. no guard necessary # This output is already impossible to destroy.
# No guard necessary
pass pass
def add_requirements(self, env): def add_requirements(self, env):
super(AddDestroyHandler, self).add_requirements(env) super(AddDestroyHandler, self).add_requirements(env)
env.extend(gof.DestroyHandler()) env.extend(gof.DestroyHandler())
class PrintCurrentEnv(gof.Optimizer): class PrintCurrentEnv(gof.Optimizer):
"""This optimizer is for debugging. """This optimizer is for debugging.
Toss it into the optimization pipeline to see the state of things at any given point. Toss it into the optimization pipeline to see the state of things at any
given point.
""" """
def __init__(self, header): def __init__(self, header):
self.header =header self.header = header
def apply(self, env): def apply(self, env):
import theano.printing import theano.printing
print "PrintCurrentEnv:", self.header print "PrintCurrentEnv:", self.header
theano.printing.debugprint(env.outputs) theano.printing.debugprint(env.outputs)
optdb = gof.SequenceDB() optdb = gof.SequenceDB()
optdb.register('merge1', gof.MergeOptimizer(), optdb.register('merge1', gof.MergeOptimizer(),
0, 'fast_run', 'fast_compile') 0, 'fast_run', 'fast_compile')
optdb.register('canonicalize', gof.EquilibriumDB(), # rearranges elemwise expressions
# rearranges elemwise expressions
optdb.register('canonicalize', gof.EquilibriumDB(),
1, 'fast_run', 'fast_compile') 1, 'fast_run', 'fast_compile')
optdb.register('merge1.2', gof.MergeOptimizer(), optdb.register('merge1.2', gof.MergeOptimizer(),
1.2, 'fast_run', 'fast_compile') 1.2, 'fast_run', 'fast_compile')
optdb.register('Print1.21', PrintCurrentEnv('Post-canonicalize'), optdb.register('Print1.21', PrintCurrentEnv('Post-canonicalize'),
1.21,)# 'fast_run', 'fast_compile') 1.21,) # 'fast_run', 'fast_compile')
optdb.register('stabilize', gof.EquilibriumDB(), # replace unstable subgraphs # replace unstable subgraphs
optdb.register('stabilize', gof.EquilibriumDB(),
1.5, 'fast_run') 1.5, 'fast_run')
optdb.register('Print1.51', PrintCurrentEnv('Post-stabilize'), optdb.register('Print1.51', PrintCurrentEnv('Post-stabilize'),
1.51,) #'fast_run', 'fast_compile') 1.51,) # 'fast_run', 'fast_compile')
optdb.register('specialize', gof.EquilibriumDB(), # misc special cases for speed
# misc special cases for speed
optdb.register('specialize', gof.EquilibriumDB(),
2, 'fast_run') 2, 'fast_run')
optdb.register('Print2.01', PrintCurrentEnv('Post-specialize'), optdb.register('Print2.01', PrintCurrentEnv('Post-specialize'),
2.01, )#'fast_run', 'fast_compile') 2.01,) # 'fast_run', 'fast_compile')
optdb.register('uncanonicalize', gof.EquilibriumDB(),# misc special cases for speed that break canonicalization
# misc special cases for speed that break canonicalization
optdb.register('uncanonicalize', gof.EquilibriumDB(),
3, 'fast_run') 3, 'fast_run')
optdb.register('specialize_device', gof.EquilibriumDB(), # misc special cases for speed that are dependent on the device.
48.6, 'fast_run')#must be after gpu stuff at 48.5 # misc special cases for speed that are dependent on the device.
optdb.register('merge2', gof.MergeOptimizer(), # especially constant merge optdb.register('specialize_device', gof.EquilibriumDB(),
48.6, 'fast_run') # must be after gpu stuff at 48.5
# especially constant merge
optdb.register('merge2', gof.MergeOptimizer(),
49, 'fast_run') 49, 'fast_run')
optdb.register('add_destroy_handler', AddDestroyHandler(), optdb.register('add_destroy_handler', AddDestroyHandler(),
49.5, 'fast_run', 'inplace') 49.5, 'fast_run', 'inplace')
optdb.register('merge3', gof.MergeOptimizer(), # final pass just to make sure
# final pass just to make sure
optdb.register('merge3', gof.MergeOptimizer(),
100, 'fast_run') 100, 'fast_run')
...@@ -251,12 +292,15 @@ class Mode(object): ...@@ -251,12 +292,15 @@ class Mode(object):
if optimizer is None: if optimizer is None:
optimizer = config.optimizer optimizer = config.optimizer
self.__setstate__((linker, optimizer)) self.__setstate__((linker, optimizer))
#self.provided_optimizer - typically the `optimizer` arg. But if the `optimizer` arg is
# keyword corresponding to a predefined Query, then this stores the query
#self._optimizer - typically same as provided_optimizer??
#self.__get_optimizer - returns self._optimizer (possibly querying optdb with self._optimizer) # self.provided_optimizer - typically the `optimizer` arg.
#self.optimizer - property that returns __get_optimizer() # But if the `optimizer` arg is keyword corresponding to a predefined
# Query, then this stores the query
# self._optimizer - typically same as provided_optimizer??
# self.__get_optimizer - returns self._optimizer (possibly querying
# optdb with self._optimizer)
# self.optimizer - property that returns __get_optimizer()
def __getstate__(self): def __getstate__(self):
return (self.provided_linker, self.provided_optimizer) return (self.provided_linker, self.provided_optimizer)
...@@ -275,12 +319,13 @@ class Mode(object): ...@@ -275,12 +319,13 @@ class Mode(object):
self._optimizer = optimizer self._optimizer = optimizer
self.call_time = 0 self.call_time = 0
self.fn_time = 0 self.fn_time = 0
linker.mode = self #TODO: WHY IS THIS HERE? linker.mode = self # TODO: WHY IS THIS HERE?
self.optimizer_time = 0 self.optimizer_time = 0
self.linker_time = 0 self.linker_time = 0
def __str__(self): def __str__(self):
return "Mode(linker = %s, optimizer = %s)" % (self.provided_linker, self.provided_optimizer) return "Mode(linker = %s, optimizer = %s)" % (
self.provided_linker, self.provided_optimizer)
def __get_optimizer(self): def __get_optimizer(self):
if isinstance(self._optimizer, gof.Query): if isinstance(self._optimizer, gof.Query):
...@@ -298,17 +343,20 @@ class Mode(object): ...@@ -298,17 +343,20 @@ class Mode(object):
return (linker, optimizer) return (linker, optimizer)
def including(self, *tags): def including(self, *tags):
link, opt = self.get_linker_optimizer(self.provided_linker, self.provided_optimizer) link, opt = self.get_linker_optimizer(self.provided_linker,
self.provided_optimizer)
#N.B. opt might be a Query instance, not sure what else it might be... #N.B. opt might be a Query instance, not sure what else it might be...
# string? Optimizer? OptDB? who knows??? # string? Optimizer? OptDB? who knows???
return self.__class__(linker=link, optimizer=opt.including(*tags)) return self.__class__(linker=link, optimizer=opt.including(*tags))
def excluding(self, *tags): def excluding(self, *tags):
link, opt = self.get_linker_optimizer(self.provided_linker, self.provided_optimizer) link, opt = self.get_linker_optimizer(self.provided_linker,
self.provided_optimizer)
return self.__class__(linker=link, optimizer=opt.excluding(*tags)) return self.__class__(linker=link, optimizer=opt.excluding(*tags))
def requiring(self, *tags): def requiring(self, *tags):
link, opt = self.get_linker_optimizer(self.provided_linker, self.provided_optimizer) link, opt = self.get_linker_optimizer(self.provided_linker,
self.provided_optimizer)
return self.__class__(linker=link, optimizer=opt.requiring(*tags)) return self.__class__(linker=link, optimizer=opt.requiring(*tags))
# If a string is passed as the mode argument in function or # If a string is passed as the mode argument in function or
...@@ -321,20 +369,22 @@ predefined_modes = {'FAST_COMPILE': FAST_COMPILE, ...@@ -321,20 +369,22 @@ predefined_modes = {'FAST_COMPILE': FAST_COMPILE,
'FAST_RUN': FAST_RUN, 'FAST_RUN': FAST_RUN,
} }
instanciated_default_mode=None instanciated_default_mode = None
def get_mode(orig_string): def get_mode(orig_string):
if orig_string is None: if orig_string is None:
string = config.mode string = config.mode
else: else:
string = orig_string string = orig_string
if not isinstance(string, basestring): if not isinstance(string, basestring):
return string #it is hopefully already a mode... return string # it is hopefully already a mode...
global instanciated_default_mode global instanciated_default_mode
# The default mode is cached. However, config.mode can change # The default mode is cached. However, config.mode can change
# If instanciated_default_mode has the right class, use it. # If instanciated_default_mode has the right class, use it.
if orig_string is None and instanciated_default_mode: if orig_string is None and instanciated_default_mode:
if predefined_modes.has_key(string): if string in predefined_modes:
default_mode_class = predefined_modes[string].__class__.__name__ default_mode_class = predefined_modes[string].__class__.__name__
else: else:
default_mode_class = string default_mode_class = string
...@@ -342,7 +392,7 @@ def get_mode(orig_string): ...@@ -342,7 +392,7 @@ def get_mode(orig_string):
default_mode_class): default_mode_class):
return instanciated_default_mode return instanciated_default_mode
if string in ['Mode','ProfileMode','DebugMode']: if string in ['Mode', 'ProfileMode', 'DebugMode']:
if string == 'DebugMode': if string == 'DebugMode':
#need to import later to break circular dependency. #need to import later to break circular dependency.
from debugmode import DebugMode from debugmode import DebugMode
...@@ -350,12 +400,13 @@ def get_mode(orig_string): ...@@ -350,12 +400,13 @@ def get_mode(orig_string):
ret = DebugMode(optimizer=config.optimizer) ret = DebugMode(optimizer=config.optimizer)
else: else:
# The import is needed in case string is ProfileMode # The import is needed in case string is ProfileMode
from profilemode import ProfileMode,prof_mode_instance_to_print from profilemode import ProfileMode, prof_mode_instance_to_print
ret = eval(string+'(linker=config.linker, optimizer=config.optimizer)') ret = eval(string
elif predefined_modes.has_key(string): + '(linker=config.linker, optimizer=config.optimizer)')
elif string in predefined_modes:
ret = predefined_modes[string] ret = predefined_modes[string]
else: else:
raise Exception("No predefined mode exist for string: %s"%string) raise Exception("No predefined mode exist for string: %s" % string)
if orig_string is None: if orig_string is None:
# Build and cache the default mode # Build and cache the default mode
...@@ -374,12 +425,14 @@ def get_mode(orig_string): ...@@ -374,12 +425,14 @@ def get_mode(orig_string):
return ret return ret
def get_default_mode(): def get_default_mode():
return get_mode(None) return get_mode(None)
# Removed: use config.mode instead. # Removed: use config.mode instead.
#default_mode = config.mode #default_mode = config.mode
def register_mode(name, mode): def register_mode(name, mode):
"""Add a `Mode` which can be referred to by `name` in `function`.""" """Add a `Mode` which can be referred to by `name` in `function`."""
if name in predefined_modes: if name in predefined_modes:
......
...@@ -1823,7 +1823,6 @@ def local_subtensor_merge(node): ...@@ -1823,7 +1823,6 @@ def local_subtensor_merge(node):
merged_slices.append(slice1) merged_slices.append(slice1)
pos_1 += 1 pos_1 += 1
if pos_2 < len(slices2): if pos_2 < len(slices2):
merged_slices += slices2[pos_2:] merged_slices += slices2[pos_2:]
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论