提交 57bea857 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Remove the restriction that we need to create the compiledir to have the source…

Remove the restriction that we need to create the compiledir to have the source code by removing the compile_cmodule_by_step interface.
上级 2a7391a7
...@@ -1282,26 +1282,19 @@ class CLinker(link.Linker): ...@@ -1282,26 +1282,19 @@ class CLinker(link.Linker):
return ((), sig) return ((), sig)
return version, sig return version, sig
def compile_cmodule(self, location=None): def get_src_code(self):
"""
Compile the module and return it.
""" """
# Go through all steps of the compilation process. Returns the source code for the module represented by this linker.
for step_result in self.compile_cmodule_by_step(location=location):
pass
# And return the output of the last step, which should be the module
# itself.
return step_result
def compile_cmodule_by_step(self, location=None):
""" """
This method is a callback for `ModuleCache.module_from_key`. if not hasattr(self, '_source'):
mod = self.build_dynamic_module()
self._source = mod.code()
return self._source
It is a generator (thus the 'by step'), so that: def compile_cmodule(self, location=None):
- it first yields the module's C code """
- it last yields the module itself This compiles the source code for this linker and returns a
- it may yield other intermediate outputs in-between if needed loaded module.
in the future (but this is not currently the case)
""" """
if location is None: if location is None:
location = cmodule.dlimport_workdir(config.compiledir) location = cmodule.dlimport_workdir(config.compiledir)
...@@ -1324,48 +1317,46 @@ class CLinker(link.Linker): ...@@ -1324,48 +1317,46 @@ class CLinker(link.Linker):
if 'amdlibm' in libs: if 'amdlibm' in libs:
libs.remove('amdlibm') libs.remove('amdlibm')
src_code = mod.code() src_code = mod.code()
yield src_code
get_lock() get_lock()
try: try:
_logger.debug("LOCATION %s", str(location)) _logger.debug("LOCATION %s", str(location))
try: module = c_compiler.compile_str(
module = c_compiler.compile_str( module_name=mod.code_hash,
module_name=mod.code_hash, src_code=mod.code(),
src_code=src_code, location=location,
location=location, include_dirs=self.header_dirs(),
include_dirs=self.header_dirs(), lib_dirs=self.lib_dirs(),
lib_dirs=self.lib_dirs(), libs=libs,
libs=libs, preargs=preargs)
preargs=preargs) except Exception, e:
except Exception, e: e.args += (str(self.fgraph),)
e.args += (str(self.fgraph),) raise
raise
finally: finally:
release_lock() release_lock()
return module
yield module
def build_dynamic_module(self): def build_dynamic_module(self):
"""Return a cmodule.DynamicModule instance full of the code """Return a cmodule.DynamicModule instance full of the code
for our fgraph. for our fgraph.
""" """
self.code_gen() if not hasattr(self, '_mod'):
self.code_gen()
mod = cmodule.DynamicModule() mod = cmodule.DynamicModule()
# The code of instantiate # The code of instantiate
# the 1 is for error_storage # the 1 is for error_storage
code = self.instantiate_code(1 + len(self.args)) code = self.instantiate_code(1 + len(self.args))
instantiate = cmodule.ExtFunction('instantiate', code, instantiate = cmodule.ExtFunction('instantiate', code,
method=cmodule.METH_VARARGS) method=cmodule.METH_VARARGS)
#['error_storage'] + argnames, #['error_storage'] + argnames,
#local_dict = d, #local_dict = d,
#global_dict = {}) #global_dict = {})
# Static methods that can run and destroy the struct built by # Static methods that can run and destroy the struct built by
# instantiate. # instantiate.
if PY3: if PY3:
static = """ static = """
static int {struct_name}_executor({struct_name} *self) {{ static int {struct_name}_executor({struct_name} *self) {{
return self->run(); return self->run();
}} }}
...@@ -1375,8 +1366,8 @@ class CLinker(link.Linker): ...@@ -1375,8 +1366,8 @@ class CLinker(link.Linker):
delete self; delete self;
}} }}
""".format(struct_name=self.struct_name) """.format(struct_name=self.struct_name)
else: else:
static = """ static = """
static int %(struct_name)s_executor(%(struct_name)s* self) { static int %(struct_name)s_executor(%(struct_name)s* self) {
return self->run(); return self->run();
} }
...@@ -1387,17 +1378,17 @@ class CLinker(link.Linker): ...@@ -1387,17 +1378,17 @@ class CLinker(link.Linker):
""" % dict(struct_name=self.struct_name) """ % dict(struct_name=self.struct_name)
# We add all the support code, compile args, headers and libs we need. # We add all the support code, compile args, headers and libs we need.
for support_code in self.support_code() + self.c_support_code_apply: for support_code in self.support_code() + self.c_support_code_apply:
mod.add_support_code(support_code) mod.add_support_code(support_code)
mod.add_support_code(self.struct_code) mod.add_support_code(self.struct_code)
mod.add_support_code(static) mod.add_support_code(static)
mod.add_function(instantiate) mod.add_function(instantiate)
for header in self.headers(): for header in self.headers():
mod.add_include(header) mod.add_include(header)
for init_code_block in self.init_code() + self.c_init_code_apply: for init_code_block in self.init_code() + self.c_init_code_apply:
mod.add_init_code(init_code_block) mod.add_init_code(init_code_block)
self._mod = mod
return mod return self._mod
def cthunk_factory(self, error_storage, in_storage, out_storage, def cthunk_factory(self, error_storage, in_storage, out_storage,
keep_lock=False): keep_lock=False):
...@@ -1421,7 +1412,7 @@ class CLinker(link.Linker): ...@@ -1421,7 +1412,7 @@ class CLinker(link.Linker):
module = self.compile_cmodule() module = self.compile_cmodule()
else: else:
module = get_module_cache().module_from_key( module = get_module_cache().module_from_key(
key=key, fn=self.compile_cmodule_by_step, keep_lock=keep_lock) key=key, lnk=self, keep_lock=keep_lock)
vars = self.inputs + self.outputs + self.orphans vars = self.inputs + self.outputs + self.orphans
# List of indices that should be ignored when passing the arguments # List of indices that should be ignored when passing the arguments
......
...@@ -995,42 +995,38 @@ class ModuleCache(object): ...@@ -995,42 +995,38 @@ class ModuleCache(object):
self._update_mappings(key, key_data, module.__file__) self._update_mappings(key, key_data, module.__file__)
return key_data return key_data
def module_from_key(self, key, fn=None, keep_lock=False): def module_from_key(self, key, lnk=None, keep_lock=False):
""" """
:param fn: A callable object that will return an iterable object when Return a module from the cache, compiling it if necessary.
called, such that the first element in this iterable object is the
source code of the module, and the last element is the module itself. :param key: The key object associated with the module. If this
`fn` is called only if the key is not already in the cache, with hits a match, we avoid compilation.
a single keyword argument `location` that is the path to the directory
where the module should be compiled. :param lnk: Usually a CLinker instance, but it can be any
object that defines the `get_src_code()` and
`compile_cmodule(location)` functions. The first
one returns the source code of the module to
load/compile and the second performs the actual
compilation.
:param keep_lock: If True, the compilation lock will not be
released if taken.
""" """
# Is the module in the cache? # Is the module in the cache?
module = self._get_from_key(key) module = self._get_from_key(key)
if module is not None: if module is not None:
return module return module
location = None
nocleanup = False
lock_taken = False lock_taken = False
try: src_code = lnk.get_src_code()
location = dlimport_workdir(self.dirname) # Is the source code already in the cache?
except OSError, e: module_hash = get_module_hash(src_code, key)
_logger.error(e) module = self._get_from_hash(module_hash, key, keep_lock=keep_lock)
if e.errno == 31: if module is not None:
_logger.error('There are %i files in %s', return module
len(os.listdir(config.compiledir)),
config.compiledir)
raise
try:
# Is the source code already in the cache?
compile_steps = fn(location=location).__iter__()
src_code = next(compile_steps)
module_hash = get_module_hash(src_code, key)
module = self._get_from_hash(module_hash, key, keep_lock=keep_lock)
if module is not None:
return module
try:
# Compile the module since it's not cached # Compile the module since it's not cached
compilelock.get_lock() compilelock.get_lock()
lock_taken = True lock_taken = True
...@@ -1039,40 +1035,45 @@ class ModuleCache(object): ...@@ -1039,40 +1035,45 @@ class ModuleCache(object):
# Need no_clean_empty otherwise it deletes the workdir we # Need no_clean_empty otherwise it deletes the workdir we
# created above. # created above.
self.refresh(no_clean_empty=True) self.refresh()
module = self._get_from_key(key) module = self._get_from_key(key)
if module is not None: if module is not None:
return module return module
module = self._get_from_hash(module_hash, key, keep_lock=keep_lock) module = self._get_from_hash(module_hash, key)
if module is not None: if module is not None:
return module return module
hash_key = hash(key) hash_key = hash(key)
while True: nocleanup = False
try: try:
# The module should be returned by the last location = dlimport_workdir(self.dirname)
# step of the compilation. module = lnk.compile_cmodule(location)
module = next(compile_steps) name = module.__file__
except StopIteration: assert name.startswith(location)
break assert name not in self.module_from_name
name = module.__file__ self.module_from_name[name] = module
assert name.startswith(location) nocleanup = True
assert name not in self.module_from_name except OSError, e:
self.module_from_name[name] = module _logger.error(e)
if e.errno == 31:
_logger.error('There are %i files in %s',
len(os.listdir(config.compiledir)),
config.compiledir)
raise
finally:
if not nocleanup:
_rmtree(location, ignore_if_missing=True,
msg='exception during compilation')
# Changing the hash of the key is not allowed during # Changing the hash of the key is not allowed during
# compilation. # compilation.
assert hash(key) == hash_key assert hash(key) == hash_key
nocleanup = True
key_data = self._add_to_cache(module, key, module_hash) key_data = self._add_to_cache(module, key, module_hash)
self.module_hash_to_key_data[module_hash] = key_data self.module_hash_to_key_data[module_hash] = key_data
finally: finally:
if not nocleanup:
_rmtree(location, ignore_if_missing=True,
msg='exception during compilation')
# Release lock if needed. # Release lock if needed.
if not keep_lock and lock_taken: if not keep_lock and lock_taken:
compilelock.release_lock() compilelock.release_lock()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论