提交 2a1aa262 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Flake8 for compile/mode.py

上级 b5263298
......@@ -2,35 +2,33 @@
"""
from __future__ import print_function
import logging
import warnings
from textwrap import dedent
import numpy
import theano
import theano
from theano import gof
import theano.gof.vm
from theano.configparser import config, AddConfigVar, StrParam
from theano.compile.ops import register_view_op_c_code, _output_guard
from theano.compile.ops import _output_guard
_logger = logging.getLogger('theano.compile.mode')
AddConfigVar('optimizer_excluding',
("When using the default mode, we will remove optimizer with these "
"tags. Separate tags with ':'."),
StrParam("", allow_override=False),
in_c_key=False)
("When using the default mode, we will remove optimizer with "
"these tags. Separate tags with ':'."),
StrParam("", allow_override=False),
in_c_key=False)
AddConfigVar('optimizer_including',
("When using the default mode, we will add optimizer with these tags. "
"Separate tags with ':'."),
StrParam("", allow_override=False),
in_c_key=False)
("When using the default mode, we will add optimizer with "
"these tags. Separate tags with ':'."),
StrParam("", allow_override=False),
in_c_key=False)
AddConfigVar('optimizer_requiring',
("When using the default mode, we will require optimizer with these "
"tags. Separate tags with ':'."),
StrParam("", allow_override=False),
in_c_key=False)
("When using the default mode, we will require optimizer with "
"these tags. Separate tags with ':'."),
StrParam("", allow_override=False),
in_c_key=False)
def check_equal(x, y):
......@@ -54,11 +52,11 @@ def check_equal(x, y):
or x.shape != y.shape
or numpy.any(abs(x - y) > 1e-10)):
raise Exception("Output mismatch.",
{'performlinker': x, 'clinker': y})
{'performlinker': x, 'clinker': y})
else:
if x != y:
raise Exception("Output mismatch.",
{'performlinker': x, 'clinker': y})
{'performlinker': x, 'clinker': y})
# If a string is passed as the linker argument in the constructor for
......@@ -144,11 +142,11 @@ class AddDestroyHandler(gof.Optimizer):
for o in fgraph.outputs:
try:
fgraph.replace_validate(o, _output_guard(o),
reason='output_guard')
reason='output_guard')
_logger.info("Output variable %s required output_guard, "
"how was this output left unprotected against "
"destructive operations?"
% o)
"how was this output left unprotected against "
"destructive operations?"
% o)
except gof.InconsistencyError:
# This output is already impossible to destroy.
# No guard necessary
......@@ -188,50 +186,50 @@ class PrintCurrentFunctionGraph(gof.Optimizer):
optdb = gof.SequenceDB()
optdb.register('merge1', gof.MergeOptimizer(),
0, 'fast_run', 'fast_compile', 'merge')
0, 'fast_run', 'fast_compile', 'merge')
# rearranges elemwise expressions
optdb.register('canonicalize', gof.EquilibriumDB(),
1, 'fast_run', 'fast_compile')
1, 'fast_run', 'fast_compile')
optdb.register('merge1.2', gof.MergeOptimizer(),
1.2, 'fast_run', 'fast_compile', 'merge')
1.2, 'fast_run', 'fast_compile', 'merge')
optdb.register('Print1.21', PrintCurrentFunctionGraph('Post-canonicalize'),
1.21,) # 'fast_run', 'fast_compile')
1.21,) # 'fast_run', 'fast_compile')
# replace unstable subgraphs
optdb.register('stabilize', gof.EquilibriumDB(),
1.5, 'fast_run')
1.5, 'fast_run')
optdb.register('Print1.51', PrintCurrentFunctionGraph('Post-stabilize'),
1.51,) # 'fast_run', 'fast_compile')
1.51,) # 'fast_run', 'fast_compile')
# misc special cases for speed
optdb.register('specialize', gof.EquilibriumDB(),
2, 'fast_run', 'fast_compile_gpu')
2, 'fast_run', 'fast_compile_gpu')
# misc special cases for speed that break canonicalization
optdb.register('uncanonicalize', gof.EquilibriumDB(),
3, 'fast_run')
3, 'fast_run')
# misc special cases for speed that are dependent on the device.
optdb.register('specialize_device', gof.EquilibriumDB(),
48.6, 'fast_run') # must be after gpu stuff at 48.5
48.6, 'fast_run') # must be after gpu stuff at 48.5
# especially constant merge
optdb.register('merge2', gof.MergeOptimizer(),
49, 'fast_run', 'merge')
49, 'fast_run', 'merge')
optdb.register('add_no_output_from_inplace', AddNoOutputFromInplace(),
49.4)
49.4)
optdb.register('add_destroy_handler', AddDestroyHandler(),
49.5, 'fast_run', 'inplace')
49.5, 'fast_run', 'inplace')
# final pass just to make sure
optdb.register('merge3', gof.MergeOptimizer(),
100, 'fast_run', 'merge')
100, 'fast_run', 'merge')
class Mode(object):
......@@ -287,7 +285,8 @@ class Mode(object):
def __str__(self):
return "%s(linker = %s, optimizer = %s)" % (self.__class__.__name__,
self.provided_linker, self.provided_optimizer)
self.provided_linker,
self.provided_optimizer)
def __get_optimizer(self):
if isinstance(self._optimizer, gof.Query):
......@@ -306,19 +305,19 @@ class Mode(object):
def including(self, *tags):
link, opt = self.get_linker_optimizer(self.provided_linker,
self.provided_optimizer)
self.provided_optimizer)
# N.B. opt might be a Query instance, not sure what else it might be...
# string? Optimizer? OptDB? who knows???
return self.__class__(linker=link, optimizer=opt.including(*tags))
def excluding(self, *tags):
link, opt = self.get_linker_optimizer(self.provided_linker,
self.provided_optimizer)
self.provided_optimizer)
return self.__class__(linker=link, optimizer=opt.excluding(*tags))
def requiring(self, *tags):
link, opt = self.get_linker_optimizer(self.provided_linker,
self.provided_optimizer)
self.provided_optimizer)
return self.__class__(linker=link, optimizer=opt.requiring(*tags))
# If a string is passed as the mode argument in function or
......@@ -364,10 +363,11 @@ def get_mode(orig_string):
# DebugMode use its own linker.
ret = DebugMode(optimizer=config.optimizer)
else:
# The import is needed in case string is ProfileMode
from profilemode import ProfileMode, prof_mode_instance_to_print
ret = eval(string
+ '(linker=config.linker, optimizer=config.optimizer)')
# This might be required if the string is 'ProfileMode'
from profilemode import ProfileMode # noqa
from profilemode import prof_mode_instance_to_print
ret = eval(string +
'(linker=config.linker, optimizer=config.optimizer)')
elif string in predefined_modes:
ret = predefined_modes[string]
else:
......
......@@ -38,7 +38,6 @@ whitelist_flake8 = [
"tests/test_tutorial.py",
"tests/disturb_mem.py",
"tests/unittest_tools.py",
"compile/mode.py",
"compile/profilemode.py",
"compile/builders.py",
"compile/__init__.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论