提交 057f6a5f authored 作者: Gabe Schwartz's avatar Gabe Schwartz

PyCObject instead of NpyCapsule for py3k compat.

上级 85f83402
...@@ -9,6 +9,7 @@ __contact__ = "theano-dev@googlegroups.com" ...@@ -9,6 +9,7 @@ __contact__ = "theano-dev@googlegroups.com"
import numpy import numpy
import theano.gof import theano.gof
from theano.compat import PY3
from theano.gof.python25 import all from theano.gof.python25 import all
from theano.sandbox.cuda import CudaNdarrayType, GpuOp from theano.sandbox.cuda import CudaNdarrayType, GpuOp
from theano.tensor import (get_vector_length, cast, opt) from theano.tensor import (get_vector_length, cast, opt)
...@@ -133,8 +134,11 @@ class CURAND_Base(GpuOp): ...@@ -133,8 +134,11 @@ class CURAND_Base(GpuOp):
else: else:
otype = 'double' otype = 'double'
return """ code = """
//////// <code generated by CURAND_Base> //////// <code generated by CURAND_Base>
#if PY_MAJOR_VERSION >= 3
#include "numpy/npy_3kcompat.h"
#endif
int odims[%(ndim)s]; int odims[%(ndim)s];
int n_elements = 1; int n_elements = 1;
...@@ -221,8 +225,12 @@ class CURAND_Base(GpuOp): ...@@ -221,8 +225,12 @@ class CURAND_Base(GpuOp):
//////// </ code generated by CURAND_Base> //////// </ code generated by CURAND_Base>
""" % locals() """ % locals()
if PY3:
code = code.replace("PyCObject", "NpyCapsule")
return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (3,)
class CURAND_Normal(CURAND_Base): class CURAND_Normal(CURAND_Base):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论