提交 4e9eafa6 authored 作者: Gabe Schwartz's avatar Gabe Schwartz

Fixed CURAND wrapper to work with capsules.

The capsule API destructor gets passed a pointer to the capsule, not the underlying C object as in the old PyCObject API.
上级 051a6b95
...@@ -107,16 +107,23 @@ class CURAND_Base(GpuOp): ...@@ -107,16 +107,23 @@ class CURAND_Base(GpuOp):
def c_support_code(self): def c_support_code(self):
return """ return """
#if PY_MAJOR_VERSION >= 3
void free_generator(PyObject *_gen)
{
curandGenerator_t * gen = (curandGenerator_t*)NpyCapsule_AsVoidPtr(_gen);
#else
void free_generator(void *_gen) void free_generator(void *_gen)
{ {
curandGenerator_t * gen = (curandGenerator_t*)_gen; curandGenerator_t * gen = (curandGenerator_t*)_gen;
#endif
curandStatus_t err = curandDestroyGenerator(*gen); curandStatus_t err = curandDestroyGenerator(*gen);
if (err != CURAND_STATUS_SUCCESS) if (err != CURAND_STATUS_SUCCESS)
{ {
fprintf(stderr, "Failure (%%i) in destroying CURAND generator", fprintf(stderr, "Failure (%i) in destroying CURAND generator.\\n",
(int)err); (int)err);
} }
free(_gen); free(gen);
} }
""" """
...@@ -136,10 +143,6 @@ class CURAND_Base(GpuOp): ...@@ -136,10 +143,6 @@ class CURAND_Base(GpuOp):
code = """ code = """
//////// <code generated by CURAND_Base> //////// <code generated by CURAND_Base>
#if PY_MAJOR_VERSION >= 3
#include "numpy/npy_3kcompat.h"
#endif
int odims[%(ndim)s]; int odims[%(ndim)s];
int n_elements = 1; int n_elements = 1;
int must_alloc_sample = ((NULL == %(o_sample)s) int must_alloc_sample = ((NULL == %(o_sample)s)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论