提交 cff27c13 authored 作者: Frederic's avatar Frederic

make {nvcc,gcc}_module_compile_str a class with another function compile_args…

make {nvcc,gcc}_module_compile_str a class with another function compile_args that get added in the keys.
上级 2f2b424a
...@@ -622,6 +622,10 @@ class CLinker(link.Linker): ...@@ -622,6 +622,10 @@ class CLinker(link.Linker):
for x in [y.type for y in self.variables] + [y.op for y in self.node_order]: for x in [y.type for y in self.variables] + [y.op for y in self.node_order]:
try: ret += x.c_compile_args() try: ret += x.c_compile_args()
except utils.MethodNotDefined: pass except utils.MethodNotDefined: pass
c_compiler = self.c_compiler()
ret += c_compiler.compile_args()
ret=list(set(ret))#to remove duplicate ret=list(set(ret))#to remove duplicate
for x in [y.type for y in self.variables] + [y.op for y in self.node_order]: for x in [y.type for y in self.variables] + [y.op for y in self.node_order]:
try: try:
...@@ -661,7 +665,7 @@ class CLinker(link.Linker): ...@@ -661,7 +665,7 @@ class CLinker(link.Linker):
raise Exception('Nodes have requested specific different compilers', raise Exception('Nodes have requested specific different compilers',
(c_compiler, x_compiler)) (c_compiler, x_compiler))
if (c_compiler is None): if (c_compiler is None):
return cmodule.gcc_module_compile_str return cmodule.GCC_compiler
else: return c_compiler else: return c_compiler
def header_dirs(self): def header_dirs(self):
...@@ -1007,7 +1011,7 @@ class CLinker(link.Linker): ...@@ -1007,7 +1011,7 @@ class CLinker(link.Linker):
libs = self.libraries() libs = self.libraries()
preargs = self.compile_args() preargs = self.compile_args()
compiler_name = c_compiler.__name__ compiler_name = c_compiler.__name__
if compiler_name == 'nvcc_module_compile_str' and config.lib.amdlibm: if compiler_name == 'NVCC_compiler' and config.lib.amdlibm:
# This lib does not work correctly with nvcc in device code. # This lib does not work correctly with nvcc in device code.
# and newer version of g++ as 4.5.1. # and newer version of g++ as 4.5.1.
# example of errors: "/usr/lib/gcc/x86_64-redhat-linux/4.5.1/include/mmintrin.h(49): error: identifier "__builtin_ia32_emms" is undefined" # example of errors: "/usr/lib/gcc/x86_64-redhat-linux/4.5.1/include/mmintrin.h(49): error: identifier "__builtin_ia32_emms" is undefined"
...@@ -1024,7 +1028,7 @@ class CLinker(link.Linker): ...@@ -1024,7 +1028,7 @@ class CLinker(link.Linker):
try: try:
_logger.debug("LOCATION %s", str(location)) _logger.debug("LOCATION %s", str(location))
try: try:
module = c_compiler( module = c_compiler.compile_str(
module_name=mod.name, module_name=mod.name,
src_code=src_code, src_code=src_code,
location=location, location=location,
......
...@@ -1312,23 +1312,30 @@ def gcc_version(): ...@@ -1312,23 +1312,30 @@ def gcc_version():
return gcc_version_str return gcc_version_str
def gcc_module_compile_str(module_name, src_code, location=None, class GCC_compiler():
@staticmethod
def compile_args():
return []
@staticmethod
def compile_str(module_name, src_code, location=None,
include_dirs=[], lib_dirs=[], libs=[], preargs=[]): include_dirs=[], lib_dirs=[], libs=[], preargs=[]):
""" """
:param module_name: string (this has been embedded in the src_code :param module_name: string (this has been embedded in the src_code
:param src_code: a complete c or c++ source listing for the module :param src_code: a complete c or c++ source listing for the module
:param location: a pre-existing filesystem directory where the cpp file and :param location: a pre-existing filesystem directory where the
.so will be written cpp file and .so will be written
:param include_dirs: a list of include directory names (each gets prefixed :param include_dirs: a list of include directory names (each
with -I) gets prefixed with -I)
:param lib_dirs: a list of library search path directory names (each gets :param lib_dirs: a list of library search path directory names
prefixed with -L) (each gets prefixed with -L)
:param libs: a list of libraries to link with (each gets prefixed with -l) :param libs: a list of libraries to link with (each gets
prefixed with -l)
:param preargs: a list of extra compiler arguments :param preargs: a list of extra compiler arguments
...@@ -1362,8 +1369,8 @@ def gcc_module_compile_str(module_name, src_code, location=None, ...@@ -1362,8 +1369,8 @@ def gcc_module_compile_str(module_name, src_code, location=None,
config.cmodule.mac_framework_link): config.cmodule.mac_framework_link):
preargs.extend(['-framework', 'Python']) preargs.extend(['-framework', 'Python'])
# Figure out whether the current Python executable is 32 or 64 bit and # Figure out whether the current Python executable is 32
# compile accordingly. # or 64 bit and compile accordingly.
n_bits = local_bitwidth() n_bits = local_bitwidth()
preargs.extend(['-m%s' % n_bits]) preargs.extend(['-m%s' % n_bits])
_logger.debug("OS X: compiling for %s bit architecture", n_bits) _logger.debug("OS X: compiling for %s bit architecture", n_bits)
......
...@@ -70,7 +70,7 @@ except ImportError: ...@@ -70,7 +70,7 @@ except ImportError:
if not os.path.exists(loc): if not os.path.exists(loc):
os.mkdir(loc) os.mkdir(loc)
cmodule.gcc_module_compile_str('cutils_ext', code, location=loc) cmodule.GCC_compiler.compile_str('cutils_ext', code, location=loc)
from cutils_ext.cutils_ext import * from cutils_ext.cutils_ext import *
finally: finally:
......
...@@ -53,7 +53,7 @@ except ImportError: ...@@ -53,7 +53,7 @@ except ImportError:
loc = os.path.join(config.compiledir, dirname) loc = os.path.join(config.compiledir, dirname)
if not os.path.exists(loc): if not os.path.exists(loc):
os.mkdir(loc) os.mkdir(loc)
cmodule.gcc_module_compile_str(dirname, code, location=loc) cmodule.GCC_compiler.compile_str(dirname, code, location=loc)
# Save version into the __init__.py file. # Save version into the __init__.py file.
init_py = os.path.join(loc, '__init__.py') init_py = os.path.join(loc, '__init__.py')
open(init_py, 'w').write('_version = %s\n' % version) open(init_py, 'w').write('_version = %s\n' % version)
......
...@@ -87,7 +87,7 @@ libcuda_ndarray_so = os.path.join(cuda_ndarray_loc, ...@@ -87,7 +87,7 @@ libcuda_ndarray_so = os.path.join(cuda_ndarray_loc,
# Add the theano cache directory's cuda_ndarray subdirectory to the # Add the theano cache directory's cuda_ndarray subdirectory to the
# list of places that are hard-coded into compiled modules' runtime # list of places that are hard-coded into compiled modules' runtime
# library search list. This works in conjunction with # library search list. This works in conjunction with
# nvcc_compiler.nvcc_module_compile_str which adds this folder during # nvcc_compiler.NVCC_compiler.compile_str which adds this folder during
# compilation with -L and also adds -lcuda_ndarray when compiling # compilation with -L and also adds -lcuda_ndarray when compiling
# modules. # modules.
nvcc_compiler.add_standard_rpath(cuda_ndarray_loc) nvcc_compiler.add_standard_rpath(cuda_ndarray_loc)
...@@ -117,7 +117,8 @@ try: ...@@ -117,7 +117,8 @@ try:
if not os.path.exists(cuda_ndarray_loc): if not os.path.exists(cuda_ndarray_loc):
os.makedirs(cuda_ndarray_loc) os.makedirs(cuda_ndarray_loc)
nvcc_compiler.nvcc_module_compile_str( compiler = nvcc_compiler.NVCC_compiler()
compiler.compile_str(
'cuda_ndarray', 'cuda_ndarray',
code, code,
location=cuda_ndarray_loc, location=cuda_ndarray_loc,
...@@ -130,7 +131,7 @@ except Exception, e: ...@@ -130,7 +131,7 @@ except Exception, e:
if cuda_available: if cuda_available:
# If necessary, # If necessary,
# create a symlink called libcuda_ndarray.so # create a symlink called libcuda_ndarray.so
# which nvcc_module_compile_str uses when linking # which nvcc_compiler.NVCC_compiler uses when linking
# any module except "cuda_ndarray" itself. # any module except "cuda_ndarray" itself.
try: try:
open(libcuda_ndarray_so).close() open(libcuda_ndarray_so).close()
......
...@@ -72,7 +72,23 @@ rpath_defaults = [] ...@@ -72,7 +72,23 @@ rpath_defaults = []
def add_standard_rpath(rpath): def add_standard_rpath(rpath):
rpath_defaults.append(rpath) rpath_defaults.append(rpath)
def nvcc_module_compile_str(
class NVCC_compiler():
@staticmethod
def compile_args():
"""
This args will be received by compile_str() in the preargs paramter.
They will also be included in the "hard" part of the key module.
"""
return []
# flags = [flag for flag in config.nvcc.flags.split(' ') if flag]
# cuda_ndarray_cuh_hash = hash_from_file(
# os.path.join(os.path.split(__file__)[0], 'cuda_ndarray.cuh'))
# cuda_macro = '-DCUDA_NDARRAY_CUH=' + cuda_ndarray_cuh_hash
# return [cuda_macro]
@staticmethod
def compile_str(
module_name, src_code, module_name, src_code,
location=None, include_dirs=[], lib_dirs=[], libs=[], preargs=[], location=None, include_dirs=[], lib_dirs=[], libs=[], preargs=[],
rpaths=rpath_defaults): rpaths=rpath_defaults):
......
...@@ -12,7 +12,7 @@ try: ...@@ -12,7 +12,7 @@ try:
# We must do those import to be able to create the full doc when nvcc # We must do those import to be able to create the full doc when nvcc
# is not available # is not available
import cuda_ndarray.cuda_ndarray as cuda import cuda_ndarray.cuda_ndarray as cuda
from theano.sandbox.cuda.nvcc_compiler import nvcc_module_compile_str from theano.sandbox.cuda.nvcc_compiler import NVCC_compiler
import cuda_ndarray import cuda_ndarray
except ImportError: except ImportError:
pass pass
...@@ -370,7 +370,7 @@ class CudaNdarrayType(Type): ...@@ -370,7 +370,7 @@ class CudaNdarrayType(Type):
return (2,) # with assertion about refcounts return (2,) # with assertion about refcounts
def c_compiler(self): def c_compiler(self):
return nvcc_module_compile_str return NVCC_compiler
def c_compile_args(self): def c_compile_args(self):
ret = [] ret = []
......
...@@ -50,8 +50,8 @@ except ImportError: ...@@ -50,8 +50,8 @@ except ImportError:
loc = os.path.join(config.compiledir, dirname) loc = os.path.join(config.compiledir, dirname)
if not os.path.exists(loc): if not os.path.exists(loc):
os.mkdir(loc) os.mkdir(loc)
cmodule.gcc_module_compile_str(dirname, code, location=loc, cmodule.GCC_compiler.compile_str(dirname, code, location=loc,
preargs = ['-pthread','-fwrapv', preargs=['-pthread', '-fwrapv',
'-O2', '-O2',
'-fno-strict-aliasing']) '-fno-strict-aliasing'])
# Save version into the __init__.py file. # Save version into the __init__.py file.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论