void k_cumadd_%(nodename)s(float* input, float* output, dim3 inputStrides, dim3 outputStrides, int dataOffset, int beforeLastElementIdx, int lastElementIdx) {
void k_cumadd_%(nodename)s(float* input, float* output, dim3 inputStrides, dim3 outputStrides, int offsetY, int offsetZ, int beforeLastElementIdx, int lastElementIdx) {
int dataOffsetY_input = (blockIdx.y + dataOffset) * inputStrides.y;
int idY = blockIdx.y + offsetY;
int dataOffsetY_output = (blockIdx.y + dataOffset) * outputStrides.y;
int idZ = blockIdx.z + offsetZ;
int dataOffsetY_input = idY * inputStrides.y + idZ * inputStrides.z;
int dataOffsetY_output = idY * outputStrides.y + idZ * outputStrides.z;
int idx_last_input = lastElementIdx*inputStrides.x + dataOffsetY_input;
int idx_last_input = lastElementIdx*inputStrides.x + dataOffsetY_input;
int idx_last_output = lastElementIdx*outputStrides.x + dataOffsetY_output;
int idx_last_output = lastElementIdx*outputStrides.x + dataOffsetY_output;
...
@@ -127,39 +138,42 @@ class GpuCumsum(CumsumOp, GpuOp):
...
@@ -127,39 +138,42 @@ class GpuCumsum(CumsumOp, GpuOp):
}
}
__global__
__global__
void k_finalCumSum_%(nodename)s(float* output, float* blockSum, int numElements, dim3 dataStrides, int dataOffset) {
void k_finalCumSum_%(nodename)s(float* output, float* blockSum, int nbElementsPerCumsum, dim3 dataStrides, int offsetY, int offsetZ) {
int globalThreadID = (blockIdx.x + 1) * blockDim.x + threadIdx.x;
int globalThreadID = (blockIdx.x + 1) * blockDim.x + threadIdx.x;
// Check if current has data to process.
// Check if current has data to process.
if (globalThreadID >= ceil(numElements/2.0)) {
if (globalThreadID >= ceil(nbElementsPerCumsum/2.0)) {