提交 eedf7d47 authored 作者: Frederic Bastien's avatar Frederic Bastien

recompile if the source file and dependency of cuda_ndarray are newer then compiled file.

上级 a68097aa
import os, sys import os, sys, stat
from theano.gof.compiledir import get_compiledir from theano.gof.compiledir import get_compiledir
from theano.compile import optdb from theano.compile import optdb
import theano.config as config import theano.config as config
...@@ -47,22 +47,37 @@ def set_cuda_disabled(): ...@@ -47,22 +47,37 @@ def set_cuda_disabled():
#cuda_ndarray compile and import #cuda_ndarray compile and import
sys.path.append(get_compiledir()) sys.path.append(get_compiledir())
try: cuda_path = os.path.split(__file__)[0]
date = os.stat(os.path.join(cuda_path,'cuda_ndarray.cu'))[stat.ST_MTIME]
date = max(date,os.stat(os.path.join(cuda_path,'cuda_ndarray.cuh'))[stat.ST_MTIME])
date = max(date,os.stat(os.path.join(cuda_path,'conv_full_kernel.cu'))[stat.ST_MTIME])
date = max(date,os.stat(os.path.join(cuda_path,'conv_kernel.cu'))[stat.ST_MTIME])
cuda_ndarray_loc = os.path.join(get_compiledir(),'cuda_ndarray')
cuda_ndarray_so = os.path.join(cuda_ndarray_loc,'cuda_ndarray.so')
compile_cuda_ndarray = True
if os.path.exists(cuda_ndarray_so):
compile_cuda_ndarray = date>=os.stat(cuda_ndarray_so)[stat.ST_MTIME]
if not compile_cuda_ndarray:
try:
from cuda_ndarray.cuda_ndarray import * from cuda_ndarray.cuda_ndarray import *
except ImportError: except ImportError:
compile_cuda_ndarray = True
if compile_cuda_ndarray:
import nvcc_compiler import nvcc_compiler
if not nvcc_compiler.is_nvcc_available(): if not nvcc_compiler.is_nvcc_available():
set_cuda_disabled() set_cuda_disabled()
if enable_cuda: if enable_cuda:
cuda_path = os.path.split(__file__)[0]
code = open(os.path.join(cuda_path, "cuda_ndarray.cu")).read() code = open(os.path.join(cuda_path, "cuda_ndarray.cu")).read()
loc = os.path.join(get_compiledir(),'cuda_ndarray') if not os.path.exists(cuda_ndarray_loc):
if not os.path.exists(loc): os.makedirs(cuda_ndarray_loc)
os.makedirs(loc)
nvcc_compiler.nvcc_module_compile_str('cuda_ndarray', code, location = loc, include_dirs=[cuda_path], libs=['cublas'], nvcc_compiler.nvcc_module_compile_str('cuda_ndarray', code, location = cuda_ndarray_loc,
include_dirs=[cuda_path], libs=['cublas'],
preargs=['-DDONT_UNROLL', '-O3']) preargs=['-DDONT_UNROLL', '-O3'])
from cuda_ndarray.cuda_ndarray import * from cuda_ndarray.cuda_ndarray import *
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论