提交 79fa1216 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Merged

Trunk sin last release Modification in the trunk since the last release
------ ------------------------------------------------
* Sparse type is now supported by the shape op and the ShapeFeature optimizer work correctly with them. * Sparse type is now supported by the shape op and the ShapeFeature optimizer work correctly with them.
* Fuse GpuElemwise more often (in the case where there are so many inputs that fusing them all would bust the 256 bytes limit of parameter to gpu function). * Fuse GpuElemwise more often (in the case where there are so many inputs that fusing them all would bust the 256 bytes limit of parameter to gpu function).
* Speed up gemv by a work around scipy gemv slowness when the matrix is in C order (the default). * Speed up gemv by a work around scipy gemv slowness when the matrix is in C order (the default).
......
import atexit, os, stat import atexit, os, stat
from theano.compile import optdb from theano.compile import optdb
from theano import config from theano import config
from theano.gof.cmodule import get_lib_extension
import logging import logging
_logger_name = 'theano.sandbox.cuda' _logger_name = 'theano.sandbox.cuda'
...@@ -54,8 +55,9 @@ cuda_files = ('cuda_ndarray.cu', 'cuda_ndarray.cuh', 'conv_full_kernel.cu', 'con ...@@ -54,8 +55,9 @@ cuda_files = ('cuda_ndarray.cu', 'cuda_ndarray.cuh', 'conv_full_kernel.cu', 'con
stat_times = [os.stat(os.path.join(cuda_path, cuda_file))[stat.ST_MTIME] for cuda_file in cuda_files] stat_times = [os.stat(os.path.join(cuda_path, cuda_file))[stat.ST_MTIME] for cuda_file in cuda_files]
date = max(stat_times) date = max(stat_times)
cuda_ndarray_loc = os.path.join(config.compiledir,'cuda_ndarray') cuda_ndarray_loc = os.path.join(config.compiledir, 'cuda_ndarray')
cuda_ndarray_so = os.path.join(cuda_ndarray_loc,'cuda_ndarray.so') cuda_ndarray_so = os.path.join(cuda_ndarray_loc,
'cuda_ndarray.' + get_lib_extension())
compile_cuda_ndarray = True compile_cuda_ndarray = True
if os.path.exists(cuda_ndarray_so): if os.path.exists(cuda_ndarray_so):
...@@ -80,8 +82,11 @@ try: ...@@ -80,8 +82,11 @@ try:
if not os.path.exists(cuda_ndarray_loc): if not os.path.exists(cuda_ndarray_loc):
os.makedirs(cuda_ndarray_loc) os.makedirs(cuda_ndarray_loc)
nvcc_compiler.nvcc_module_compile_str('cuda_ndarray', code, location = cuda_ndarray_loc, nvcc_compiler.nvcc_module_compile_str(
include_dirs=[cuda_path], libs=['cublas']) 'cuda_ndarray',
code,
location=cuda_ndarray_loc,
include_dirs=[cuda_path], libs=['cublas'])
from cuda_ndarray.cuda_ndarray import * from cuda_ndarray.cuda_ndarray import *
except Exception, e: except Exception, e:
...@@ -109,9 +114,7 @@ from theano.sandbox.cuda.type import CudaNdarrayType ...@@ -109,9 +114,7 @@ from theano.sandbox.cuda.type import CudaNdarrayType
if cuda_available: if cuda_available:
#check if their is an old cuda_ndarray that was loading instead of the one we compiled! #check if their is an old cuda_ndarray that was loading instead of the one we compiled!
import cuda_ndarray.cuda_ndarray import cuda_ndarray.cuda_ndarray
from theano.gof.cmodule import get_lib_extension if cuda_ndarray_so != cuda_ndarray.cuda_ndarray.__file__:
if os.path.join(config.compiledir,'cuda_ndarray','cuda_ndarray.'+get_lib_extension())!=cuda_ndarray.cuda_ndarray.__file__:
warning("WARNING: cuda_ndarray was loaded from",cuda_ndarray.cuda_ndarray.__file__,"This is not expected as theano should compile it automatically for you. Do you have a directory called cuda_ndarray in your LD_LIBRARY_PATH environment variable? If so, please remove it as it is outdated!") warning("WARNING: cuda_ndarray was loaded from",cuda_ndarray.cuda_ndarray.__file__,"This is not expected as theano should compile it automatically for you. Do you have a directory called cuda_ndarray in your LD_LIBRARY_PATH environment variable? If so, please remove it as it is outdated!")
shared_constructor = float32_shared_constructor shared_constructor = float32_shared_constructor
......
...@@ -53,8 +53,9 @@ def is_nvcc_available(): ...@@ -53,8 +53,9 @@ def is_nvcc_available():
else: return False else: return False
is_nvcc_available()#to set nvcc_path correctly and get the version is_nvcc_available()#to set nvcc_path correctly and get the version
def nvcc_module_compile_str(module_name, src_code, location=None, include_dirs=[], lib_dirs=[], libs=[], def nvcc_module_compile_str(
preargs=[]): module_name, src_code,
location=None, include_dirs=[], lib_dirs=[], libs=[], preargs=[]):
""" """
:param module_name: string (this has been embedded in the src_code :param module_name: string (this has been embedded in the src_code
:param src_code: a complete c or c++ source listing for the module :param src_code: a complete c or c++ source listing for the module
...@@ -71,8 +72,8 @@ def nvcc_module_compile_str(module_name, src_code, location=None, include_dirs=[ ...@@ -71,8 +72,8 @@ def nvcc_module_compile_str(module_name, src_code, location=None, include_dirs=[
""" """
if sys.platform=="win32": if sys.platform=="win32":
#remove some compilation args that cl.exe don't understand # Remove some compilation args that cl.exe does not understand.
#cl.exe is the compiler used by nvcc on Windows # cl.exe is the compiler used by nvcc on Windows.
for a in ["-Wno-write-strings","-Wno-unused-label", for a in ["-Wno-write-strings","-Wno-unused-label",
"-Wno-unused-variable", "-fno-math-errno"]: "-Wno-unused-variable", "-fno-math-errno"]:
if a in preargs: if a in preargs:
...@@ -185,8 +186,8 @@ def nvcc_module_compile_str(module_name, src_code, location=None, include_dirs=[ ...@@ -185,8 +186,8 @@ def nvcc_module_compile_str(module_name, src_code, location=None, include_dirs=[
orig_dir = os.getcwd() orig_dir = os.getcwd()
try: try:
os.chdir(location) os.chdir(location)
p = subprocess.Popen(
p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
nvcc_stdout, nvcc_stderr = p.communicate()[:2] nvcc_stdout, nvcc_stderr = p.communicate()[:2]
finally: finally:
os.chdir(orig_dir) os.chdir(orig_dir)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论