提交 6dbb2457 authored 作者: Marc-Alexandre Cote's avatar Marc-Alexandre Cote

Cumsum 2D in cuda is working when axis=1.

上级 34239976
...@@ -68,13 +68,11 @@ class GpuCumsum(CumsumOp, GpuOp): ...@@ -68,13 +68,11 @@ class GpuCumsum(CumsumOp, GpuOp):
} }
} }
__global__ __global__
void k_cumadd_%(nodename)s(float* input, float* output, int beforeLastElementIdx, int lastElementIdx) { void k_cumadd_%(nodename)s(float* input, float* output, int beforeLastElementIdx, int lastElementIdx) {
output[lastElementIdx] = input[lastElementIdx] + output[beforeLastElementIdx]; output[lastElementIdx] = input[lastElementIdx] + output[beforeLastElementIdx];
} }
__global__ __global__
void k_blockCumSum_1D_%(nodename)s(float* input, float* output, int numElements, float* blockSum) { void k_blockCumSum_1D_%(nodename)s(float* input, float* output, int numElements, float* blockSum) {
int globalThreadID = blockIdx.x * blockDim.x + threadIdx.x; int globalThreadID = blockIdx.x * blockDim.x + threadIdx.x;
...@@ -105,7 +103,7 @@ class GpuCumsum(CumsumOp, GpuOp): ...@@ -105,7 +103,7 @@ class GpuCumsum(CumsumOp, GpuOp):
} }
} }
// Wtite the final output to global memory // Write the final output to global memory
__syncthreads(); __syncthreads();
output[globalThreadID*2] = partialCumSum[threadIdx.x*2]; output[globalThreadID*2] = partialCumSum[threadIdx.x*2];
output[globalThreadID*2 + 1] = partialCumSum[threadIdx.x*2 + 1]; output[globalThreadID*2 + 1] = partialCumSum[threadIdx.x*2 + 1];
...@@ -173,6 +171,141 @@ class GpuCumsum(CumsumOp, GpuOp): ...@@ -173,6 +171,141 @@ class GpuCumsum(CumsumOp, GpuOp):
cudaFree(CudaNdarray_DEV_DATA(deviceBlockSum)); cudaFree(CudaNdarray_DEV_DATA(deviceBlockSum));
cudaThreadSynchronize(); cudaThreadSynchronize();
} }
__global__
void k_finalCumSum_2D_axis1_%(nodename)s(float* output, float* blockSum, int numElements, dim3 dataStrides) {
int globalThreadID = (blockIdx.y + 1) * blockDim.y + threadIdx.y;
// Check if current has data to process.
if (globalThreadID >= ceil(numElements/2.0)) {
return;
}
const float currentBlockSum = blockSum[blockIdx.x*gridDim.y + blockIdx.y];
output[globalThreadID*2 + blockIdx.x*dataStrides.x] += currentBlockSum;
output[globalThreadID*2 + 1 + blockIdx.x*dataStrides.x] += currentBlockSum;
}
__global__
void k_cumadd_2D_axis1_%(nodename)s(float* input, float* output, int beforeLastElementIdx, int lastElementIdx) {
output[blockIdx.x*(lastElementIdx+1) + lastElementIdx] = input[blockIdx.x*(lastElementIdx+1) + lastElementIdx]
+ output[blockIdx.x*(lastElementIdx+1) + beforeLastElementIdx];
}
__global__
void k_blockCumSum_2D_axis1_%(nodename)s(float* input, float* output, int numElements, dim3 dataStrides, float* blockSum) {
int globalThreadID = blockIdx.y * blockDim.y + threadIdx.y;
// Check if current has data to process.
if (globalThreadID >= ceil(numElements/2.0)) {
return;
}
extern __shared__ float partialCumSum[];
// Load data in shared memory
partialCumSum[threadIdx.y*2] = input[globalThreadID*2 + blockIdx.x*dataStrides.x];
partialCumSum[threadIdx.y*2 + 1] = input[globalThreadID*2 + 1 + blockIdx.x*dataStrides.x];
// Reduction Phase
int stride;
for (stride = 1; stride <= blockDim.y; stride *= 2) {
__syncthreads();
int index = (threadIdx.y + 1) * (stride * 2) - 1;
if(index < blockDim.y*2) {
partialCumSum[index] += partialCumSum[index - stride];
}
}
// Reverse Phase
for (; stride > 0; stride /= 2) {
__syncthreads();
int index = (threadIdx.y + 1) * (stride * 2) - 1;
if(index + stride < blockDim.y*2) {
partialCumSum[index + stride] += partialCumSum[index];
}
}
// Write the final output to global memory
__syncthreads();
output[globalThreadID*2 + blockIdx.x*dataStrides.x] = partialCumSum[threadIdx.y*2];
output[globalThreadID*2 + 1 + blockIdx.x*dataStrides.x] = partialCumSum[threadIdx.y*2 + 1];
if (blockSum != NULL){
if (threadIdx.y == blockDim.y - 1) {
blockSum[blockIdx.x*gridDim.y + blockIdx.y] = partialCumSum[threadIdx.y*2 + 1];
}
}
}
void cumSum_2D_axis1_%(nodename)s(CudaNdarray* input, CudaNdarray* output, const int* shape, int maxThreads) {
int axis = 1; // Convert into a parameter
if (shape[axis] <= 1) {
CudaNdarray_CopyFromCudaNdarray(output, input);
return;
}
int numElements = shape[axis] - (shape[axis] %% 2);
int blockSize = ceil( min(numElements, 2*maxThreads) / 2.0);
int dimGridX = shape[0];
int dimGridY = ceil(numElements / (2.0*blockSize));
const int shapeBlockSum[2] = { dimGridX, dimGridY };
//CudaNdarray* deviceBlockSum = (CudaNdarray*) CudaNdarray_NewDims(2, shapeBlockSum);
CudaNdarray* deviceBlockSum = (CudaNdarray*) CudaNdarray_ZEROS(2, (int*)shapeBlockSum);
dim3 dimBlock(1, blockSize, 1);
dim3 dimGrid(dimGridX, dimGridY, 1);
int sharedBytes = (2*blockSize) * sizeof(float);
dim3 dataStrides(CudaNdarray_HOST_STRIDES(input)[0], CudaNdarray_HOST_STRIDES(input)[1], 0);
cudaThreadSynchronize();
k_blockCumSum_2D_axis1_%(nodename)s<<<dimGrid, dimBlock, sharedBytes>>>
(
CudaNdarray_DEV_DATA(input),
CudaNdarray_DEV_DATA(output),
numElements,
dataStrides,
CudaNdarray_DEV_DATA(deviceBlockSum)
);
if (dimGridY > 1) {
// Do a cumsum over the blockSum (recursive).
cumSum_2D_axis1_%(nodename)s(deviceBlockSum, deviceBlockSum, shapeBlockSum, maxThreads);
dim3 dimGrid(dimGridX, dimGridY, 1);
dim3 dimBlock(1, blockSize, 1);
k_finalCumSum_2D_axis1_%(nodename)s<<<dimGrid, dimBlock>>>
(
CudaNdarray_DEV_DATA(output),
CudaNdarray_DEV_DATA(deviceBlockSum),
numElements,
dataStrides
);
}
// If shape[axis] is odd, the last element is compute manually
if (shape[axis] != numElements){
cudaThreadSynchronize();
dim3 dimGrid(dimGridX, 1, 1);
dim3 dimBlock(1, 1, 1);
k_cumadd_2D_axis1_%(nodename)s<<<dimGrid, dimBlock>>>
(
CudaNdarray_DEV_DATA(input),
CudaNdarray_DEV_DATA(output),
shape[axis]-2,
shape[axis]-1
);
}
cudaFree(CudaNdarray_DEV_DATA(deviceBlockSum));
cudaThreadSynchronize();
}
""" % locals() """ % locals()
def c_code(self, node, nodename, inames, onames, sub): def c_code(self, node, nodename, inames, onames, sub):
...@@ -190,8 +323,9 @@ class GpuCumsum(CumsumOp, GpuOp): ...@@ -190,8 +323,9 @@ class GpuCumsum(CumsumOp, GpuOp):
"related to the selected GPU.") "related to the selected GPU.")
sub.update(locals()) sub.update(locals())
#Right now, only the 1D case implementation exists. #Right now, only the 1D case works.
if self.axis is None or (self.axis == 0 and node.inputs[0].ndim == 1):
code = """ code = """
npy_intp shape[1] = { CudaNdarray_SIZE(%(x)s) }; npy_intp shape[1] = { CudaNdarray_SIZE(%(x)s) };
if(! (%(z)s && CudaNdarray_HOST_DIMS(%(z)s)[0] == shape[0]) ) { if(! (%(z)s && CudaNdarray_HOST_DIMS(%(z)s)[0] == shape[0]) ) {
...@@ -217,6 +351,45 @@ class GpuCumsum(CumsumOp, GpuOp): ...@@ -217,6 +351,45 @@ class GpuCumsum(CumsumOp, GpuOp):
} }
} }
""" % locals() """ % locals()
elif node.inputs[0].ndim == 2 and self.axis == 1:
code = """
const int* shape = CudaNdarray_HOST_DIMS(%(x)s);
bool needAllocation = !%(z)s || CudaNdarray_NDIM(%(x)s) != CudaNdarray_NDIM(%(z)s);
// If output is already allocated, check if its shape matches the input's one.
if (!needAllocation) {
for (int i= 0; i < CudaNdarray_NDIM(%(x)s); ++i) {
if (CudaNdarray_HOST_DIMS(%(x)s)[i] == CudaNdarray_HOST_DIMS(%(z)s)[i]) {
needAllocation = true;
}
}
}
if (needAllocation){
Py_XDECREF(%(z)s);
%(z)s = (CudaNdarray*) CudaNdarray_NewDims(CudaNdarray_NDIM(%(x)s), shape);
}
if (!%(z)s) {
%(fail)s;
}
{ // Namespace for kernel calls //
cumSum_2D_axis1_%(nodename)s(%(x)s, %(z)s, shape, %(max_threads_dim0)s);
cudaError_t sts = cudaGetLastError();
if (cudaSuccess != sts)
{
PyErr_Format(PyExc_RuntimeError,
"Cuda error: %%s: %%s.\\n",
"cumSum_2D_axis1_%(nodename)s",
cudaGetErrorString(sts));
%(fail)s;
}
}
""" % locals()
else:
raise NotImplementedError('Only 1D case and 2D (axis=1) are supported right now!')
return code return code
......
...@@ -31,41 +31,118 @@ class TestGpuCumsum(theano.tensor.tests.test_extra_ops.TestCumsumOp): ...@@ -31,41 +31,118 @@ class TestGpuCumsum(theano.tensor.tests.test_extra_ops.TestCumsumOp):
x = T.vector('x') x = T.vector('x')
f = theano.function([x], cumsum(x)) f = theano.function([x], cumsum(x))
# Even number of elements # # Even number of elements
a = np.random.random((18,)).astype(config.floatX) # a = np.random.random((18,)).astype(config.floatX)
assert np.allclose(np.cumsum(a), f(a)) # assert np.allclose(np.cumsum(a), f(a))
# Odd number of elements # # Odd number of elements
a = np.random.random((7,)).astype(config.floatX) # a = np.random.random((7,)).astype(config.floatX)
assert np.allclose(np.cumsum(a), f(a)) # assert np.allclose(np.cumsum(a), f(a))
# Use multiple GPU threadblocks # # Use multiple GPU threadblocks
a = np.random.random((2048+1,)).astype(config.floatX) # a = np.random.random((2048+2,)).astype(config.floatX)
assert np.allclose(np.cumsum(a), f(a)) # assert np.allclose(np.cumsum(a), f(a))
# Use multiple GPU threadblocks # # Use multiple GPU threadblocks
a = np.random.random((2048*75+1,)).astype(config.floatX) # a = np.random.random((2048*75+2,)).astype(config.floatX)
assert np.allclose(np.cumsum(a), f(a)) # assert np.allclose(np.cumsum(a), f(a))
# Use multiple GPU gridblocks # # Use multiple GPU gridblocks
a = np.ones((2048*2048+1,)).astype(config.floatX) # a = np.ones((2048*2048+2,)).astype(config.floatX)
assert np.allclose(np.cumsum(a), f(a)) # assert np.allclose(np.cumsum(a), f(a))
print "\nBenchmark:"
# Extensive testing
i = 0; import timeit as t
while True: #theano_time = t.timeit("np.ones((100,))", "import numpy as np", number=1000)
a = np.ones((i,), dtype=config.floatX)
fa = f(a) stmt = "f(a)"
npa = np.cumsum(a) setup = """
import numpy as np
if not np.allclose(npa, fa): import theano
print i, np.allclose(npa, fa) # Test axis=None import theano.tensor as T
print fa from theano.tensor.extra_ops import cumsum
print npa from theano import config
assert False x = T.vector('x')
f = theano.function([x], cumsum(x))
if i % 1000 == 0: a = np.ones((100000,), dtype=config.floatX)
print i """.replace(" ", "")
theano_time = t.timeit(stmt, setup, number=1000)
i += 1 print "Theano:\t", theano_time
stmt = "np.cumsum(a)"
setup = """
import numpy as np
from theano import config
a = np.ones((100000,), dtype=config.floatX)
""".replace(" ", "")
numpy_time = t.timeit(stmt, setup, number=1000)
print "Numpy:\t", numpy_time
print "Speedup: {0}x".format(numpy_time/theano_time)
# # Extensive testing
# i = 0;
# while True:
# a = np.ones((i,), dtype=config.floatX)
# fa = f(a)
# npa = np.cumsum(a)
# if not np.allclose(npa, fa):
# print i, np.allclose(npa, fa) # Test axis=None
# print fa
# print npa
# assert False
# if i % 1000 == 0:
# print i
# i += 1
# ### Test 2D case - axis=1 ###
# x = T.matrix('x')
# f = theano.function([x], cumsum(x, axis=1))
# # # Even number of elements
# # print "\n# Even number of elements"
# # a = np.random.random((18,18)).astype(config.floatX)
# # assert np.allclose(np.cumsum(a, axis=1), f(a))
# # # Odd number of elements
# # print "\n# Odd number of elements"
# # assert np.allclose(np.cumsum(a, axis=1), f(a))
# # # Use multiple GPU threadblocks
# # print "\n# Use multiple GPU threadblocks"
# # a = np.random.random((2048+2,2048+2)).astype(config.floatX)
# # assert np.allclose(np.cumsum(a, axis=1), f(a))
# # # Use multiple GPU threadblocks
# # print "\n# Use multiple GPU threadblocks"
# # a = np.ones((10,2048*75+3)).astype(config.floatX)
# # assert np.allclose(np.cumsum(a, axis=1), f(a))
# # # Use multiple GPU gridblocks
# # print "\n# Use multiple GPU gridblocks"
# # a = np.ones((11,2048*2048+3)).astype(config.floatX)
# # assert np.allclose(np.cumsum(a, axis=1), f(a))
# # Extensive testing
# i = 19000;
# while True:
# a = np.ones((11,i), dtype=config.floatX)
# fa = f(a)
# npa = np.cumsum(a, axis=1)
# if not np.allclose(npa, fa):
# print i, np.allclose(npa, fa) # Test axis=None
# print fa
# print npa
# assert False
# if i % 1000 == 0:
# print i
# i += 1
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论