提交 121d0c9d authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1414 from HapeMask/py3k-fixes

Use NpyCapsule instead of PyCObject for py3k compat.
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define _CUDA_NDARRAY_H #define _CUDA_NDARRAY_H
// Defines for Python 2/3 compatibility. // Defines for Python 2/3 compatibility.
#if PY_MAJOR_VERSION == 3 #if PY_MAJOR_VERSION >= 3
// Py3k treats all ints as longs. // Py3k treats all ints as longs.
#define PyInt_Check PyLong_Check #define PyInt_Check PyLong_Check
#define PyInt_CheckExact PyLong_CheckExact #define PyInt_CheckExact PyLong_CheckExact
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#define PyString_FromStringAndSize PyUnicode_FromStringAndSize #define PyString_FromStringAndSize PyUnicode_FromStringAndSize
#define PyString_Size PyUnicode_GET_SIZE #define PyString_Size PyUnicode_GET_SIZE
#include "numpy/npy_3kcompat.h"
#define PyCObject_AsVoidPtr NpyCapsule_AsVoidPtr #define PyCObject_AsVoidPtr NpyCapsule_AsVoidPtr
#define PyCObject_GetDesc NpyCapsule_GetDesc #define PyCObject_GetDesc NpyCapsule_GetDesc
#define PyCObject_Check NpyCapsule_Check #define PyCObject_Check NpyCapsule_Check
......
...@@ -9,6 +9,7 @@ __contact__ = "theano-dev@googlegroups.com" ...@@ -9,6 +9,7 @@ __contact__ = "theano-dev@googlegroups.com"
import numpy import numpy
import theano.gof import theano.gof
from theano.compat import PY3
from theano.gof.python25 import all from theano.gof.python25 import all
from theano.sandbox.cuda import CudaNdarrayType, GpuOp from theano.sandbox.cuda import CudaNdarrayType, GpuOp
from theano.tensor import (get_vector_length, cast, opt) from theano.tensor import (get_vector_length, cast, opt)
...@@ -133,8 +134,11 @@ class CURAND_Base(GpuOp): ...@@ -133,8 +134,11 @@ class CURAND_Base(GpuOp):
else: else:
otype = 'double' otype = 'double'
return """ code = """
//////// <code generated by CURAND_Base> //////// <code generated by CURAND_Base>
#if PY_MAJOR_VERSION >= 3
#include "numpy/npy_3kcompat.h"
#endif
int odims[%(ndim)s]; int odims[%(ndim)s];
int n_elements = 1; int n_elements = 1;
...@@ -221,8 +225,12 @@ class CURAND_Base(GpuOp): ...@@ -221,8 +225,12 @@ class CURAND_Base(GpuOp):
//////// </ code generated by CURAND_Base> //////// </ code generated by CURAND_Base>
""" % locals() """ % locals()
if PY3:
code = code.replace("PyCObject", "NpyCapsule")
return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (3,)
class CURAND_Normal(CURAND_Base): class CURAND_Normal(CURAND_Base):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论