提交 ee5534cc authored 作者: Frederic's avatar Frederic

don't recompile cuda_ndarray.so if it was outdated, but someone else compiled it.

上级 01e0840c
......@@ -69,14 +69,6 @@ def set_cuda_disabled():
#cuda_ndarray compile and import
cuda_path = os.path.abspath(os.path.split(__file__)[0])
cuda_files = (
'cuda_ndarray.cu',
'cuda_ndarray.cuh',
'conv_full_kernel.cu',
'conv_kernel.cu')
stat_times = [os.stat(os.path.join(cuda_path, cuda_file))[stat.ST_MTIME]
for cuda_file in cuda_files]
date = max(stat_times)
cuda_ndarray_loc = os.path.join(config.compiledir, 'cuda_ndarray')
cuda_ndarray_so = os.path.join(cuda_ndarray_loc,
......@@ -85,6 +77,33 @@ libcuda_ndarray_so = os.path.join(cuda_ndarray_loc,
'libcuda_ndarray.' + get_lib_extension())
def try_import():
"""
load the cuda_ndarray module if present and up to date
return True if loaded correctly, otherwise return False
"""
cuda_files = (
'cuda_ndarray.cu',
'cuda_ndarray.cuh',
'conv_full_kernel.cu',
'conv_kernel.cu')
stat_times = [os.stat(os.path.join(cuda_path, cuda_file))[stat.ST_MTIME]
for cuda_file in cuda_files]
date = max(stat_times)
if os.path.exists(cuda_ndarray_so):
if date >= os.stat(cuda_ndarray_so)[stat.ST_MTIME]:
return False
try:
# If we load a previously-compiled version, config.compiledir should
# be in sys.path.
if config.compiledir not in sys.path:
sys.path.append(config.compiledir)
from cuda_ndarray.cuda_ndarray import *
except ImportError:
return False
return True
# Add the theano cache directory's cuda_ndarray subdirectory to the
# list of places that are hard-coded into compiled modules' runtime
# library search list. This works in conjunction with
......@@ -95,41 +114,23 @@ nvcc_compiler.add_standard_rpath(cuda_ndarray_loc)
compile_cuda_ndarray = True
outdated = False
if os.path.exists(cuda_ndarray_so):
outdated = date >= os.stat(cuda_ndarray_so)[stat.ST_MTIME]
compile_cuda_ndarray = outdated
if not compile_cuda_ndarray:
try:
# If we load a previously-compiled version, config.compiledir should
# be in sys.path.
if config.compiledir not in sys.path:
sys.path.append(config.compiledir)
from cuda_ndarray.cuda_ndarray import *
except ImportError:
compile_cuda_ndarray = True
compile_cuda_ndarray = not try_import()
if compile_cuda_ndarray:
get_lock()
try:
if not outdated:
# Maybe someone else compiled it while we got the lock.
try:
# If we load a previously-compiled version,
# config.compiledir should be in sys.path.
if config.compiledir not in sys.path:
sys.path.append(config.compiledir)
from cuda_ndarray.cuda_ndarray import *
compile_cuda_ndarray = False
except ImportError:
compile_cuda_ndarray = True
if compile_cuda_ndarray:
# Retry to load again in case someone else compiled it
# while we waited for the lock
compile_cuda_ndarray = not try_import()
if not try_import():
try:
if not nvcc_compiler.is_nvcc_available():
set_cuda_disabled()
if cuda_available:
code = open(os.path.join(cuda_path, "cuda_ndarray.cu")).read()
code = open(os.path.join(cuda_path,
"cuda_ndarray.cu")).read()
if not os.path.exists(cuda_ndarray_loc):
os.makedirs(cuda_ndarray_loc)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论