提交 580367c1 authored 作者: James Bergstra's avatar James Bergstra

added rpaths parameter to nvcc_module_compile_str

上级 dfa86140
......@@ -71,9 +71,14 @@ if config.cuda.root == "AUTO":
is_nvcc_available()#to set nvcc_path correctly and get the version
rpath_defaults = []
def add_standard_rpath(rpath):
rpath_defaults.append(rpath)
def nvcc_module_compile_str(
module_name, src_code,
location=None, include_dirs=[], lib_dirs=[], libs=[], preargs=[]):
location=None, include_dirs=[], lib_dirs=[], libs=[], preargs=[],
rpaths=rpath_defaults):
"""
:param module_name: string (this has been embedded in the src_code
:param src_code: a complete c or c++ source listing for the module
......@@ -82,6 +87,7 @@ def nvcc_module_compile_str(
:param lib_dirs: a list of library search path directory names (each gets prefixed with -L)
:param libs: a list of libraries to link with (each gets prefixed with -l)
:param preargs: a list of extra compiler arguments
:param rpaths: list of rpaths to use with Xlinker. Defaults to `rpath_defaults`.
:returns: dynamically-imported python module of the compiled code.
......@@ -89,6 +95,8 @@ def nvcc_module_compile_str(
Otherwise nvcc never finish.
"""
rpaths = list(rpaths)
if sys.platform=="win32":
# Remove some compilation args that cl.exe does not understand.
# cl.exe is the compiler used by nvcc on Windows.
......@@ -171,10 +179,13 @@ def nvcc_module_compile_str(
cmd.extend(['-Xcompiler', ','.join(preargs2)])
if config.cuda.root and os.path.exists(os.path.join(config.cuda.root,'lib')):
cmd.extend(['-Xlinker',','.join(['-rpath',os.path.join(config.cuda.root,'lib')])])
rpaths.append(os.path.join(config.cuda.root,'lib'))
if sys.platform != 'darwin':
# the 64bit CUDA libs are in the same files as are named by the function above
cmd.extend(['-Xlinker',','.join(['-rpath',os.path.join(config.cuda.root,'lib64')])])
rpaths.append(os.path.join(config.cuda.root,'lib64'))
for rpath in rpaths:
cmd.extend(['-Xlinker',','.join(['-rpath',rpath])])
nvccflags = [flag for flag in config.cuda.nvccflags.split(' ') if flag]
cmd.extend(nvccflags)
cmd.extend('-I%s'%idir for idir in include_dirs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论