提交 2ad36935 authored 作者: Marc-Alexandre Cote's avatar Marc-Alexandre Cote

Force cast of dim3's attributes to int.

上级 37b40abd
...@@ -78,7 +78,7 @@ class GpuCumsum(CumsumOp, GpuOp): ...@@ -78,7 +78,7 @@ class GpuCumsum(CumsumOp, GpuOp):
compute_map, no_recycling) compute_map, no_recycling)
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (3,)
def c_support_code_apply(self, node, nodename): def c_support_code_apply(self, node, nodename):
return """ return """
...@@ -107,23 +107,28 @@ class GpuCumsum(CumsumOp, GpuOp): ...@@ -107,23 +107,28 @@ class GpuCumsum(CumsumOp, GpuOp):
__device__ __device__
void k_fetchData_%(nodename)s(float* partialCumSum, float* input, int globalThreadID, dim3 dataStrides, int dataOffset) { void k_fetchData_%(nodename)s(float* partialCumSum, float* input, int globalThreadID, dim3 dataStrides, int dataOffset) {
// blockIdx.y represents the # of the current independent cumsum // blockIdx.y represents the # of the current independent cumsum
partialCumSum[threadIdx.x*2] = input[(globalThreadID*2 ) * dataStrides.x + (blockIdx.y + dataOffset) * dataStrides.y]; int idx_even = (globalThreadID*2 ) * dataStrides.x + (blockIdx.y + dataOffset) * dataStrides.y;
partialCumSum[threadIdx.x*2 + 1] = input[(globalThreadID*2 + 1) * dataStrides.x + (blockIdx.y + dataOffset) * dataStrides.y]; int idx_odd = (globalThreadID*2 + 1) * dataStrides.x + (blockIdx.y + dataOffset) * dataStrides.y;
partialCumSum[threadIdx.x*2] = input[idx_even];
partialCumSum[threadIdx.x*2 + 1] = input[idx_odd];
} }
__device__ __device__
void k_pushData_%(nodename)s(float* partialCumSum, float* output, int globalThreadID, dim3 dataStrides, int dataOffset) { void k_pushData_%(nodename)s(float* partialCumSum, float* output, int globalThreadID, dim3 dataStrides, int dataOffset) {
__syncthreads(); __syncthreads();
// blockIdx.y represents the # of the current independent cumsum // blockIdx.y represents the # of the current independent cumsum
output[(globalThreadID*2 ) * dataStrides.x + (blockIdx.y + dataOffset) * dataStrides.y] = partialCumSum[threadIdx.x*2]; int idx_even = (globalThreadID*2 ) * dataStrides.x + (blockIdx.y + dataOffset) * dataStrides.y;
output[(globalThreadID*2 + 1) * dataStrides.x + (blockIdx.y + dataOffset) * dataStrides.y] = partialCumSum[threadIdx.x*2 + 1]; int idx_odd = (globalThreadID*2 + 1) * dataStrides.x + (blockIdx.y + dataOffset) * dataStrides.y;
output[idx_even] = partialCumSum[threadIdx.x*2];
output[idx_odd] = partialCumSum[threadIdx.x*2 + 1];
} }
__global__ __global__
void k_cumadd_%(nodename)s(float* input, float* output, dim3 dataStrides, int dataOffset, int beforeLastElementIdx, int lastElementIdx) { void k_cumadd_%(nodename)s(float* input, float* output, dim3 dataStrides, int dataOffset, int beforeLastElementIdx, int lastElementIdx) {
int dataOffsetY = (blockIdx.y + dataOffset) * dataStrides.y; int dataOffsetY = (blockIdx.y + dataOffset) * dataStrides.y;
output[lastElementIdx*dataStrides.x + dataOffsetY] = input[lastElementIdx*dataStrides.x + dataOffsetY] int idx_last = lastElementIdx*dataStrides.x + dataOffsetY;
+ output[beforeLastElementIdx*dataStrides.x + dataOffsetY]; int idx_beforelast = beforeLastElementIdx*dataStrides.x + dataOffsetY;
output[idx_last] = input[idx_last] + output[idx_beforelast];
} }
__global__ __global__
...@@ -137,9 +142,11 @@ class GpuCumsum(CumsumOp, GpuOp): ...@@ -137,9 +142,11 @@ class GpuCumsum(CumsumOp, GpuOp):
const float currentBlockSum = blockSum[blockIdx.x*gridDim.y + blockIdx.y + dataOffset]; const float currentBlockSum = blockSum[blockIdx.x*gridDim.y + blockIdx.y + dataOffset];
int dataOffsetY = (blockIdx.y + dataOffset) * dataStrides.y; int dataOffsetY = (blockIdx.y + dataOffset) * (int)dataStrides.y;
output[(globalThreadID*2 ) * dataStrides.x + dataOffsetY] += currentBlockSum; int idx_even = (globalThreadID*2 ) * dataStrides.x + dataOffsetY;
output[(globalThreadID*2 + 1) * dataStrides.x + dataOffsetY] += currentBlockSum; int idx_odd = (globalThreadID*2 + 1) * dataStrides.x + dataOffsetY;
output[idx_even] += currentBlockSum;
output[idx_odd] += currentBlockSum;
} }
__global__ __global__
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论