void k_fetchData_%(nodename)s(float* partialCumSum, float* input, int globalThreadID, dim3 dataStrides, int offsetY, int offsetZ, int nbElementsPerCumsum) {
// blockIdx.y and blockIdx.z represents the current independent cumsum
int idY = blockIdx.y + offsetY;
int idZ = blockIdx.z + offsetZ;
int offset = idY * dataStrides.y + idZ * dataStrides.z;
int idx_even = (globalThreadID*2 ) * dataStrides.x + offset;
int idx_odd = (globalThreadID*2 + 1) * dataStrides.x + offset;
void k_pushData_%(nodename)s(float* partialCumSum, float* output, int globalThreadID, dim3 dataStrides, int dataOffset) {
void k_pushData_%(nodename)s(float* partialCumSum, float* output, int globalThreadID, dim3 dataStrides, int offsetY, int offsetZ, int nbElementsPerCumsum) {
__syncthreads();
// blockIdx.y represents the # of the current independent cumsum
void k_cumadd_%(nodename)s(float* input, float* output, dim3 inputStrides, dim3 outputStrides, int dataOffset, int beforeLastElementIdx, int lastElementIdx) {
int dataOffsetY_input = (blockIdx.y + dataOffset) * inputStrides.y;
int dataOffsetY_output = (blockIdx.y + dataOffset) * outputStrides.y;
void k_cumadd_%(nodename)s(float* input, float* output, dim3 inputStrides, dim3 outputStrides, int offsetY, int offsetZ, int beforeLastElementIdx, int lastElementIdx) {
int idY = blockIdx.y + offsetY;
int idZ = blockIdx.z + offsetZ;
int dataOffsetY_input = idY * inputStrides.y + idZ * inputStrides.z;
int dataOffsetY_output = idY * outputStrides.y + idZ * outputStrides.z;
int idx_last_input = lastElementIdx*inputStrides.x + dataOffsetY_input;
int idx_last_output = lastElementIdx*outputStrides.x + dataOffsetY_output;
...
...
@@ -127,39 +138,42 @@ class GpuCumsum(CumsumOp, GpuOp):
}
__global__
void k_finalCumSum_%(nodename)s(float* output, float* blockSum, int numElements, dim3 dataStrides, int dataOffset) {
void k_finalCumSum_%(nodename)s(float* output, float* blockSum, int nbElementsPerCumsum, dim3 dataStrides, int offsetY, int offsetZ) {
int globalThreadID = (blockIdx.x + 1) * blockDim.x + threadIdx.x;
// Check if current has data to process.
if (globalThreadID >= ceil(numElements/2.0)) {
if (globalThreadID >= ceil(nbElementsPerCumsum/2.0)) {