提交 ee5534cc authored 作者: Frederic's avatar Frederic

don't recompile cuda_ndarray.so if it was outdated, but someone else compiled it.

上级 01e0840c
...@@ -69,14 +69,6 @@ def set_cuda_disabled(): ...@@ -69,14 +69,6 @@ def set_cuda_disabled():
#cuda_ndarray compile and import #cuda_ndarray compile and import
cuda_path = os.path.abspath(os.path.split(__file__)[0]) cuda_path = os.path.abspath(os.path.split(__file__)[0])
cuda_files = (
'cuda_ndarray.cu',
'cuda_ndarray.cuh',
'conv_full_kernel.cu',
'conv_kernel.cu')
stat_times = [os.stat(os.path.join(cuda_path, cuda_file))[stat.ST_MTIME]
for cuda_file in cuda_files]
date = max(stat_times)
cuda_ndarray_loc = os.path.join(config.compiledir, 'cuda_ndarray') cuda_ndarray_loc = os.path.join(config.compiledir, 'cuda_ndarray')
cuda_ndarray_so = os.path.join(cuda_ndarray_loc, cuda_ndarray_so = os.path.join(cuda_ndarray_loc,
...@@ -85,6 +77,33 @@ libcuda_ndarray_so = os.path.join(cuda_ndarray_loc, ...@@ -85,6 +77,33 @@ libcuda_ndarray_so = os.path.join(cuda_ndarray_loc,
'libcuda_ndarray.' + get_lib_extension()) 'libcuda_ndarray.' + get_lib_extension())
def try_import():
"""
load the cuda_ndarray module if present and up to date
return True if loaded correctly, otherwise return False
"""
cuda_files = (
'cuda_ndarray.cu',
'cuda_ndarray.cuh',
'conv_full_kernel.cu',
'conv_kernel.cu')
stat_times = [os.stat(os.path.join(cuda_path, cuda_file))[stat.ST_MTIME]
for cuda_file in cuda_files]
date = max(stat_times)
if os.path.exists(cuda_ndarray_so):
if date >= os.stat(cuda_ndarray_so)[stat.ST_MTIME]:
return False
try:
# If we load a previously-compiled version, config.compiledir should
# be in sys.path.
if config.compiledir not in sys.path:
sys.path.append(config.compiledir)
from cuda_ndarray.cuda_ndarray import *
except ImportError:
return False
return True
# Add the theano cache directory's cuda_ndarray subdirectory to the # Add the theano cache directory's cuda_ndarray subdirectory to the
# list of places that are hard-coded into compiled modules' runtime # list of places that are hard-coded into compiled modules' runtime
# library search list. This works in conjunction with # library search list. This works in conjunction with
...@@ -95,41 +114,23 @@ nvcc_compiler.add_standard_rpath(cuda_ndarray_loc) ...@@ -95,41 +114,23 @@ nvcc_compiler.add_standard_rpath(cuda_ndarray_loc)
compile_cuda_ndarray = True compile_cuda_ndarray = True
outdated = False
if os.path.exists(cuda_ndarray_so):
outdated = date >= os.stat(cuda_ndarray_so)[stat.ST_MTIME]
compile_cuda_ndarray = outdated
if not compile_cuda_ndarray: if not compile_cuda_ndarray:
try: compile_cuda_ndarray = not try_import()
# If we load a previously-compiled version, config.compiledir should
# be in sys.path.
if config.compiledir not in sys.path:
sys.path.append(config.compiledir)
from cuda_ndarray.cuda_ndarray import *
except ImportError:
compile_cuda_ndarray = True
if compile_cuda_ndarray: if compile_cuda_ndarray:
get_lock() get_lock()
try: try:
if not outdated: # Retry to load again in case someone else compiled it
# Maybe someone else compiled it while we got the lock. # while we waited for the lock
try: compile_cuda_ndarray = not try_import()
# If we load a previously-compiled version, if not try_import():
# config.compiledir should be in sys.path.
if config.compiledir not in sys.path:
sys.path.append(config.compiledir)
from cuda_ndarray.cuda_ndarray import *
compile_cuda_ndarray = False
except ImportError:
compile_cuda_ndarray = True
if compile_cuda_ndarray:
try: try:
if not nvcc_compiler.is_nvcc_available(): if not nvcc_compiler.is_nvcc_available():
set_cuda_disabled() set_cuda_disabled()
if cuda_available: if cuda_available:
code = open(os.path.join(cuda_path, "cuda_ndarray.cu")).read() code = open(os.path.join(cuda_path,
"cuda_ndarray.cu")).read()
if not os.path.exists(cuda_ndarray_loc): if not os.path.exists(cuda_ndarray_loc):
os.makedirs(cuda_ndarray_loc) os.makedirs(cuda_ndarray_loc)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论