提交 201dbaf8 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Make GpuDownSampleFactorMax use non-contiguous out

上级 0f63d704
...@@ -801,7 +801,7 @@ class GpuDownsampleFactorMax(GpuOp): ...@@ -801,7 +801,7 @@ class GpuDownsampleFactorMax(GpuOp):
#def perform(self, node, input_storage, output_storage): #def perform(self, node, input_storage, output_storage):
#raise NotImplementedError('only C is implemented') #raise NotImplementedError('only C is implemented')
def c_code_cache_version(self): def c_code_cache_version(self):
return (3) return (4)
def c_code(self, node, nodename, inp, out, sub): def c_code(self, node, nodename, inp, out, sub):
x, = inp x, = inp
...@@ -867,7 +867,11 @@ class GpuDownsampleFactorMax(GpuOp): ...@@ -867,7 +867,11 @@ class GpuDownsampleFactorMax(GpuOp):
CudaNdarray_HOST_STRIDES(%(x)s)[1], CudaNdarray_HOST_STRIDES(%(x)s)[1],
CudaNdarray_HOST_STRIDES(%(x)s)[2], CudaNdarray_HOST_STRIDES(%(x)s)[2],
CudaNdarray_HOST_STRIDES(%(x)s)[3], CudaNdarray_HOST_STRIDES(%(x)s)[3],
CudaNdarray_DEV_DATA(%(z)s)); CudaNdarray_DEV_DATA(%(z)s),
CudaNdarray_HOST_STRIDES(%(z)s)[0],
CudaNdarray_HOST_STRIDES(%(z)s)[1],
CudaNdarray_HOST_STRIDES(%(z)s)[2],
CudaNdarray_HOST_STRIDES(%(z)s)[3]);
CNDA_THREAD_SYNC; CNDA_THREAD_SYNC;
cudaError_t err = cudaGetLastError(); cudaError_t err = cudaGetLastError();
if( cudaSuccess != err) if( cudaSuccess != err)
...@@ -894,7 +898,7 @@ class GpuDownsampleFactorMax(GpuOp): ...@@ -894,7 +898,7 @@ class GpuDownsampleFactorMax(GpuOp):
__global__ void kMaxPool_%(nodename)s( __global__ void kMaxPool_%(nodename)s(
int D0, int D1, int D2, int D3, int xD2, int xD3, int D0, int D1, int D2, int D3, int xD2, int xD3,
const float * x, int xS0, int xS1, int xS2, int xS3, const float * x, int xS0, int xS1, int xS2, int xS3,
float *z) float *z, int zS0, int zS1, int zS2, int zS3)
{ {
float cur_max, cur_x; float cur_max, cur_x;
int i0 = blockIdx.x %% D0; int i0 = blockIdx.x %% D0;
...@@ -943,7 +947,7 @@ class GpuDownsampleFactorMax(GpuOp): ...@@ -943,7 +947,7 @@ class GpuDownsampleFactorMax(GpuOp):
} }
//store the result to global memory //store the result to global memory
z[i0 * D1*D2*D3 + i1*D2*D3 + i2*D3 + threadIdx.x] = cur_max; z[i0*zS0 + i1*zS1 + i2*zS2 + threadIdx.x*zS3] = cur_max;
} }
""" % locals() """ % locals()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论