提交 8422aa01 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Fixed check of the existence of the cuda_ndarray module under Windows, where the…

Fixed check of the existence of the cuda_ndarray module under Windows, where the extension is not .so
上级 e8756c30
import atexit, os, stat import atexit, os, stat
from theano.compile import optdb from theano.compile import optdb
from theano import config from theano import config
from theano.gof.cmodule import get_lib_extension
import logging import logging
_logger_name = 'theano.sandbox.cuda' _logger_name = 'theano.sandbox.cuda'
...@@ -55,8 +56,9 @@ date = max(date,os.stat(os.path.join(cuda_path,'cuda_ndarray.cuh'))[stat.ST_MTIM ...@@ -55,8 +56,9 @@ date = max(date,os.stat(os.path.join(cuda_path,'cuda_ndarray.cuh'))[stat.ST_MTIM
date = max(date,os.stat(os.path.join(cuda_path,'conv_full_kernel.cu'))[stat.ST_MTIME]) date = max(date,os.stat(os.path.join(cuda_path,'conv_full_kernel.cu'))[stat.ST_MTIME])
date = max(date,os.stat(os.path.join(cuda_path,'conv_kernel.cu'))[stat.ST_MTIME]) date = max(date,os.stat(os.path.join(cuda_path,'conv_kernel.cu'))[stat.ST_MTIME])
cuda_ndarray_loc = os.path.join(config.compiledir,'cuda_ndarray') cuda_ndarray_loc = os.path.join(config.compiledir, 'cuda_ndarray')
cuda_ndarray_so = os.path.join(cuda_ndarray_loc,'cuda_ndarray.so') cuda_ndarray_so = os.path.join(cuda_ndarray_loc,
'cuda_ndarray.' + get_lib_extension())
compile_cuda_ndarray = True compile_cuda_ndarray = True
if os.path.exists(cuda_ndarray_so): if os.path.exists(cuda_ndarray_so):
...@@ -81,8 +83,11 @@ try: ...@@ -81,8 +83,11 @@ try:
if not os.path.exists(cuda_ndarray_loc): if not os.path.exists(cuda_ndarray_loc):
os.makedirs(cuda_ndarray_loc) os.makedirs(cuda_ndarray_loc)
nvcc_compiler.nvcc_module_compile_str('cuda_ndarray', code, location = cuda_ndarray_loc, nvcc_compiler.nvcc_module_compile_str(
include_dirs=[cuda_path], libs=['cublas']) 'cuda_ndarray',
code,
location=cuda_ndarray_loc,
include_dirs=[cuda_path], libs=['cublas'])
from cuda_ndarray.cuda_ndarray import * from cuda_ndarray.cuda_ndarray import *
except Exception, e: except Exception, e:
...@@ -103,9 +108,7 @@ if cuda_available: ...@@ -103,9 +108,7 @@ if cuda_available:
if cuda_available: if cuda_available:
#check if their is an old cuda_ndarray that was loading instead of the one we compiled! #check if their is an old cuda_ndarray that was loading instead of the one we compiled!
import cuda_ndarray.cuda_ndarray import cuda_ndarray.cuda_ndarray
from theano.gof.cmodule import get_lib_extension if cuda_ndarray_so != cuda_ndarray.cuda_ndarray.__file__:
if os.path.join(config.compiledir,'cuda_ndarray','cuda_ndarray.'+get_lib_extension())!=cuda_ndarray.cuda_ndarray.__file__:
warning("WARNING: cuda_ndarray was loaded from",cuda_ndarray.cuda_ndarray.__file__,"This is not expected as theano should compile it automatically for you. Do you have a directory called cuda_ndarray in your LD_LIBRARY_PATH environment variable? If so, please remove it as it is outdated!") warning("WARNING: cuda_ndarray was loaded from",cuda_ndarray.cuda_ndarray.__file__,"This is not expected as theano should compile it automatically for you. Do you have a directory called cuda_ndarray in your LD_LIBRARY_PATH environment variable? If so, please remove it as it is outdated!")
from theano.sandbox.cuda.type import CudaNdarrayType from theano.sandbox.cuda.type import CudaNdarrayType
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论