提交 83dadc9d authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2765 from abergeron/fix_mac_cuda

Fix cuda segfault on mac
...@@ -41,6 +41,14 @@ from theano.configdefaults import config ...@@ -41,6 +41,14 @@ from theano.configdefaults import config
# Version information. # Version information.
from theano.version import version as __version__ from theano.version import version as __version__
# This is the api version for ops that generate C code. External ops
# might need manual changes if this number goes up. An undefined
# __api_version__ can be understood to mean api version 0.
#
# This number is not tied to the release version and should change
# very rarely.
__api_version__ = 1
from theano.gof import ( from theano.gof import (
CLinker, OpWiseCLinker, DualLinker, Linker, LocalLinker, PerformLinker, CLinker, OpWiseCLinker, DualLinker, Linker, LocalLinker, PerformLinker,
Container, Container,
......
...@@ -149,7 +149,7 @@ class DynamicModule(object): ...@@ -149,7 +149,7 @@ class DynamicModule(object):
self.support_code = [] self.support_code = []
self.functions = [] self.functions = []
self.includes = ["<Python.h>", "<iostream>"] self.includes = ["<Python.h>", "<iostream>", '"theano_mod_helper.h"']
self.init_blocks = [] self.init_blocks = []
def print_methoddef(self, stream): def print_methoddef(self, stream):
...@@ -1447,7 +1447,8 @@ def std_include_dirs(): ...@@ -1447,7 +1447,8 @@ def std_include_dirs():
py_plat_spec_inc = distutils.sysconfig.get_python_inc(plat_specific=True) py_plat_spec_inc = distutils.sysconfig.get_python_inc(plat_specific=True)
python_inc_dirs = ([py_inc] if py_inc == py_plat_spec_inc python_inc_dirs = ([py_inc] if py_inc == py_plat_spec_inc
else [py_inc, py_plat_spec_inc]) else [py_inc, py_plat_spec_inc])
return numpy_inc_dirs + python_inc_dirs gof_inc_dir = os.path.dirname(__file__)
return numpy_inc_dirs + python_inc_dirs + [gof_inc_dir]
def std_lib_dirs_and_libs(): def std_lib_dirs_and_libs():
...@@ -1927,7 +1928,7 @@ class GCC_compiler(Compiler): ...@@ -1927,7 +1928,7 @@ class GCC_compiler(Compiler):
@staticmethod @staticmethod
def compile_str(module_name, src_code, location=None, def compile_str(module_name, src_code, location=None,
include_dirs=None, lib_dirs=None, libs=None, include_dirs=None, lib_dirs=None, libs=None,
preargs=None, py_module=True): preargs=None, py_module=True, hide_symbols=True):
""" """
:param module_name: string (this has been embedded in the src_code :param module_name: string (this has been embedded in the src_code
...@@ -1950,6 +1951,10 @@ class GCC_compiler(Compiler): ...@@ -1950,6 +1951,10 @@ class GCC_compiler(Compiler):
:param py_module: if False, compile to a shared library, but do not :param py_module: if False, compile to a shared library, but do not
import it as a Python module. import it as a Python module.
:param hide_symbols: if True (the default) all symbols will be
hidden from the library symbol table (which means that other
objects can't use them.
:returns: dynamically-imported python module of the compiled code. :returns: dynamically-imported python module of the compiled code.
(unless py_module is False, in that case returns None.) (unless py_module is False, in that case returns None.)
""" """
...@@ -2004,6 +2009,14 @@ class GCC_compiler(Compiler): ...@@ -2004,6 +2009,14 @@ class GCC_compiler(Compiler):
else: else:
cmd.extend(preargs) cmd.extend(preargs)
cmd.extend('-I%s' % idir for idir in include_dirs) cmd.extend('-I%s' % idir for idir in include_dirs)
if hide_symbols and sys.platform != 'win32':
# This has been available since gcc 4.0 so we suppose it
# is always available. We pass it here since it
# significantly reduces the size of the symbol table for
# the objects we want to share. This in turns leads to
# improved loading times on most platforms (win32 is
# different, as usual).
cmd.append('-fvisibility=hidden')
cmd.extend(['-o', lib_filename]) cmd.extend(['-o', lib_filename])
cmd.append(cppfilename) cmd.append(cppfilename)
cmd.extend(['-L%s' % ldir for ldir in lib_dirs]) cmd.extend(['-L%s' % ldir for ldir in lib_dirs])
......
...@@ -198,6 +198,7 @@ def compile_cutils(): ...@@ -198,6 +198,7 @@ def compile_cutils():
code = (""" code = ("""
#include <Python.h> #include <Python.h>
#include "numpy/arrayobject.h" #include "numpy/arrayobject.h"
#include "theano_mod_helper.h"
extern "C"{ extern "C"{
static PyObject * static PyObject *
......
#include <Python.h> #include <Python.h>
#include "theano_mod_helper.h"
#include "structmember.h" #include "structmember.h"
#include <sys/time.h> #include <sys/time.h>
// Old Python compatibility from here:
// http://www.python.org/dev/peps/pep-0353/
#if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN)
typedef int Py_ssize_t;
#define PY_SSIZE_T_MAX INT_MAX
#define PY_SSIZE_T_MIN INT_MIN
// This one was taken from:
// http://svn.python.org/projects/python/trunk/Modules/_ctypes/ctypes.h
#define PyNumber_AsSsize_t(ob, exc) PyInt_AsLong(ob)
#endif
#if PY_VERSION_HEX >= 0x03000000 #if PY_VERSION_HEX >= 0x03000000
#include "numpy/npy_3kcompat.h" #include "numpy/npy_3kcompat.h"
#define PyCObject_AsVoidPtr NpyCapsule_AsVoidPtr #define PyCObject_AsVoidPtr NpyCapsule_AsVoidPtr
...@@ -1033,10 +1023,6 @@ static PyMethodDef lazylinker_ext_methods[] = { ...@@ -1033,10 +1023,6 @@ static PyMethodDef lazylinker_ext_methods[] = {
{NULL, NULL, 0, NULL} /* Sentinel */ {NULL, NULL, 0, NULL} /* Sentinel */
}; };
#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */
#define PyMODINIT_FUNC void
#endif
#if defined(NPY_PY3K) #if defined(NPY_PY3K)
static struct PyModuleDef moduledef = { static struct PyModuleDef moduledef = {
PyModuleDef_HEAD_INIT, PyModuleDef_HEAD_INIT,
...@@ -1076,4 +1062,3 @@ initlazylinker_ext(void) ...@@ -1076,4 +1062,3 @@ initlazylinker_ext(void)
return RETVAL; return RETVAL;
} }
#ifndef THEANO_MOD_HELPER
#define THEANO_MOD_HELPER
#include <Python.h>
#ifndef _WIN32
#define MOD_PUBLIC __attribute__((visibility ("default")))
#else
#define MOD_PUBLIC
#endif
#ifdef __cplusplus
#define THEANO_EXTERN extern "C"
#else
#define THEANO_EXTERN
#endif
#if PY_MAJOR_VERSION < 3
#define THEANO_RTYPE void
#else
#define THEANO_RTYPE PyObject *
#endif
/* We need to redefine PyMODINIT_FUNC to add MOD_PUBLIC in the middle */
#undef PyMODINIT_FUNC
#define PyMODINIT_FUNC THEANO_EXTERN MOD_PUBLIC THEANO_RTYPE
#endif
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <Python.h> #include <Python.h>
#include <structmember.h> #include <structmember.h>
#include "theano_mod_helper.h"
#include <numpy/arrayobject.h> #include <numpy/arrayobject.h>
#include <iostream> #include <iostream>
...@@ -3466,10 +3467,6 @@ static PyMethodDef module_methods[] = { ...@@ -3466,10 +3467,6 @@ static PyMethodDef module_methods[] = {
{NULL, NULL, NULL, NULL} /* Sentinel */ {NULL, NULL, NULL, NULL} /* Sentinel */
}; };
#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */
#define PyMODINIT_FUNC void
#endif
#define CNDA_MOD_NAME "cuda_ndarray" #define CNDA_MOD_NAME "cuda_ndarray"
#define CNDA_DOCSTRING "CUDA implementation of a numpy ndarray-like object." #define CNDA_DOCSTRING "CUDA implementation of a numpy ndarray-like object."
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
#include <algorithm> #include <algorithm>
#include "theano_mod_helper.h"
// Defines for Python 2/3 compatibility. // Defines for Python 2/3 compatibility.
#if PY_MAJOR_VERSION >= 3 #if PY_MAJOR_VERSION >= 3
// Py3k treats all ints as longs. This one is not caught by npy_3kcompat.h. // Py3k treats all ints as longs. This one is not caught by npy_3kcompat.h.
...@@ -46,15 +48,15 @@ ...@@ -46,15 +48,15 @@
#include <cublas_v2.h> #include <cublas_v2.h>
#ifdef _WIN32 #ifdef _WIN32
#ifdef _CUDA_NDARRAY_C # ifdef _CUDA_NDARRAY_C
#define DllExport __declspec( dllexport ) # define DllExport __declspec( dllexport )
#else # else
#define DllExport __declspec( dllimport ) # define DllExport __declspec( dllimport )
#endif # endif
#define ALWAYS_INLINE # define ALWAYS_INLINE
#else //else _WIN32 #else //else _WIN32
#define DllExport # define DllExport MOD_PUBLIC
#define ALWAYS_INLINE __attribute__((always_inline)) # define ALWAYS_INLINE __attribute__((always_inline))
#endif #endif
typedef float real; typedef float real;
......
...@@ -162,18 +162,18 @@ class NVCC_compiler(Compiler): ...@@ -162,18 +162,18 @@ class NVCC_compiler(Compiler):
# to use the new API, but not everywhere. When finished, enable # to use the new API, but not everywhere. When finished, enable
# the following macro to assert that we don't bring new code # the following macro to assert that we don't bring new code
# that use the old API. # that use the old API.
flags.append("-D NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION") flags.append("-DNPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION")
# numpy 1.7 deprecated the following macro but the didn't # numpy 1.7 deprecated the following macro but the didn't
# existed in the past # existed in the past
numpy_ver = [int(n) for n in numpy.__version__.split('.')[:2]] numpy_ver = [int(n) for n in numpy.__version__.split('.')[:2]]
if bool(numpy_ver < [1, 7]): if bool(numpy_ver < [1, 7]):
flags.append("-D NPY_ARRAY_ENSURECOPY=NPY_ENSURECOPY") flags.append("-DNPY_ARRAY_ENSURECOPY=NPY_ENSURECOPY")
flags.append("-D NPY_ARRAY_ALIGNED=NPY_ALIGNED") flags.append("-DNPY_ARRAY_ALIGNED=NPY_ALIGNED")
flags.append("-D NPY_ARRAY_WRITEABLE=NPY_WRITEABLE") flags.append("-DNPY_ARRAY_WRITEABLE=NPY_WRITEABLE")
flags.append("-D NPY_ARRAY_UPDATE_ALL=NPY_UPDATE_ALL") flags.append("-DNPY_ARRAY_UPDATE_ALL=NPY_UPDATE_ALL")
flags.append("-D NPY_ARRAY_C_CONTIGUOUS=NPY_C_CONTIGUOUS") flags.append("-DNPY_ARRAY_C_CONTIGUOUS=NPY_C_CONTIGUOUS")
flags.append("-D NPY_ARRAY_F_CONTIGUOUS=NPY_F_CONTIGUOUS") flags.append("-DNPY_ARRAY_F_CONTIGUOUS=NPY_F_CONTIGUOUS")
# If the user didn't specify architecture flags add them # If the user didn't specify architecture flags add them
if not any(['-arch=sm_' in f for f in flags]): if not any(['-arch=sm_' in f for f in flags]):
...@@ -208,7 +208,7 @@ class NVCC_compiler(Compiler): ...@@ -208,7 +208,7 @@ class NVCC_compiler(Compiler):
def compile_str( def compile_str(
module_name, src_code, module_name, src_code,
location=None, include_dirs=[], lib_dirs=[], libs=[], preargs=[], location=None, include_dirs=[], lib_dirs=[], libs=[], preargs=[],
rpaths=rpath_defaults, py_module=True): rpaths=rpath_defaults, py_module=True, hide_symbols=True):
""":param module_name: string (this has been embedded in the src_code """:param module_name: string (this has been embedded in the src_code
:param src_code: a complete c or c++ source listing for the module :param src_code: a complete c or c++ source listing for the module
:param location: a pre-existing filesystem directory where the :param location: a pre-existing filesystem directory where the
...@@ -225,6 +225,9 @@ class NVCC_compiler(Compiler): ...@@ -225,6 +225,9 @@ class NVCC_compiler(Compiler):
:param py_module: if False, compile to a shared library, but :param py_module: if False, compile to a shared library, but
do not import as a Python module. do not import as a Python module.
:param hide_symbols: if True (the default), hide all symbols
from the library symbol table unless explicitely exported.
:returns: dynamically-imported python module of the compiled code. :returns: dynamically-imported python module of the compiled code.
(unless py_module is False, in that case returns None.) (unless py_module is False, in that case returns None.)
...@@ -320,6 +323,9 @@ class NVCC_compiler(Compiler): ...@@ -320,6 +323,9 @@ class NVCC_compiler(Compiler):
# in both math_functions.h and pymath.h, # in both math_functions.h and pymath.h,
# by not including the one in pymath.h # by not including the one in pymath.h
cmd.extend(['-D HAVE_ROUND']) cmd.extend(['-D HAVE_ROUND'])
else:
if hide_symbols:
preargs2.append('-fvisibility=hidden')
if local_bitwidth() == 64: if local_bitwidth() == 64:
cmd.append('-m64') cmd.append('-m64')
......
...@@ -109,7 +109,8 @@ except ImportError: ...@@ -109,7 +109,8 @@ except ImportError:
preargs.append("-D NPY_F_CONTIGUOUS=NPY_ARRAY_F_CONTIGUOUS") preargs.append("-D NPY_F_CONTIGUOUS=NPY_ARRAY_F_CONTIGUOUS")
cmodule.GCC_compiler.compile_str(dirname, code, location=loc, cmodule.GCC_compiler.compile_str(dirname, code, location=loc,
preargs=preargs) preargs=preargs,
hide_symbols=False)
# Save version into the __init__.py file. # Save version into the __init__.py file.
init_py = os.path.join(loc, '__init__.py') init_py = os.path.join(loc, '__init__.py')
open(init_py, 'w').write('_version = %s\n' % version) open(init_py, 'w').write('_version = %s\n' % version)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论