提交 13887e2c authored 作者: Frederic's avatar Frederic

fix gh-747. Pass the gpu architecture to nvcc when compiling.

上级 fa2904f2
...@@ -7,6 +7,7 @@ import subprocess ...@@ -7,6 +7,7 @@ import subprocess
import sys import sys
import warnings import warnings
import theano
from theano.gof.cc import hash_from_file from theano.gof.cc import hash_from_file
from theano.gof.cmodule import (std_libs, std_lib_dirs, from theano.gof.cmodule import (std_libs, std_lib_dirs,
std_include_dirs, dlimport, std_include_dirs, dlimport,
...@@ -119,6 +120,16 @@ class NVCC_compiler(object): ...@@ -119,6 +120,16 @@ class NVCC_compiler(object):
cuda_ndarray_cuh_hash = hash_from_file( cuda_ndarray_cuh_hash = hash_from_file(
os.path.join(os.path.split(__file__)[0], 'cuda_ndarray.cuh')) os.path.join(os.path.split(__file__)[0], 'cuda_ndarray.cuh'))
flags.append('-DCUDA_NDARRAY_CUH=' + cuda_ndarray_cuh_hash) flags.append('-DCUDA_NDARRAY_CUH=' + cuda_ndarray_cuh_hash)
# We compile cuda_ndarray.cu during import.
# We should not add device properties at that time.
# As the device is not selected yet!
# TODO: compile cuda_ndarray when we bind to a GPU?
import theano.sandbox.cuda
if hasattr(theano.sandbox, 'cuda'):
n = theano.sandbox.cuda.use.device_number
p = theano.sandbox.cuda.device_properties(n)
flags.append('-arch=sm_' + str(p['major']) + str(p['minor']))
return flags return flags
@staticmethod @staticmethod
...@@ -217,7 +228,9 @@ class NVCC_compiler(object): ...@@ -217,7 +228,9 @@ class NVCC_compiler(object):
# '--gpu-code=compute_13', # '--gpu-code=compute_13',
#nvcc argument #nvcc argument
preargs1 = [pa for pa in preargs preargs1 = [pa for pa in preargs
if pa.startswith('-O') or pa.startswith('--maxrregcount=')] if pa.startswith('-O') or
pa.startswith('--maxrregcount=') or
pa.startswith('-arch=')]
preargs2 = [pa for pa in preargs preargs2 = [pa for pa in preargs
if pa not in preargs1] # other arguments if pa not in preargs1] # other arguments
...@@ -337,6 +350,7 @@ class NVCC_compiler(object): ...@@ -337,6 +350,7 @@ class NVCC_compiler(object):
pass pass
print >> sys.stderr, l print >> sys.stderr, l
print nvcc_stdout print nvcc_stdout
print cmd
raise Exception('nvcc return status', p.returncode, raise Exception('nvcc return status', p.returncode,
'for cmd', ' '.join(cmd)) 'for cmd', ' '.join(cmd))
elif config.cmodule.compilation_warning and nvcc_stdout: elif config.cmodule.compilation_warning and nvcc_stdout:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论