提交 700d883b authored 作者: Iban Harlouchet's avatar Iban Harlouchet

flake8 of theano/gof/cmodule.py

上级 bee83a7a
"""Generate and compile C modules for Python, """Generate and compile C modules for Python,
""" """
from __future__ import print_function from __future__ import print_function
import atexit import atexit
import six.moves.cPickle as pickle import six.moves.cPickle as pickle
import logging import logging
...@@ -15,12 +17,6 @@ import time ...@@ -15,12 +17,6 @@ import time
import platform import platform
import distutils.sysconfig import distutils.sysconfig
importlib = None
try:
import importlib
except ImportError:
pass
import numpy.distutils # TODO: TensorType should handle this import numpy.distutils # TODO: TensorType should handle this
import theano import theano
...@@ -38,10 +34,17 @@ from theano.gof.compiledir import gcc_version_str, local_bitwidth ...@@ -38,10 +34,17 @@ from theano.gof.compiledir import gcc_version_str, local_bitwidth
from theano.configparser import AddConfigVar, BoolParam from theano.configparser import AddConfigVar, BoolParam
AddConfigVar('cmodule.mac_framework_link', importlib = None
"If set to True, breaks certain MacOS installations with the infamous " try:
"Bus Error", import importlib
BoolParam(False)) except ImportError:
pass
AddConfigVar(
'cmodule.mac_framework_link',
"If set to True, breaks certain MacOS installations with the infamous "
"Bus Error",
BoolParam(False))
AddConfigVar('cmodule.warn_no_version', AddConfigVar('cmodule.warn_no_version',
"If True, will print a warning when compiling one or more Op " "If True, will print a warning when compiling one or more Op "
...@@ -131,15 +134,16 @@ class ExtFunction(object): ...@@ -131,15 +134,16 @@ class ExtFunction(object):
It goes into the DynamicModule's method table. It goes into the DynamicModule's method table.
""" """
return '\t{"%s", %s, %s, "%s"}' % ( return '\t{"%s", %s, %s, "%s"}' % (
self.name, self.name, self.method, self.doc) self.name, self.name, self.method, self.doc)
class DynamicModule(object): class DynamicModule(object):
def __init__(self, name=None): def __init__(self, name=None):
assert name is None, ("The 'name' parameter of DynamicModule" assert name is None, (
" cannot be specified anymore. Instead, 'code_hash'" "The 'name' parameter of DynamicModule"
" will be automatically computed and can be used as" " cannot be specified anymore. Instead, 'code_hash'"
" the module's name.") " will be automatically computed and can be used as"
" the module's name.")
# While the module is not finalized, we can call add_... # While the module is not finalized, we can call add_...
# when it is finalized, a hash is computed and used instead of # when it is finalized, a hash is computed and used instead of
# the placeholder, and as module name. # the placeholder, and as module name.
...@@ -171,18 +175,18 @@ static struct PyModuleDef moduledef = {{ ...@@ -171,18 +175,18 @@ static struct PyModuleDef moduledef = {{
}}; }};
""".format(name=self.hash_placeholder), file=stream) """.format(name=self.hash_placeholder), file=stream)
print(("PyMODINIT_FUNC PyInit_%s(void) {" % print(("PyMODINIT_FUNC PyInit_%s(void) {" %
self.hash_placeholder), file=stream) self.hash_placeholder), file=stream)
for block in self.init_blocks: for block in self.init_blocks:
print(' ', block, file=stream) print(' ', block, file=stream)
print(" PyObject *m = PyModule_Create(&moduledef);", file=stream) print(" PyObject *m = PyModule_Create(&moduledef);", file=stream)
print(" return m;", file=stream) print(" return m;", file=stream)
else: else:
print(("PyMODINIT_FUNC init%s(void){" % print(("PyMODINIT_FUNC init%s(void){" %
self.hash_placeholder), file=stream) self.hash_placeholder), file=stream)
for block in self.init_blocks: for block in self.init_blocks:
print(' ', block, file=stream) print(' ', block, file=stream)
print(' ', ('(void) Py_InitModule("%s", MyMethods);' print(' ', ('(void) Py_InitModule("%s", MyMethods);'
% self.hash_placeholder), file=stream) % self.hash_placeholder), file=stream)
print("}", file=stream) print("}", file=stream)
def add_include(self, str): def add_include(self, str):
...@@ -351,9 +355,9 @@ def is_same_entry(entry_1, entry_2): ...@@ -351,9 +355,9 @@ def is_same_entry(entry_1, entry_2):
if os.path.realpath(entry_1) == os.path.realpath(entry_2): if os.path.realpath(entry_1) == os.path.realpath(entry_2):
return True return True
if (os.path.basename(entry_1) == os.path.basename(entry_2) and if (os.path.basename(entry_1) == os.path.basename(entry_2) and
(os.path.basename(os.path.dirname(entry_1)) == (os.path.basename(os.path.dirname(entry_1)) ==
os.path.basename(os.path.dirname(entry_2))) and os.path.basename(os.path.dirname(entry_2))) and
os.path.basename(os.path.dirname(entry_1)).startswith('tmp')): os.path.basename(os.path.dirname(entry_1)).startswith('tmp')):
return True return True
return False return False
...@@ -429,8 +433,8 @@ def get_safe_part(key): ...@@ -429,8 +433,8 @@ def get_safe_part(key):
# Find the md5 hash part. # Find the md5 hash part.
c_link_key = key[1] c_link_key = key[1]
for key_element in c_link_key[1:]: for key_element in c_link_key[1:]:
if (isinstance(key_element, string_types) if (isinstance(key_element, string_types) and
and key_element.startswith('md5:')): key_element.startswith('md5:')):
md5 = key_element[4:] md5 = key_element[4:]
break break
...@@ -761,9 +765,9 @@ class ModuleCache(object): ...@@ -761,9 +765,9 @@ class ModuleCache(object):
# simpler to implement). # simpler to implement).
rmtree(root, ignore_nocleanup=True, rmtree(root, ignore_nocleanup=True,
msg=( msg=(
'invalid cache entry format -- this ' 'invalid cache entry format -- this '
'should not happen unless your cache ' 'should not happen unless your cache '
'was really old'), 'was really old'),
level=logging.WARN) level=logging.WARN)
continue continue
...@@ -964,7 +968,7 @@ class ModuleCache(object): ...@@ -964,7 +968,7 @@ class ModuleCache(object):
# process that could be changing the file at the same # process that could be changing the file at the same
# time. # time.
if (key[0] and not key_broken and if (key[0] and not key_broken and
self.check_for_broken_eq): self.check_for_broken_eq):
self.check_key(key, key_data.key_pkl) self.check_key(key, key_data.key_pkl)
self._update_mappings(key, key_data, module.__file__, check_in_keys=not key_broken) self._update_mappings(key, key_data, module.__file__, check_in_keys=not key_broken)
return module return module
...@@ -1149,15 +1153,14 @@ class ModuleCache(object): ...@@ -1149,15 +1153,14 @@ class ModuleCache(object):
# This is to make debugging in pdb easier, by providing # This is to make debugging in pdb easier, by providing
# the offending keys in the local context. # the offending keys in the local context.
# key_data_keys = list(key_data.keys) # key_data_keys = list(key_data.keys)
## import pdb; pdb.set_trace() # import pdb; pdb.set_trace()
pass pass
elif found > 1: elif found > 1:
msg = 'Multiple equal keys found in unpickled KeyData file' msg = 'Multiple equal keys found in unpickled KeyData file'
if msg: if msg:
raise AssertionError( raise AssertionError(
"%s. Verify the __eq__ and __hash__ functions of your " "%s. Verify the __eq__ and __hash__ functions of your "
"Ops. The file is: %s. The key is: %s" % "Ops. The file is: %s. The key is: %s" % (msg, key_pkl, key))
(msg, key_pkl, key))
# Also verify that there exists no other loaded key that would be equal # Also verify that there exists no other loaded key that would be equal
# to this key. In order to speed things up, we only compare to keys # to this key. In order to speed things up, we only compare to keys
# with the same version part and config md5, since we can assume this # with the same version part and config md5, since we can assume this
...@@ -1195,10 +1198,10 @@ class ModuleCache(object): ...@@ -1195,10 +1198,10 @@ class ModuleCache(object):
if age_thresh_del < self.age_thresh_use: if age_thresh_del < self.age_thresh_use:
if age_thresh_del > 0: if age_thresh_del > 0:
_logger.warning("Clearing modules that were not deemed " _logger.warning("Clearing modules that were not deemed "
"too old to use: age_thresh_del=%d, " "too old to use: age_thresh_del=%d, "
"self.age_thresh_use=%d", "self.age_thresh_use=%d",
age_thresh_del, age_thresh_del,
self.age_thresh_use) self.age_thresh_use)
else: else:
_logger.info("Clearing all modules.") _logger.info("Clearing all modules.")
age_thresh_use = age_thresh_del age_thresh_use = age_thresh_del
...@@ -1210,8 +1213,8 @@ class ModuleCache(object): ...@@ -1210,8 +1213,8 @@ class ModuleCache(object):
# processes and get all module that are too old to use # processes and get all module that are too old to use
# (not loaded in self.entry_from_key). # (not loaded in self.entry_from_key).
too_old_to_use = self.refresh( too_old_to_use = self.refresh(
age_thresh_use=age_thresh_use, age_thresh_use=age_thresh_use,
delete_if_problem=delete_if_problem) delete_if_problem=delete_if_problem)
for entry in too_old_to_use: for entry in too_old_to_use:
# TODO: we are assuming that modules that haven't been # TODO: we are assuming that modules that haven't been
...@@ -1242,8 +1245,8 @@ class ModuleCache(object): ...@@ -1242,8 +1245,8 @@ class ModuleCache(object):
""" """
with compilelock.lock_ctx(): with compilelock.lock_ctx():
self.clear_old( self.clear_old(
age_thresh_del=-1.0, age_thresh_del=-1.0,
delete_if_problem=delete_if_problem) delete_if_problem=delete_if_problem)
self.clear_unversioned(min_age=unversioned_min_age) self.clear_unversioned(min_age=unversioned_min_age)
if clear_base_files: if clear_base_files:
self.clear_base_files() self.clear_base_files()
...@@ -1333,7 +1336,7 @@ class ModuleCache(object): ...@@ -1333,7 +1336,7 @@ class ModuleCache(object):
if filename.startswith('tmp'): if filename.startswith('tmp'):
try: try:
open(os.path.join(self.dirname, filename, 'key.pkl') open(os.path.join(self.dirname, filename, 'key.pkl')
).close() ).close()
has_key = True has_key = True
except IOError: except IOError:
has_key = False has_key = False
...@@ -1420,8 +1423,8 @@ def get_module_cache(dirname, init_args=None): ...@@ -1420,8 +1423,8 @@ def get_module_cache(dirname, init_args=None):
'was created prior to this call') 'was created prior to this call')
if _module_cache.dirname != dirname: if _module_cache.dirname != dirname:
_logger.warning("Returning module cache instance with different " _logger.warning("Returning module cache instance with different "
"dirname (%s) than you requested (%s)", "dirname (%s) than you requested (%s)",
_module_cache.dirname, dirname) _module_cache.dirname, dirname)
return _module_cache return _module_cache
...@@ -1685,7 +1688,7 @@ class GCC_compiler(Compiler): ...@@ -1685,7 +1688,7 @@ class GCC_compiler(Compiler):
break break
if ('g++' not in theano.config.cxx and if ('g++' not in theano.config.cxx and
'clang++' not in theano.config.cxx): 'clang++' not in theano.config.cxx):
_logger.warn( _logger.warn(
"OPTIMIZATION WARNING: your Theano flag `cxx` seems not to be" "OPTIMIZATION WARNING: your Theano flag `cxx` seems not to be"
" the g++ compiler. So we disable the compiler optimization" " the g++ compiler. So we disable the compiler optimization"
...@@ -1719,9 +1722,9 @@ class GCC_compiler(Compiler): ...@@ -1719,9 +1722,9 @@ class GCC_compiler(Compiler):
selected_lines = [] selected_lines = []
for line in lines: for line in lines:
if ("COLLECT_GCC_OPTIONS=" in line or if ("COLLECT_GCC_OPTIONS=" in line or
"CFLAGS=" in line or "CFLAGS=" in line or
"CXXFLAGS=" in line or "CXXFLAGS=" in line or
"-march=native" in line): "-march=native" in line):
continue continue
elif "-march=" in line: elif "-march=" in line:
selected_lines.append(line.strip()) selected_lines.append(line.strip())
...@@ -1805,9 +1808,9 @@ class GCC_compiler(Compiler): ...@@ -1805,9 +1808,9 @@ class GCC_compiler(Compiler):
for line in default_lines: for line in default_lines:
if line.startswith(part[0]): if line.startswith(part[0]):
part2 = [p for p in join_options(line.split()) part2 = [p for p in join_options(line.split())
if (not 'march' in p and if ('march' not in p and
not 'mtune' in p and 'mtune' not in p and
not 'target-cpu' in p)] 'target-cpu' not in p)]
new_flags = [p for p in part if p not in part2] new_flags = [p for p in part if p not in part2]
# Replace '-target-cpu value', which is an option # Replace '-target-cpu value', which is an option
# of clang, with '-march=value', for g++ # of clang, with '-march=value', for g++
...@@ -2021,14 +2024,13 @@ class GCC_compiler(Compiler): ...@@ -2021,14 +2024,13 @@ class GCC_compiler(Compiler):
cmd.append(cppfilename) cmd.append(cppfilename)
cmd.extend(['-L%s' % ldir for ldir in lib_dirs]) cmd.extend(['-L%s' % ldir for ldir in lib_dirs])
cmd.extend(['-l%s' % l for l in libs]) cmd.extend(['-l%s' % l for l in libs])
#print >> sys.stderr, 'COMPILING W CMD', cmd # print >> sys.stderr, 'COMPILING W CMD', cmd
_logger.debug('Running cmd: %s', ' '.join(cmd)) _logger.debug('Running cmd: %s', ' '.join(cmd))
def print_command_line_error(): def print_command_line_error():
# Print command line when a problem occurred. # Print command line when a problem occurred.
print(( print(("Problem occurred during compilation with the "
"Problem occurred during compilation with the " "command line below:"), file=sys.stderr)
"command line below:"), file=sys.stderr)
print(' '.join(cmd), file=sys.stderr) print(' '.join(cmd), file=sys.stderr)
try: try:
......
...@@ -244,7 +244,6 @@ whitelist_flake8 = [ ...@@ -244,7 +244,6 @@ whitelist_flake8 = [
"gof/graph.py", "gof/graph.py",
"gof/__init__.py", "gof/__init__.py",
"gof/op.py", "gof/op.py",
"gof/cmodule.py",
"gof/tests/test_cmodule.py", "gof/tests/test_cmodule.py",
"gof/tests/test_destroyhandler.py", "gof/tests/test_destroyhandler.py",
"gof/tests/test_opt.py", "gof/tests/test_opt.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论