提交 700d883b authored 作者: Iban Harlouchet's avatar Iban Harlouchet

flake8 of theano/gof/cmodule.py

上级 bee83a7a
"""Generate and compile C modules for Python, """Generate and compile C modules for Python,
""" """
from __future__ import print_function from __future__ import print_function
import atexit import atexit
import six.moves.cPickle as pickle import six.moves.cPickle as pickle
import logging import logging
...@@ -15,12 +17,6 @@ import time ...@@ -15,12 +17,6 @@ import time
import platform import platform
import distutils.sysconfig import distutils.sysconfig
importlib = None
try:
import importlib
except ImportError:
pass
import numpy.distutils # TODO: TensorType should handle this import numpy.distutils # TODO: TensorType should handle this
import theano import theano
...@@ -38,7 +34,14 @@ from theano.gof.compiledir import gcc_version_str, local_bitwidth ...@@ -38,7 +34,14 @@ from theano.gof.compiledir import gcc_version_str, local_bitwidth
from theano.configparser import AddConfigVar, BoolParam from theano.configparser import AddConfigVar, BoolParam
AddConfigVar('cmodule.mac_framework_link', importlib = None
try:
import importlib
except ImportError:
pass
AddConfigVar(
'cmodule.mac_framework_link',
"If set to True, breaks certain MacOS installations with the infamous " "If set to True, breaks certain MacOS installations with the infamous "
"Bus Error", "Bus Error",
BoolParam(False)) BoolParam(False))
...@@ -136,7 +139,8 @@ class ExtFunction(object): ...@@ -136,7 +139,8 @@ class ExtFunction(object):
class DynamicModule(object): class DynamicModule(object):
def __init__(self, name=None): def __init__(self, name=None):
assert name is None, ("The 'name' parameter of DynamicModule" assert name is None, (
"The 'name' parameter of DynamicModule"
" cannot be specified anymore. Instead, 'code_hash'" " cannot be specified anymore. Instead, 'code_hash'"
" will be automatically computed and can be used as" " will be automatically computed and can be used as"
" the module's name.") " the module's name.")
...@@ -429,8 +433,8 @@ def get_safe_part(key): ...@@ -429,8 +433,8 @@ def get_safe_part(key):
# Find the md5 hash part. # Find the md5 hash part.
c_link_key = key[1] c_link_key = key[1]
for key_element in c_link_key[1:]: for key_element in c_link_key[1:]:
if (isinstance(key_element, string_types) if (isinstance(key_element, string_types) and
and key_element.startswith('md5:')): key_element.startswith('md5:')):
md5 = key_element[4:] md5 = key_element[4:]
break break
...@@ -1149,15 +1153,14 @@ class ModuleCache(object): ...@@ -1149,15 +1153,14 @@ class ModuleCache(object):
# This is to make debugging in pdb easier, by providing # This is to make debugging in pdb easier, by providing
# the offending keys in the local context. # the offending keys in the local context.
# key_data_keys = list(key_data.keys) # key_data_keys = list(key_data.keys)
## import pdb; pdb.set_trace() # import pdb; pdb.set_trace()
pass pass
elif found > 1: elif found > 1:
msg = 'Multiple equal keys found in unpickled KeyData file' msg = 'Multiple equal keys found in unpickled KeyData file'
if msg: if msg:
raise AssertionError( raise AssertionError(
"%s. Verify the __eq__ and __hash__ functions of your " "%s. Verify the __eq__ and __hash__ functions of your "
"Ops. The file is: %s. The key is: %s" % "Ops. The file is: %s. The key is: %s" % (msg, key_pkl, key))
(msg, key_pkl, key))
# Also verify that there exists no other loaded key that would be equal # Also verify that there exists no other loaded key that would be equal
# to this key. In order to speed things up, we only compare to keys # to this key. In order to speed things up, we only compare to keys
# with the same version part and config md5, since we can assume this # with the same version part and config md5, since we can assume this
...@@ -1805,9 +1808,9 @@ class GCC_compiler(Compiler): ...@@ -1805,9 +1808,9 @@ class GCC_compiler(Compiler):
for line in default_lines: for line in default_lines:
if line.startswith(part[0]): if line.startswith(part[0]):
part2 = [p for p in join_options(line.split()) part2 = [p for p in join_options(line.split())
if (not 'march' in p and if ('march' not in p and
not 'mtune' in p and 'mtune' not in p and
not 'target-cpu' in p)] 'target-cpu' not in p)]
new_flags = [p for p in part if p not in part2] new_flags = [p for p in part if p not in part2]
# Replace '-target-cpu value', which is an option # Replace '-target-cpu value', which is an option
# of clang, with '-march=value', for g++ # of clang, with '-march=value', for g++
...@@ -2021,13 +2024,12 @@ class GCC_compiler(Compiler): ...@@ -2021,13 +2024,12 @@ class GCC_compiler(Compiler):
cmd.append(cppfilename) cmd.append(cppfilename)
cmd.extend(['-L%s' % ldir for ldir in lib_dirs]) cmd.extend(['-L%s' % ldir for ldir in lib_dirs])
cmd.extend(['-l%s' % l for l in libs]) cmd.extend(['-l%s' % l for l in libs])
#print >> sys.stderr, 'COMPILING W CMD', cmd # print >> sys.stderr, 'COMPILING W CMD', cmd
_logger.debug('Running cmd: %s', ' '.join(cmd)) _logger.debug('Running cmd: %s', ' '.join(cmd))
def print_command_line_error(): def print_command_line_error():
# Print command line when a problem occurred. # Print command line when a problem occurred.
print(( print(("Problem occurred during compilation with the "
"Problem occurred during compilation with the "
"command line below:"), file=sys.stderr) "command line below:"), file=sys.stderr)
print(' '.join(cmd), file=sys.stderr) print(' '.join(cmd), file=sys.stderr)
......
...@@ -244,7 +244,6 @@ whitelist_flake8 = [ ...@@ -244,7 +244,6 @@ whitelist_flake8 = [
"gof/graph.py", "gof/graph.py",
"gof/__init__.py", "gof/__init__.py",
"gof/op.py", "gof/op.py",
"gof/cmodule.py",
"gof/tests/test_cmodule.py", "gof/tests/test_cmodule.py",
"gof/tests/test_destroyhandler.py", "gof/tests/test_destroyhandler.py",
"gof/tests/test_opt.py", "gof/tests/test_opt.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论