提交 700d883b authored 作者: Iban Harlouchet's avatar Iban Harlouchet

flake8 of theano/gof/cmodule.py

上级 bee83a7a
"""Generate and compile C modules for Python,
"""
from __future__ import print_function
import atexit
import six.moves.cPickle as pickle
import logging
......@@ -15,12 +17,6 @@ import time
import platform
import distutils.sysconfig
importlib = None
try:
import importlib
except ImportError:
pass
import numpy.distutils # TODO: TensorType should handle this
import theano
......@@ -38,10 +34,17 @@ from theano.gof.compiledir import gcc_version_str, local_bitwidth
from theano.configparser import AddConfigVar, BoolParam
AddConfigVar('cmodule.mac_framework_link',
"If set to True, breaks certain MacOS installations with the infamous "
"Bus Error",
BoolParam(False))
importlib = None
try:
import importlib
except ImportError:
pass
AddConfigVar(
'cmodule.mac_framework_link',
"If set to True, breaks certain MacOS installations with the infamous "
"Bus Error",
BoolParam(False))
AddConfigVar('cmodule.warn_no_version',
"If True, will print a warning when compiling one or more Op "
......@@ -131,15 +134,16 @@ class ExtFunction(object):
It goes into the DynamicModule's method table.
"""
return '\t{"%s", %s, %s, "%s"}' % (
self.name, self.name, self.method, self.doc)
self.name, self.name, self.method, self.doc)
class DynamicModule(object):
def __init__(self, name=None):
assert name is None, ("The 'name' parameter of DynamicModule"
" cannot be specified anymore. Instead, 'code_hash'"
" will be automatically computed and can be used as"
" the module's name.")
assert name is None, (
"The 'name' parameter of DynamicModule"
" cannot be specified anymore. Instead, 'code_hash'"
" will be automatically computed and can be used as"
" the module's name.")
# While the module is not finalized, we can call add_...
# when it is finalized, a hash is computed and used instead of
# the placeholder, and as module name.
......@@ -171,18 +175,18 @@ static struct PyModuleDef moduledef = {{
}};
""".format(name=self.hash_placeholder), file=stream)
print(("PyMODINIT_FUNC PyInit_%s(void) {" %
self.hash_placeholder), file=stream)
self.hash_placeholder), file=stream)
for block in self.init_blocks:
print(' ', block, file=stream)
print(" PyObject *m = PyModule_Create(&moduledef);", file=stream)
print(" return m;", file=stream)
else:
print(("PyMODINIT_FUNC init%s(void){" %
self.hash_placeholder), file=stream)
self.hash_placeholder), file=stream)
for block in self.init_blocks:
print(' ', block, file=stream)
print(' ', ('(void) Py_InitModule("%s", MyMethods);'
% self.hash_placeholder), file=stream)
% self.hash_placeholder), file=stream)
print("}", file=stream)
def add_include(self, str):
......@@ -351,9 +355,9 @@ def is_same_entry(entry_1, entry_2):
if os.path.realpath(entry_1) == os.path.realpath(entry_2):
return True
if (os.path.basename(entry_1) == os.path.basename(entry_2) and
(os.path.basename(os.path.dirname(entry_1)) ==
os.path.basename(os.path.dirname(entry_2))) and
os.path.basename(os.path.dirname(entry_1)).startswith('tmp')):
(os.path.basename(os.path.dirname(entry_1)) ==
os.path.basename(os.path.dirname(entry_2))) and
os.path.basename(os.path.dirname(entry_1)).startswith('tmp')):
return True
return False
......@@ -429,8 +433,8 @@ def get_safe_part(key):
# Find the md5 hash part.
c_link_key = key[1]
for key_element in c_link_key[1:]:
if (isinstance(key_element, string_types)
and key_element.startswith('md5:')):
if (isinstance(key_element, string_types) and
key_element.startswith('md5:')):
md5 = key_element[4:]
break
......@@ -761,9 +765,9 @@ class ModuleCache(object):
# simpler to implement).
rmtree(root, ignore_nocleanup=True,
msg=(
'invalid cache entry format -- this '
'should not happen unless your cache '
'was really old'),
'invalid cache entry format -- this '
'should not happen unless your cache '
'was really old'),
level=logging.WARN)
continue
......@@ -964,7 +968,7 @@ class ModuleCache(object):
# process that could be changing the file at the same
# time.
if (key[0] and not key_broken and
self.check_for_broken_eq):
self.check_for_broken_eq):
self.check_key(key, key_data.key_pkl)
self._update_mappings(key, key_data, module.__file__, check_in_keys=not key_broken)
return module
......@@ -1149,15 +1153,14 @@ class ModuleCache(object):
# This is to make debugging in pdb easier, by providing
# the offending keys in the local context.
# key_data_keys = list(key_data.keys)
## import pdb; pdb.set_trace()
# import pdb; pdb.set_trace()
pass
elif found > 1:
msg = 'Multiple equal keys found in unpickled KeyData file'
if msg:
raise AssertionError(
"%s. Verify the __eq__ and __hash__ functions of your "
"Ops. The file is: %s. The key is: %s" %
(msg, key_pkl, key))
"%s. Verify the __eq__ and __hash__ functions of your "
"Ops. The file is: %s. The key is: %s" % (msg, key_pkl, key))
# Also verify that there exists no other loaded key that would be equal
# to this key. In order to speed things up, we only compare to keys
# with the same version part and config md5, since we can assume this
......@@ -1195,10 +1198,10 @@ class ModuleCache(object):
if age_thresh_del < self.age_thresh_use:
if age_thresh_del > 0:
_logger.warning("Clearing modules that were not deemed "
"too old to use: age_thresh_del=%d, "
"self.age_thresh_use=%d",
age_thresh_del,
self.age_thresh_use)
"too old to use: age_thresh_del=%d, "
"self.age_thresh_use=%d",
age_thresh_del,
self.age_thresh_use)
else:
_logger.info("Clearing all modules.")
age_thresh_use = age_thresh_del
......@@ -1210,8 +1213,8 @@ class ModuleCache(object):
# processes and get all module that are too old to use
# (not loaded in self.entry_from_key).
too_old_to_use = self.refresh(
age_thresh_use=age_thresh_use,
delete_if_problem=delete_if_problem)
age_thresh_use=age_thresh_use,
delete_if_problem=delete_if_problem)
for entry in too_old_to_use:
# TODO: we are assuming that modules that haven't been
......@@ -1242,8 +1245,8 @@ class ModuleCache(object):
"""
with compilelock.lock_ctx():
self.clear_old(
age_thresh_del=-1.0,
delete_if_problem=delete_if_problem)
age_thresh_del=-1.0,
delete_if_problem=delete_if_problem)
self.clear_unversioned(min_age=unversioned_min_age)
if clear_base_files:
self.clear_base_files()
......@@ -1333,7 +1336,7 @@ class ModuleCache(object):
if filename.startswith('tmp'):
try:
open(os.path.join(self.dirname, filename, 'key.pkl')
).close()
).close()
has_key = True
except IOError:
has_key = False
......@@ -1420,8 +1423,8 @@ def get_module_cache(dirname, init_args=None):
'was created prior to this call')
if _module_cache.dirname != dirname:
_logger.warning("Returning module cache instance with different "
"dirname (%s) than you requested (%s)",
_module_cache.dirname, dirname)
"dirname (%s) than you requested (%s)",
_module_cache.dirname, dirname)
return _module_cache
......@@ -1685,7 +1688,7 @@ class GCC_compiler(Compiler):
break
if ('g++' not in theano.config.cxx and
'clang++' not in theano.config.cxx):
'clang++' not in theano.config.cxx):
_logger.warn(
"OPTIMIZATION WARNING: your Theano flag `cxx` seems not to be"
" the g++ compiler. So we disable the compiler optimization"
......@@ -1719,9 +1722,9 @@ class GCC_compiler(Compiler):
selected_lines = []
for line in lines:
if ("COLLECT_GCC_OPTIONS=" in line or
"CFLAGS=" in line or
"CXXFLAGS=" in line or
"-march=native" in line):
"CFLAGS=" in line or
"CXXFLAGS=" in line or
"-march=native" in line):
continue
elif "-march=" in line:
selected_lines.append(line.strip())
......@@ -1805,9 +1808,9 @@ class GCC_compiler(Compiler):
for line in default_lines:
if line.startswith(part[0]):
part2 = [p for p in join_options(line.split())
if (not 'march' in p and
not 'mtune' in p and
not 'target-cpu' in p)]
if ('march' not in p and
'mtune' not in p and
'target-cpu' not in p)]
new_flags = [p for p in part if p not in part2]
# Replace '-target-cpu value', which is an option
# of clang, with '-march=value', for g++
......@@ -2021,14 +2024,13 @@ class GCC_compiler(Compiler):
cmd.append(cppfilename)
cmd.extend(['-L%s' % ldir for ldir in lib_dirs])
cmd.extend(['-l%s' % l for l in libs])
#print >> sys.stderr, 'COMPILING W CMD', cmd
# print >> sys.stderr, 'COMPILING W CMD', cmd
_logger.debug('Running cmd: %s', ' '.join(cmd))
def print_command_line_error():
# Print command line when a problem occurred.
print((
"Problem occurred during compilation with the "
"command line below:"), file=sys.stderr)
print(("Problem occurred during compilation with the "
"command line below:"), file=sys.stderr)
print(' '.join(cmd), file=sys.stderr)
try:
......
......@@ -244,7 +244,6 @@ whitelist_flake8 = [
"gof/graph.py",
"gof/__init__.py",
"gof/op.py",
"gof/cmodule.py",
"gof/tests/test_cmodule.py",
"gof/tests/test_destroyhandler.py",
"gof/tests/test_opt.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论