提交 cb6152de authored 作者: James Bergstra's avatar James Bergstra

Added dlimport_workdir function to cmodule, to standardize the way in which

compiledir is mapped to a tmpdir where an .so should go.
上级 f3dfcc7b
...@@ -777,9 +777,7 @@ class CLinker(link.Linker): ...@@ -777,9 +777,7 @@ class CLinker(link.Linker):
This method is a callback for `ModuleCache.module_from_key` This method is a callback for `ModuleCache.module_from_key`
""" """
if location is None: if location is None:
location = get_compiledir() location = cmodule.dlimport_workdir(get_compiledir())
#backport
#location = get_compiledir() if location is None else location
mod = self.build_dynamic_module() mod = self.build_dynamic_module()
get_lock() get_lock()
try: try:
......
...@@ -166,6 +166,11 @@ def dlimport(fullpath, suffix=None): ...@@ -166,6 +166,11 @@ def dlimport(fullpath, suffix=None):
assert fullpath.startswith(rval.__file__) assert fullpath.startswith(rval.__file__)
return rval return rval
def dlimport_workdir(basedir):
"""Return a directory where you should put your .so file for dlimport to be able to load
it, given a basedir which should normally be the result of get_compiledir()"""
return tempfile.mkdtemp(dir=basedir)
def last_access_time(path): def last_access_time(path):
"""Return the number of seconds since the epoch of the last access of a given file""" """Return the number of seconds since the epoch of the last access of a given file"""
return os.stat(path)[stat.ST_ATIME] return os.stat(path)[stat.ST_ATIME]
...@@ -370,7 +375,7 @@ class ModuleCache(object): ...@@ -370,7 +375,7 @@ class ModuleCache(object):
rval = self.module_from_name[name] rval = self.module_from_name[name]
else: else:
# we have never seen this key before # we have never seen this key before
location = tempfile.mkdtemp(dir=self.dirname) location = dlimport_workdir(self.dirname)
#debug("LOCATION*", location) #debug("LOCATION*", location)
try: try:
module = fn(location=location) # WILL FAIL FOR BAD C CODE module = fn(location=location) # WILL FAIL FOR BAD C CODE
...@@ -627,7 +632,7 @@ def nvcc_module_compile_str(module_name, src_code, location=None, include_dirs=[ ...@@ -627,7 +632,7 @@ def nvcc_module_compile_str(module_name, src_code, location=None, include_dirs=[
libs = ['python2.6', 'cudart'] + libs libs = ['python2.6', 'cudart'] + libs
lib_dirs = ['/usr/local/cuda/lib']+lib_dirs lib_dirs = ['/usr/local/cuda/lib']+lib_dirs
workdir = tempfile.mkdtemp(dir=location) workdir = dlimport_workdir(location)
cppfilename = os.path.join(workdir, 'mod.cpp') #.cpp to use g++ cppfilename = os.path.join(workdir, 'mod.cpp') #.cpp to use g++
cppfilename = os.path.join(workdir, 'mod.cu') #.cu to use nvopencc cppfilename = os.path.join(workdir, 'mod.cu') #.cu to use nvopencc
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论