提交 2a1aa262 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Flake8 for compile/mode.py

上级 b5263298
...@@ -2,35 +2,33 @@ ...@@ -2,35 +2,33 @@
""" """
from __future__ import print_function from __future__ import print_function
import logging import logging
import warnings
from textwrap import dedent
import numpy import numpy
import theano import theano
from theano import gof from theano import gof
import theano.gof.vm import theano.gof.vm
from theano.configparser import config, AddConfigVar, StrParam from theano.configparser import config, AddConfigVar, StrParam
from theano.compile.ops import register_view_op_c_code, _output_guard from theano.compile.ops import _output_guard
_logger = logging.getLogger('theano.compile.mode') _logger = logging.getLogger('theano.compile.mode')
AddConfigVar('optimizer_excluding', AddConfigVar('optimizer_excluding',
("When using the default mode, we will remove optimizer with these " ("When using the default mode, we will remove optimizer with "
"tags. Separate tags with ':'."), "these tags. Separate tags with ':'."),
StrParam("", allow_override=False), StrParam("", allow_override=False),
in_c_key=False) in_c_key=False)
AddConfigVar('optimizer_including', AddConfigVar('optimizer_including',
("When using the default mode, we will add optimizer with these tags. " ("When using the default mode, we will add optimizer with "
"Separate tags with ':'."), "these tags. Separate tags with ':'."),
StrParam("", allow_override=False), StrParam("", allow_override=False),
in_c_key=False) in_c_key=False)
AddConfigVar('optimizer_requiring', AddConfigVar('optimizer_requiring',
("When using the default mode, we will require optimizer with these " ("When using the default mode, we will require optimizer with "
"tags. Separate tags with ':'."), "these tags. Separate tags with ':'."),
StrParam("", allow_override=False), StrParam("", allow_override=False),
in_c_key=False) in_c_key=False)
def check_equal(x, y): def check_equal(x, y):
...@@ -54,11 +52,11 @@ def check_equal(x, y): ...@@ -54,11 +52,11 @@ def check_equal(x, y):
or x.shape != y.shape or x.shape != y.shape
or numpy.any(abs(x - y) > 1e-10)): or numpy.any(abs(x - y) > 1e-10)):
raise Exception("Output mismatch.", raise Exception("Output mismatch.",
{'performlinker': x, 'clinker': y}) {'performlinker': x, 'clinker': y})
else: else:
if x != y: if x != y:
raise Exception("Output mismatch.", raise Exception("Output mismatch.",
{'performlinker': x, 'clinker': y}) {'performlinker': x, 'clinker': y})
# If a string is passed as the linker argument in the constructor for # If a string is passed as the linker argument in the constructor for
...@@ -144,11 +142,11 @@ class AddDestroyHandler(gof.Optimizer): ...@@ -144,11 +142,11 @@ class AddDestroyHandler(gof.Optimizer):
for o in fgraph.outputs: for o in fgraph.outputs:
try: try:
fgraph.replace_validate(o, _output_guard(o), fgraph.replace_validate(o, _output_guard(o),
reason='output_guard') reason='output_guard')
_logger.info("Output variable %s required output_guard, " _logger.info("Output variable %s required output_guard, "
"how was this output left unprotected against " "how was this output left unprotected against "
"destructive operations?" "destructive operations?"
% o) % o)
except gof.InconsistencyError: except gof.InconsistencyError:
# This output is already impossible to destroy. # This output is already impossible to destroy.
# No guard necessary # No guard necessary
...@@ -188,50 +186,50 @@ class PrintCurrentFunctionGraph(gof.Optimizer): ...@@ -188,50 +186,50 @@ class PrintCurrentFunctionGraph(gof.Optimizer):
optdb = gof.SequenceDB() optdb = gof.SequenceDB()
optdb.register('merge1', gof.MergeOptimizer(), optdb.register('merge1', gof.MergeOptimizer(),
0, 'fast_run', 'fast_compile', 'merge') 0, 'fast_run', 'fast_compile', 'merge')
# rearranges elemwise expressions # rearranges elemwise expressions
optdb.register('canonicalize', gof.EquilibriumDB(), optdb.register('canonicalize', gof.EquilibriumDB(),
1, 'fast_run', 'fast_compile') 1, 'fast_run', 'fast_compile')
optdb.register('merge1.2', gof.MergeOptimizer(), optdb.register('merge1.2', gof.MergeOptimizer(),
1.2, 'fast_run', 'fast_compile', 'merge') 1.2, 'fast_run', 'fast_compile', 'merge')
optdb.register('Print1.21', PrintCurrentFunctionGraph('Post-canonicalize'), optdb.register('Print1.21', PrintCurrentFunctionGraph('Post-canonicalize'),
1.21,) # 'fast_run', 'fast_compile') 1.21,) # 'fast_run', 'fast_compile')
# replace unstable subgraphs # replace unstable subgraphs
optdb.register('stabilize', gof.EquilibriumDB(), optdb.register('stabilize', gof.EquilibriumDB(),
1.5, 'fast_run') 1.5, 'fast_run')
optdb.register('Print1.51', PrintCurrentFunctionGraph('Post-stabilize'), optdb.register('Print1.51', PrintCurrentFunctionGraph('Post-stabilize'),
1.51,) # 'fast_run', 'fast_compile') 1.51,) # 'fast_run', 'fast_compile')
# misc special cases for speed # misc special cases for speed
optdb.register('specialize', gof.EquilibriumDB(), optdb.register('specialize', gof.EquilibriumDB(),
2, 'fast_run', 'fast_compile_gpu') 2, 'fast_run', 'fast_compile_gpu')
# misc special cases for speed that break canonicalization # misc special cases for speed that break canonicalization
optdb.register('uncanonicalize', gof.EquilibriumDB(), optdb.register('uncanonicalize', gof.EquilibriumDB(),
3, 'fast_run') 3, 'fast_run')
# misc special cases for speed that are dependent on the device. # misc special cases for speed that are dependent on the device.
optdb.register('specialize_device', gof.EquilibriumDB(), optdb.register('specialize_device', gof.EquilibriumDB(),
48.6, 'fast_run') # must be after gpu stuff at 48.5 48.6, 'fast_run') # must be after gpu stuff at 48.5
# especially constant merge # especially constant merge
optdb.register('merge2', gof.MergeOptimizer(), optdb.register('merge2', gof.MergeOptimizer(),
49, 'fast_run', 'merge') 49, 'fast_run', 'merge')
optdb.register('add_no_output_from_inplace', AddNoOutputFromInplace(), optdb.register('add_no_output_from_inplace', AddNoOutputFromInplace(),
49.4) 49.4)
optdb.register('add_destroy_handler', AddDestroyHandler(), optdb.register('add_destroy_handler', AddDestroyHandler(),
49.5, 'fast_run', 'inplace') 49.5, 'fast_run', 'inplace')
# final pass just to make sure # final pass just to make sure
optdb.register('merge3', gof.MergeOptimizer(), optdb.register('merge3', gof.MergeOptimizer(),
100, 'fast_run', 'merge') 100, 'fast_run', 'merge')
class Mode(object): class Mode(object):
...@@ -287,7 +285,8 @@ class Mode(object): ...@@ -287,7 +285,8 @@ class Mode(object):
def __str__(self): def __str__(self):
return "%s(linker = %s, optimizer = %s)" % (self.__class__.__name__, return "%s(linker = %s, optimizer = %s)" % (self.__class__.__name__,
self.provided_linker, self.provided_optimizer) self.provided_linker,
self.provided_optimizer)
def __get_optimizer(self): def __get_optimizer(self):
if isinstance(self._optimizer, gof.Query): if isinstance(self._optimizer, gof.Query):
...@@ -306,19 +305,19 @@ class Mode(object): ...@@ -306,19 +305,19 @@ class Mode(object):
def including(self, *tags): def including(self, *tags):
link, opt = self.get_linker_optimizer(self.provided_linker, link, opt = self.get_linker_optimizer(self.provided_linker,
self.provided_optimizer) self.provided_optimizer)
# N.B. opt might be a Query instance, not sure what else it might be... # N.B. opt might be a Query instance, not sure what else it might be...
# string? Optimizer? OptDB? who knows??? # string? Optimizer? OptDB? who knows???
return self.__class__(linker=link, optimizer=opt.including(*tags)) return self.__class__(linker=link, optimizer=opt.including(*tags))
def excluding(self, *tags): def excluding(self, *tags):
link, opt = self.get_linker_optimizer(self.provided_linker, link, opt = self.get_linker_optimizer(self.provided_linker,
self.provided_optimizer) self.provided_optimizer)
return self.__class__(linker=link, optimizer=opt.excluding(*tags)) return self.__class__(linker=link, optimizer=opt.excluding(*tags))
def requiring(self, *tags): def requiring(self, *tags):
link, opt = self.get_linker_optimizer(self.provided_linker, link, opt = self.get_linker_optimizer(self.provided_linker,
self.provided_optimizer) self.provided_optimizer)
return self.__class__(linker=link, optimizer=opt.requiring(*tags)) return self.__class__(linker=link, optimizer=opt.requiring(*tags))
# If a string is passed as the mode argument in function or # If a string is passed as the mode argument in function or
...@@ -364,10 +363,11 @@ def get_mode(orig_string): ...@@ -364,10 +363,11 @@ def get_mode(orig_string):
# DebugMode use its own linker. # DebugMode use its own linker.
ret = DebugMode(optimizer=config.optimizer) ret = DebugMode(optimizer=config.optimizer)
else: else:
# The import is needed in case string is ProfileMode # This might be required if the string is 'ProfileMode'
from profilemode import ProfileMode, prof_mode_instance_to_print from profilemode import ProfileMode # noqa
ret = eval(string from profilemode import prof_mode_instance_to_print
+ '(linker=config.linker, optimizer=config.optimizer)') ret = eval(string +
'(linker=config.linker, optimizer=config.optimizer)')
elif string in predefined_modes: elif string in predefined_modes:
ret = predefined_modes[string] ret = predefined_modes[string]
else: else:
......
...@@ -38,7 +38,6 @@ whitelist_flake8 = [ ...@@ -38,7 +38,6 @@ whitelist_flake8 = [
"tests/test_tutorial.py", "tests/test_tutorial.py",
"tests/disturb_mem.py", "tests/disturb_mem.py",
"tests/unittest_tools.py", "tests/unittest_tools.py",
"compile/mode.py",
"compile/profilemode.py", "compile/profilemode.py",
"compile/builders.py", "compile/builders.py",
"compile/__init__.py", "compile/__init__.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论