提交 1bafd598 authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier

Replaced calls to PyGpuArray_DATA with calls to cuda_get_ptr()

上级 a3810051
...@@ -43,6 +43,9 @@ class GpuImages2Neibs(Images2Neibs, Op): ...@@ -43,6 +43,9 @@ class GpuImages2Neibs(Images2Neibs, Op):
def c_compiler(self): def c_compiler(self):
return NVCC_compiler return NVCC_compiler
def c_init_code(self):
return ['setup_ext_cuda();']
def c_support_code_apply(self, node, nodename): def c_support_code_apply(self, node, nodename):
mode = self.mode mode = self.mode
...@@ -197,12 +200,13 @@ class GpuImages2Neibs(Images2Neibs, Op): ...@@ -197,12 +200,13 @@ class GpuImages2Neibs(Images2Neibs, Op):
""" % locals() """ % locals()
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
dtype_ten4 = node.inputs[0].dtype
dtype_z = node.outputs[0].dtype
typecode_z = pygpu.gpuarray.dtype_to_typecode(node.outputs[0].dtype)
ten4, neib_shape, neib_step = inp ten4, neib_shape, neib_step = inp
z, = out z, = out
fail = sub['fail'] fail = sub['fail']
mode = self.mode mode = self.mode
typecode_z = pygpu.gpuarray.dtype_to_typecode(node.outputs[0].dtype)
return """ return """
#ifndef CEIL_INTDIV #ifndef CEIL_INTDIV
#define CEIL_INTDIV(a, b) ((a/b) + ((a %% b) ? 1: 0)) #define CEIL_INTDIV(a, b) ((a/b) + ((a %% b) ? 1: 0))
...@@ -386,10 +390,13 @@ class GpuImages2Neibs(Images2Neibs, Op): ...@@ -386,10 +390,13 @@ class GpuImages2Neibs(Images2Neibs, Op):
PyGpuArray_STRIDES(%(ten4)s)[1], PyGpuArray_STRIDES(%(ten4)s)[1],
PyGpuArray_STRIDES(%(ten4)s)[2], PyGpuArray_STRIDES(%(ten4)s)[2],
PyGpuArray_STRIDES(%(ten4)s)[3], PyGpuArray_STRIDES(%(ten4)s)[3],
PyGpuArray_DATA(%(ten4)s), (npy_%(dtype_ten4)s*)(
((char *)cuda_get_ptr(%(ten4)s->ga.data)) +
%(ten4)s->ga.offset),
PyGpuArray_STRIDES(%(z)s)[0], PyGpuArray_STRIDES(%(z)s)[0],
PyGpuArray_STRIDES(%(z)s)[1], PyGpuArray_STRIDES(%(z)s)[1],
PyGpuArray_DATA(%(z)s) (npy_%(dtype_z)s*)(((char *)cuda_get_ptr(%(z)s->ga.data)) +
%(z)s->ga.offset),
); );
CNDA_THREAD_SYNC; CNDA_THREAD_SYNC;
cudaError_t sts = cudaGetLastError(); cudaError_t sts = cudaGetLastError();
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论