提交 1a369c4d authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2321 from abergeron/fix_cache_access

Change the ModuleCache to only take the lock for actual compilation.
...@@ -1281,30 +1281,18 @@ class CLinker(link.Linker): ...@@ -1281,30 +1281,18 @@ class CLinker(link.Linker):
return ((), sig) return ((), sig)
return version, sig return version, sig
def get_src_code(self):
mod = self.get_dynamic_module()
return mod.code()
def compile_cmodule(self, location=None): def compile_cmodule(self, location=None):
""" """
Compile the module and return it. This compiles the source code for this linker and returns a
""" loaded module.
# Go through all steps of the compilation process.
for step_result in self.compile_cmodule_by_step(location=location):
pass
# And return the output of the last step, which should be the module
# itself.
return step_result
def compile_cmodule_by_step(self, location=None):
"""
This method is a callback for `ModuleCache.module_from_key`.
It is a generator (thus the 'by step'), so that:
- it first yields the module's C code
- it last yields the module itself
- it may yield other intermediate outputs in-between if needed
in the future (but this is not currently the case)
""" """
if location is None: if location is None:
location = cmodule.dlimport_workdir(config.compiledir) location = cmodule.dlimport_workdir(config.compiledir)
mod = self.build_dynamic_module() mod = self.get_dynamic_module()
c_compiler = self.c_compiler() c_compiler = self.c_compiler()
libs = self.libraries() libs = self.libraries()
preargs = self.compile_args() preargs = self.compile_args()
...@@ -1323,48 +1311,49 @@ class CLinker(link.Linker): ...@@ -1323,48 +1311,49 @@ class CLinker(link.Linker):
if 'amdlibm' in libs: if 'amdlibm' in libs:
libs.remove('amdlibm') libs.remove('amdlibm')
src_code = mod.code() src_code = mod.code()
yield src_code
get_lock() get_lock()
try: try:
_logger.debug("LOCATION %s", str(location)) _logger.debug("LOCATION %s", str(location))
try: module = c_compiler.compile_str(
module = c_compiler.compile_str( module_name=mod.code_hash,
module_name=mod.code_hash, src_code=mod.code(),
src_code=src_code, location=location,
location=location, include_dirs=self.header_dirs(),
include_dirs=self.header_dirs(), lib_dirs=self.lib_dirs(),
lib_dirs=self.lib_dirs(), libs=libs,
libs=libs, preargs=preargs)
preargs=preargs) except Exception, e:
except Exception, e: e.args += (str(self.fgraph),)
e.args += (str(self.fgraph),) raise
raise
finally: finally:
release_lock() release_lock()
return module
yield module def get_dynamic_module(self):
def build_dynamic_module(self):
"""Return a cmodule.DynamicModule instance full of the code """Return a cmodule.DynamicModule instance full of the code
for our fgraph. for our fgraph.
This method is cached on the first call so it can be called
multiple times without penalty.
""" """
self.code_gen() if not hasattr(self, '_mod'):
self.code_gen()
mod = cmodule.DynamicModule() mod = cmodule.DynamicModule()
# The code of instantiate # The code of instantiate
# the 1 is for error_storage # the 1 is for error_storage
code = self.instantiate_code(1 + len(self.args)) code = self.instantiate_code(1 + len(self.args))
instantiate = cmodule.ExtFunction('instantiate', code, instantiate = cmodule.ExtFunction('instantiate', code,
method=cmodule.METH_VARARGS) method=cmodule.METH_VARARGS)
#['error_storage'] + argnames, #['error_storage'] + argnames,
#local_dict = d, #local_dict = d,
#global_dict = {}) #global_dict = {})
# Static methods that can run and destroy the struct built by # Static methods that can run and destroy the struct built by
# instantiate. # instantiate.
if PY3: if PY3:
static = """ static = """
static int {struct_name}_executor({struct_name} *self) {{ static int {struct_name}_executor({struct_name} *self) {{
return self->run(); return self->run();
}} }}
...@@ -1374,8 +1363,8 @@ class CLinker(link.Linker): ...@@ -1374,8 +1363,8 @@ class CLinker(link.Linker):
delete self; delete self;
}} }}
""".format(struct_name=self.struct_name) """.format(struct_name=self.struct_name)
else: else:
static = """ static = """
static int %(struct_name)s_executor(%(struct_name)s* self) { static int %(struct_name)s_executor(%(struct_name)s* self) {
return self->run(); return self->run();
} }
...@@ -1386,17 +1375,17 @@ class CLinker(link.Linker): ...@@ -1386,17 +1375,17 @@ class CLinker(link.Linker):
""" % dict(struct_name=self.struct_name) """ % dict(struct_name=self.struct_name)
# We add all the support code, compile args, headers and libs we need. # We add all the support code, compile args, headers and libs we need.
for support_code in self.support_code() + self.c_support_code_apply: for support_code in self.support_code() + self.c_support_code_apply:
mod.add_support_code(support_code) mod.add_support_code(support_code)
mod.add_support_code(self.struct_code) mod.add_support_code(self.struct_code)
mod.add_support_code(static) mod.add_support_code(static)
mod.add_function(instantiate) mod.add_function(instantiate)
for header in self.headers(): for header in self.headers():
mod.add_include(header) mod.add_include(header)
for init_code_block in self.init_code() + self.c_init_code_apply: for init_code_block in self.init_code() + self.c_init_code_apply:
mod.add_init_code(init_code_block) mod.add_init_code(init_code_block)
self._mod = mod
return mod return self._mod
def cthunk_factory(self, error_storage, in_storage, out_storage, def cthunk_factory(self, error_storage, in_storage, out_storage,
keep_lock=False): keep_lock=False):
...@@ -1420,7 +1409,7 @@ class CLinker(link.Linker): ...@@ -1420,7 +1409,7 @@ class CLinker(link.Linker):
module = self.compile_cmodule() module = self.compile_cmodule()
else: else:
module = get_module_cache().module_from_key( module = get_module_cache().module_from_key(
key=key, fn=self.compile_cmodule_by_step, keep_lock=keep_lock) key=key, lnk=self, keep_lock=keep_lock)
vars = self.inputs + self.outputs + self.orphans vars = self.inputs + self.outputs + self.orphans
# List of indices that should be ignored when passing the arguments # List of indices that should be ignored when passing the arguments
......
差异被折叠。
...@@ -8,6 +8,8 @@ import socket # only used for gethostname() ...@@ -8,6 +8,8 @@ import socket # only used for gethostname()
import time import time
import logging import logging
from contextlib import contextmanager
from theano import config from theano import config
from theano.configparser import AddConfigVar, IntParam from theano.configparser import AddConfigVar, IntParam
...@@ -44,6 +46,14 @@ def force_unlock(): ...@@ -44,6 +46,14 @@ def force_unlock():
release_lock() release_lock()
@contextmanager
def lock_ctx(lock_dir=None, keep_lock=False, **kw):
get_lock(lock_dir=lock_dir, **kw)
yield
if not keep_lock:
release_lock()
def get_lock(lock_dir=None, **kw): def get_lock(lock_dir=None, **kw):
""" """
Obtain lock on compilation directory. Obtain lock on compilation directory.
......
File mode changed from 100755 to 100644
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论