提交 2f1c91cb authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1422 from HapeMask/py3k-fixes

Fix CURAND Wrapper in Python 3
......@@ -3,25 +3,21 @@
// Defines for Python 2/3 compatibility.
#if PY_MAJOR_VERSION >= 3
// Py3k treats all ints as longs.
#define PyInt_Check PyLong_Check
#define PyInt_CheckExact PyLong_CheckExact
#define PyInt_AsLong PyLong_AsLong
#define PyInt_FromLong PyLong_FromLong
// Py3k treats all ints as longs. This one is not caught by npy_3kcompat.h.
#define PyNumber_Int PyNumber_Long
#include "numpy/npy_3kcompat.h"
// Py3k strings are unicode, these mimic old functionality.
//
// NOTE: npy_3kcompat.h replaces PyString_X with PyBytes_X, which breaks
// compatibility with some functions returning text.
#define PyString_Check PyUnicode_Check
#define PyString_FromString PyUnicode_FromString
#define PyString_AsString PyUnicode_AsUTF8
#define PyString_FromStringAndSize PyUnicode_FromStringAndSize
#define PyString_Size PyUnicode_GET_SIZE
#include "numpy/npy_3kcompat.h"
#define PyCObject_AsVoidPtr NpyCapsule_AsVoidPtr
#define PyCObject_GetDesc NpyCapsule_GetDesc
#define PyCObject_Check NpyCapsule_Check
// Python 3 expects a PyObject* as the first argument to PySlice_GetIndicesEx().
#define SLICE_CAST(x) (x)
#else
......
......@@ -107,16 +107,23 @@ class CURAND_Base(GpuOp):
def c_support_code(self):
return """
#if PY_MAJOR_VERSION >= 3
void free_generator(PyObject *_gen)
{
curandGenerator_t * gen = (curandGenerator_t*)NpyCapsule_AsVoidPtr(_gen);
#else
void free_generator(void *_gen)
{
curandGenerator_t * gen = (curandGenerator_t*)_gen;
#endif
curandStatus_t err = curandDestroyGenerator(*gen);
if (err != CURAND_STATUS_SUCCESS)
{
fprintf(stderr, "Failure (%%i) in destroying CURAND generator",
fprintf(stderr, "Failure (%i) in destroying CURAND generator.\\n",
(int)err);
}
free(_gen);
free(gen);
}
"""
......@@ -136,10 +143,6 @@ class CURAND_Base(GpuOp):
code = """
//////// <code generated by CURAND_Base>
#if PY_MAJOR_VERSION >= 3
#include "numpy/npy_3kcompat.h"
#endif
int odims[%(ndim)s];
int n_elements = 1;
int must_alloc_sample = ((NULL == %(o_sample)s)
......
......@@ -2479,7 +2479,7 @@ class Shape(Op):
#TODO: if your type is not listed here, make a damn registry of
# shape_i ops for various types of variables.
# Do not continue this madness.
return super(Shape_i, self).c_code(node, name, (x,), (out,), sub)
return super(Shape, self).c_code(node, nodename, (x,), (out,), sub)
def c_code_cache_version(self):
return (1,)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论