提交 2f1c91cb authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1422 from HapeMask/py3k-fixes

Fix CURAND Wrapper in Python 3
...@@ -3,25 +3,21 @@ ...@@ -3,25 +3,21 @@
// Defines for Python 2/3 compatibility. // Defines for Python 2/3 compatibility.
#if PY_MAJOR_VERSION >= 3 #if PY_MAJOR_VERSION >= 3
// Py3k treats all ints as longs. // Py3k treats all ints as longs. This one is not caught by npy_3kcompat.h.
#define PyInt_Check PyLong_Check
#define PyInt_CheckExact PyLong_CheckExact
#define PyInt_AsLong PyLong_AsLong
#define PyInt_FromLong PyLong_FromLong
#define PyNumber_Int PyNumber_Long #define PyNumber_Int PyNumber_Long
#include "numpy/npy_3kcompat.h"
// Py3k strings are unicode, these mimic old functionality. // Py3k strings are unicode, these mimic old functionality.
//
// NOTE: npy_3kcompat.h replaces PyString_X with PyBytes_X, which breaks
// compatibility with some functions returning text.
#define PyString_Check PyUnicode_Check #define PyString_Check PyUnicode_Check
#define PyString_FromString PyUnicode_FromString #define PyString_FromString PyUnicode_FromString
#define PyString_AsString PyUnicode_AsUTF8 #define PyString_AsString PyUnicode_AsUTF8
#define PyString_FromStringAndSize PyUnicode_FromStringAndSize #define PyString_FromStringAndSize PyUnicode_FromStringAndSize
#define PyString_Size PyUnicode_GET_SIZE #define PyString_Size PyUnicode_GET_SIZE
#include "numpy/npy_3kcompat.h"
#define PyCObject_AsVoidPtr NpyCapsule_AsVoidPtr
#define PyCObject_GetDesc NpyCapsule_GetDesc
#define PyCObject_Check NpyCapsule_Check
// Python 3 expects a PyObject* as the first argument to PySlice_GetIndicesEx(). // Python 3 expects a PyObject* as the first argument to PySlice_GetIndicesEx().
#define SLICE_CAST(x) (x) #define SLICE_CAST(x) (x)
#else #else
......
...@@ -107,16 +107,23 @@ class CURAND_Base(GpuOp): ...@@ -107,16 +107,23 @@ class CURAND_Base(GpuOp):
def c_support_code(self): def c_support_code(self):
return """ return """
#if PY_MAJOR_VERSION >= 3
void free_generator(PyObject *_gen)
{
curandGenerator_t * gen = (curandGenerator_t*)NpyCapsule_AsVoidPtr(_gen);
#else
void free_generator(void *_gen) void free_generator(void *_gen)
{ {
curandGenerator_t * gen = (curandGenerator_t*)_gen; curandGenerator_t * gen = (curandGenerator_t*)_gen;
#endif
curandStatus_t err = curandDestroyGenerator(*gen); curandStatus_t err = curandDestroyGenerator(*gen);
if (err != CURAND_STATUS_SUCCESS) if (err != CURAND_STATUS_SUCCESS)
{ {
fprintf(stderr, "Failure (%%i) in destroying CURAND generator", fprintf(stderr, "Failure (%i) in destroying CURAND generator.\\n",
(int)err); (int)err);
} }
free(_gen); free(gen);
} }
""" """
...@@ -136,10 +143,6 @@ class CURAND_Base(GpuOp): ...@@ -136,10 +143,6 @@ class CURAND_Base(GpuOp):
code = """ code = """
//////// <code generated by CURAND_Base> //////// <code generated by CURAND_Base>
#if PY_MAJOR_VERSION >= 3
#include "numpy/npy_3kcompat.h"
#endif
int odims[%(ndim)s]; int odims[%(ndim)s];
int n_elements = 1; int n_elements = 1;
int must_alloc_sample = ((NULL == %(o_sample)s) int must_alloc_sample = ((NULL == %(o_sample)s)
......
...@@ -2479,7 +2479,7 @@ class Shape(Op): ...@@ -2479,7 +2479,7 @@ class Shape(Op):
#TODO: if your type is not listed here, make a damn registry of #TODO: if your type is not listed here, make a damn registry of
# shape_i ops for various types of variables. # shape_i ops for various types of variables.
# Do not continue this madness. # Do not continue this madness.
return super(Shape_i, self).c_code(node, name, (x,), (out,), sub) return super(Shape, self).c_code(node, nodename, (x,), (out,), sub)
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (1,)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论