提交 35e15192 authored 作者: abergeron's avatar abergeron

Merge pull request #2171 from MarcCote/cumsum_3D

Support GpuCumsum on 3D array.
import theano import theano
import copy import copy
from theano import Op, Apply from theano import Op
from theano.gof import local_optimizer from theano.gof import local_optimizer
from theano.sandbox.cuda import cuda_available, GpuOp from theano.sandbox.cuda import cuda_available, GpuOp
...@@ -14,8 +14,8 @@ if cuda_available: ...@@ -14,8 +14,8 @@ if cuda_available:
class GpuCumsum(CumsumOp, GpuOp): class GpuCumsum(CumsumOp, GpuOp):
SUPPORTED_NDIMS = 2 SUPPORTED_NDIMS = 3
__props__ = ('axis', 'max_threads_dim0', 'max_grid_size1') __props__ = ('axis', 'max_threads_dim0', 'max_grid_size1', 'max_grid_size2')
def __init__(self, axis): def __init__(self, axis):
""" """
...@@ -24,6 +24,7 @@ class GpuCumsum(CumsumOp, GpuOp): ...@@ -24,6 +24,7 @@ class GpuCumsum(CumsumOp, GpuOp):
self.axis = axis self.axis = axis
self.max_threads_dim0 = None self.max_threads_dim0 = None
self.max_grid_size1 = None self.max_grid_size1 = None
self.max_grid_size2 = None
def perform(self, node, inp, out): def perform(self, node, inp, out):
return Op.perform(self, node, inp, out) return Op.perform(self, node, inp, out)
...@@ -34,7 +35,7 @@ class GpuCumsum(CumsumOp, GpuOp): ...@@ -34,7 +35,7 @@ class GpuCumsum(CumsumOp, GpuOp):
raise TypeError('x must be a CudaNdarrayType', x) raise TypeError('x must be a CudaNdarrayType', x)
if x.ndim > GpuCumsum.SUPPORTED_NDIMS: if x.ndim > GpuCumsum.SUPPORTED_NDIMS:
raise NotImplementedError('Only cumsum on 1D and 2D array are supported right now!') raise NotImplementedError('Only cumsum on 1D, 2D and 3D array are supported right now!')
if self.axis >= x.ndim: if self.axis >= x.ndim:
raise ValueError('axis(={1}) out of bounds'.format(self.axis)) raise ValueError('axis(={1}) out of bounds'.format(self.axis))
...@@ -44,7 +45,7 @@ class GpuCumsum(CumsumOp, GpuOp): ...@@ -44,7 +45,7 @@ class GpuCumsum(CumsumOp, GpuOp):
def make_thunk(self, node, storage_map, compute_map, no_recycling): def make_thunk(self, node, storage_map, compute_map, no_recycling):
node_ = copy.copy(node) node_ = copy.copy(node)
assert node.op is node_.op assert node.op is node_.op
if node_.op.max_threads_dim0 is None or node_.op.max_grid_size1 is None: if node_.op.max_threads_dim0 is None or node_.op.max_grid_size1 is None or node_.op.max_grid_size2 is None:
cuda = theano.sandbox.cuda cuda = theano.sandbox.cuda
device_id = cuda.use.device_number device_id = cuda.use.device_number
if device_id is None: if device_id is None:
...@@ -59,6 +60,7 @@ class GpuCumsum(CumsumOp, GpuOp): ...@@ -59,6 +60,7 @@ class GpuCumsum(CumsumOp, GpuOp):
prop = cuda_ndarray.device_properties(device_id) prop = cuda_ndarray.device_properties(device_id)
node_.op.max_threads_dim0 = prop['maxThreadsDim0'] node_.op.max_threads_dim0 = prop['maxThreadsDim0']
node_.op.max_grid_size1 = prop['maxGridSize1'] node_.op.max_grid_size1 = prop['maxGridSize1']
node_.op.max_grid_size2 = prop['maxGridSize2']
return super(GpuCumsum, node_.op).make_thunk(node_, storage_map, return super(GpuCumsum, node_.op).make_thunk(node_, storage_map,
compute_map, no_recycling) compute_map, no_recycling)
...@@ -67,7 +69,7 @@ class GpuCumsum(CumsumOp, GpuOp): ...@@ -67,7 +69,7 @@ class GpuCumsum(CumsumOp, GpuOp):
return "%s{%s}" % (self.__class__.__name__, self.axis) return "%s{%s}" % (self.__class__.__name__, self.axis)
def c_code_cache_version(self): def c_code_cache_version(self):
return (5,) return (7,)
def c_support_code_apply(self, node, nodename): def c_support_code_apply(self, node, nodename):
return """ return """
...@@ -96,28 +98,37 @@ class GpuCumsum(CumsumOp, GpuOp): ...@@ -96,28 +98,37 @@ class GpuCumsum(CumsumOp, GpuOp):
} }
__device__ __device__
void k_fetchData_%(nodename)s(float* partialCumSum, float* input, int globalThreadID, dim3 dataStrides, int dataOffset) { void k_fetchData_%(nodename)s(float* partialCumSum, float* input, int globalThreadID, dim3 dataStrides, int offsetY, int offsetZ) {
// blockIdx.y represents the # of the current independent cumsum // blockIdx.y and blockIdx.z represents the current independent cumsum
int idx_even = (globalThreadID*2 ) * dataStrides.x + (blockIdx.y + dataOffset) * dataStrides.y; int idY = blockIdx.y + offsetY;
int idx_odd = (globalThreadID*2 + 1) * dataStrides.x + (blockIdx.y + dataOffset) * dataStrides.y; int idZ = blockIdx.z + offsetZ;
int offset = idY * dataStrides.y + idZ * dataStrides.z;
int idx_even = (globalThreadID*2 ) * dataStrides.x + offset;
int idx_odd = (globalThreadID*2 + 1) * dataStrides.x + offset;
partialCumSum[threadIdx.x*2] = input[idx_even]; partialCumSum[threadIdx.x*2] = input[idx_even];
partialCumSum[threadIdx.x*2 + 1] = input[idx_odd]; partialCumSum[threadIdx.x*2 + 1] = input[idx_odd];
} }
__device__ __device__
void k_pushData_%(nodename)s(float* partialCumSum, float* output, int globalThreadID, dim3 dataStrides, int dataOffset) { void k_pushData_%(nodename)s(float* partialCumSum, float* output, int globalThreadID, dim3 dataStrides, int offsetY, int offsetZ) {
__syncthreads(); __syncthreads();
// blockIdx.y represents the # of the current independent cumsum // blockIdx.y and blockIdx.z represents the current independent cumsum
int idx_even = (globalThreadID*2 ) * dataStrides.x + (blockIdx.y + dataOffset) * dataStrides.y; int idY = blockIdx.y + offsetY;
int idx_odd = (globalThreadID*2 + 1) * dataStrides.x + (blockIdx.y + dataOffset) * dataStrides.y; int idZ = blockIdx.z + offsetZ;
int offset = idY * dataStrides.y + idZ * dataStrides.z;
int idx_even = (globalThreadID*2 ) * dataStrides.x + offset;
int idx_odd = (globalThreadID*2 + 1) * dataStrides.x + offset;
output[idx_even] = partialCumSum[threadIdx.x*2]; output[idx_even] = partialCumSum[threadIdx.x*2];
output[idx_odd] = partialCumSum[threadIdx.x*2 + 1]; output[idx_odd] = partialCumSum[threadIdx.x*2 + 1];
} }
__global__ __global__
void k_cumadd_%(nodename)s(float* input, float* output, dim3 inputStrides, dim3 outputStrides, int dataOffset, int beforeLastElementIdx, int lastElementIdx) { void k_cumadd_%(nodename)s(float* input, float* output, dim3 inputStrides, dim3 outputStrides, int offsetY, int offsetZ, int beforeLastElementIdx, int lastElementIdx) {
int dataOffsetY_input = (blockIdx.y + dataOffset) * inputStrides.y; int idY = blockIdx.y + offsetY;
int dataOffsetY_output = (blockIdx.y + dataOffset) * outputStrides.y; int idZ = blockIdx.z + offsetZ;
int dataOffsetY_input = idY * inputStrides.y + idZ * inputStrides.z;
int dataOffsetY_output = idY * outputStrides.y + idZ * outputStrides.z;
int idx_last_input = lastElementIdx*inputStrides.x + dataOffsetY_input; int idx_last_input = lastElementIdx*inputStrides.x + dataOffsetY_input;
int idx_last_output = lastElementIdx*outputStrides.x + dataOffsetY_output; int idx_last_output = lastElementIdx*outputStrides.x + dataOffsetY_output;
...@@ -127,39 +138,42 @@ class GpuCumsum(CumsumOp, GpuOp): ...@@ -127,39 +138,42 @@ class GpuCumsum(CumsumOp, GpuOp):
} }
__global__ __global__
void k_finalCumSum_%(nodename)s(float* output, float* blockSum, int numElements, dim3 dataStrides, int dataOffset) { void k_finalCumSum_%(nodename)s(float* output, float* blockSum, int nbElementsPerCumsum, dim3 dataStrides, int offsetY, int offsetZ) {
int globalThreadID = (blockIdx.x + 1) * blockDim.x + threadIdx.x; int globalThreadID = (blockIdx.x + 1) * blockDim.x + threadIdx.x;
// Check if current has data to process. // Check if current has data to process.
if (globalThreadID >= ceil(numElements/2.0)) { if (globalThreadID >= ceil(nbElementsPerCumsum/2.0)) {
return; return;
} }
const float currentBlockSum = blockSum[blockIdx.x*gridDim.y + blockIdx.y + dataOffset]; int idY = blockIdx.y + offsetY;
int idZ = blockIdx.z + offsetZ;
const float currentBlockSum = blockSum[blockIdx.x*(gridDim.y*gridDim.z) + idY*gridDim.z + idZ];
int dataOffsetY = (blockIdx.y + dataOffset) * (int)dataStrides.y; int offset = idY * dataStrides.y + idZ * dataStrides.z;
int idx_even = (globalThreadID*2 ) * dataStrides.x + dataOffsetY; int idx_even = (globalThreadID*2 ) * dataStrides.x + offset;
int idx_odd = (globalThreadID*2 + 1) * dataStrides.x + dataOffsetY; int idx_odd = (globalThreadID*2 + 1) * dataStrides.x + offset;
output[idx_even] += currentBlockSum; output[idx_even] += currentBlockSum;
output[idx_odd] += currentBlockSum; output[idx_odd] += currentBlockSum;
} }
__global__ __global__
void k_blockCumSum_%(nodename)s(float* input, float* output, int numElements, dim3 inputStrides, dim3 outputStrides, int dataOffset, float* blockSum) { void k_blockCumSum_%(nodename)s(float* input, float* output, int nbElementsPerCumsum, dim3 inputStrides, dim3 outputStrides, int offsetY, int offsetZ, float* blockSum) {
// Regarding blockIdx and threadIdx, 'Cumsum' is always performed along the X axis. // Regarding blockIdx and threadIdx, 'Cumsum' is always performed along the X axis.
// The Y axis will contain all the independent cumsums of the 2D case. // The Y and Z axis of the grid will contain all independent cumsums of the 2D/3D case.
int globalThreadID = blockIdx.x * blockDim.x + threadIdx.x; int globalThreadID = blockIdx.x * blockDim.x + threadIdx.x;
// Check if current thread has data to process. // Check if current thread has data to process.
if (globalThreadID >= ceil(numElements/2.0)) { if (globalThreadID >= ceil(nbElementsPerCumsum/2.0)) {
return; return;
} }
extern __shared__ float partialCumSum[]; extern __shared__ float partialCumSum[];
// Load data in shared memory // Load data in shared memory
k_fetchData_%(nodename)s(partialCumSum, input, globalThreadID, inputStrides, dataOffset); k_fetchData_%(nodename)s(partialCumSum, input, globalThreadID, inputStrides, offsetY, offsetZ);
// Use a dichotomy approach to compute the cumsum (i.e. balanced binary tree). // Use a dichotomy approach to compute the cumsum (i.e. balanced binary tree).
// The tree is sweeped from the leaves to the root and from the root to the leaves. // The tree is sweeped from the leaves to the root and from the root to the leaves.
...@@ -168,19 +182,19 @@ class GpuCumsum(CumsumOp, GpuOp): ...@@ -168,19 +182,19 @@ class GpuCumsum(CumsumOp, GpuOp):
k_reversePhase_%(nodename)s(partialCumSum); k_reversePhase_%(nodename)s(partialCumSum);
// Write the final output to global memory // Write the final output to global memory
k_pushData_%(nodename)s(partialCumSum, output, globalThreadID, outputStrides, dataOffset); k_pushData_%(nodename)s(partialCumSum, output, globalThreadID, outputStrides, offsetY, offsetZ);
if (blockSum != NULL){ if (blockSum != NULL){
if (threadIdx.x == blockDim.x - 1) { if (threadIdx.x == blockDim.x - 1) {
blockSum[blockIdx.x*gridDim.y + blockIdx.y + dataOffset] = partialCumSum[threadIdx.x*2 + 1]; blockSum[blockIdx.x*(gridDim.y*gridDim.z) + (blockIdx.y + offsetY)*gridDim.z + blockIdx.z + offsetZ] = partialCumSum[threadIdx.x*2 + 1];
} }
} }
} }
int cumSum_%(nodename)s(CudaNdarray* input, CudaNdarray* output, int maxThreads, int axis, int maxGridY) { int cumSum_%(nodename)s(CudaNdarray* input, CudaNdarray* output, int axis, int maxThreads, int maxGridY, int maxGridZ) {
int shape[2] = { 1, 1 }; int shape[3] = { 1, 1, 1 };
dim3 inputStrides(0,0,0); dim3 inputStrides(0, 0, 0);
dim3 outputStrides(0,0,0); dim3 outputStrides(0, 0, 0);
switch (CudaNdarray_NDIM(input)) switch (CudaNdarray_NDIM(input))
{ {
...@@ -197,8 +211,18 @@ class GpuCumsum(CumsumOp, GpuOp): ...@@ -197,8 +211,18 @@ class GpuCumsum(CumsumOp, GpuOp):
outputStrides.x = CudaNdarray_HOST_STRIDES(output)[0]; outputStrides.x = CudaNdarray_HOST_STRIDES(output)[0];
outputStrides.y = CudaNdarray_HOST_STRIDES(output)[1]; outputStrides.y = CudaNdarray_HOST_STRIDES(output)[1];
break; break;
case 3:
shape[0] = CudaNdarray_HOST_DIMS(input)[0];
shape[1] = CudaNdarray_HOST_DIMS(input)[1];
shape[2] = CudaNdarray_HOST_DIMS(input)[2];
inputStrides.x = CudaNdarray_HOST_STRIDES(input)[0];
inputStrides.y = CudaNdarray_HOST_STRIDES(input)[1];
inputStrides.z = CudaNdarray_HOST_STRIDES(input)[2];
outputStrides.x = CudaNdarray_HOST_STRIDES(output)[0];
outputStrides.y = CudaNdarray_HOST_STRIDES(output)[1];
outputStrides.z = CudaNdarray_HOST_STRIDES(output)[2];
break;
default: default:
printf("Only 1D and 2D cumsum is implemented yet.\\n");
return -1; return -1;
} }
...@@ -207,64 +231,102 @@ class GpuCumsum(CumsumOp, GpuOp): ...@@ -207,64 +231,102 @@ class GpuCumsum(CumsumOp, GpuOp):
return 0; return 0;
} }
if (axis == 1) { // Perform cumsum on array of even size.
int tmp = inputStrides.x; int nbElementsPerCumsum = shape[axis] - (shape[axis] %% 2);
// Determine how many elements can be processed in one block.
int dimBlockX = ceil( min(nbElementsPerCumsum, 2*maxThreads) / 2.0);
// Determine how many blocks are needed in total.
int dimGridX = ceil(nbElementsPerCumsum / (2.0*dimBlockX)); // Nb. of blocks needed per cumsum.
int dimGridY; // Nb. of independent cumsums (width).
int dimGridZ; // Nb. of independent cumsums (height).
int tmp;
switch (axis)
{
case 0:
dimGridY = shape[1];
dimGridZ = shape[2];
break;
case 1:
dimGridY = shape[0];
dimGridZ = shape[2];
tmp = inputStrides.x;
inputStrides.x = inputStrides.y; inputStrides.x = inputStrides.y;
inputStrides.y = tmp; inputStrides.y = tmp;
tmp = outputStrides.x; tmp = outputStrides.x;
outputStrides.x = outputStrides.y; outputStrides.x = outputStrides.y;
outputStrides.y = tmp; outputStrides.y = tmp;
} break;
case 2:
dimGridY = shape[1];
dimGridZ = shape[0];
tmp = inputStrides.x;
inputStrides.x = inputStrides.z;
inputStrides.z = tmp;
int numElements = shape[axis] - (shape[axis] %% 2); tmp = outputStrides.x;
int blockSize = ceil( min(numElements, 2*maxThreads) / 2.0); outputStrides.x = outputStrides.z;
int dimGridX = ceil(numElements / (2.0*blockSize)); // Nb. of elements to perform cumsum on. outputStrides.z = tmp;
int dimGridY = shape[1-axis]; // Nb. of independent cumsums. break;
const int shapeBlockSum[2] = { dimGridX, dimGridY }; default:
return -1;
}
const int shapeBlockSum[2] = { dimGridX, dimGridY*dimGridZ };
CudaNdarray* deviceBlockSum = (CudaNdarray*) CudaNdarray_NewDims(2, shapeBlockSum); CudaNdarray* deviceBlockSum = (CudaNdarray*) CudaNdarray_NewDims(2, shapeBlockSum);
for (int dataOffset = 0; dataOffset < dimGridY; dataOffset += maxGridY){ // Perform `maxGridY`*`maxGridZ` cumsums in parallel.
int localDimGridY = min(dimGridY - dataOffset, maxGridY); for (int offsetY = 0; offsetY < dimGridY; offsetY += maxGridY){
dim3 dimBlock(blockSize, 1, 1); int localDimGridY = min(dimGridY - offsetY, maxGridY);
dim3 dimGrid(dimGridX, localDimGridY, 1);
int sharedBytes = (2*blockSize) * sizeof(float); for (int offsetZ = 0; offsetZ < dimGridZ; offsetZ += maxGridZ){
int localDimGridZ = min(dimGridZ - offsetZ, maxGridZ);
dim3 dimGrid(dimGridX, localDimGridY, localDimGridZ);
dim3 dimBlock(dimBlockX, 1, 1); // One cumsum per block.
int sharedBytes = (2*dimBlockX) * sizeof(float);
k_blockCumSum_%(nodename)s<<<dimGrid, dimBlock, sharedBytes>>> k_blockCumSum_%(nodename)s<<<dimGrid, dimBlock, sharedBytes>>>
( (
CudaNdarray_DEV_DATA(input), CudaNdarray_DEV_DATA(input),
CudaNdarray_DEV_DATA(output), CudaNdarray_DEV_DATA(output),
numElements, nbElementsPerCumsum,
inputStrides, inputStrides,
outputStrides, outputStrides,
dataOffset, offsetY,
offsetZ,
CudaNdarray_DEV_DATA(deviceBlockSum) CudaNdarray_DEV_DATA(deviceBlockSum)
); );
if (dimGridX > 1) { if (dimGridX > 1) {
// Do a cumsum over the blockSum (recursive). // Do a cumsum over the blockSum (recursive).
if (cumSum_%(nodename)s(deviceBlockSum, deviceBlockSum, maxThreads, 0, maxGridY) == -1){ if (cumSum_%(nodename)s(deviceBlockSum, deviceBlockSum, 0, maxThreads, maxGridY, maxGridZ) == -1){
return -1; return -1;
} }
// Since there are more than one block (i.e. `dimGridX > 1`) // Since there are more than one block (i.e. `dimGridX > 1`)
// report partial cumsums of previous blocks to subsequents ones. // report partial cumsums of previous blocks to subsequents ones.
dim3 dimGrid(dimGridX, dimGridY, 1); dim3 dimGrid(dimGridX, localDimGridY, localDimGridZ);
dim3 dimBlock(blockSize, 1, 1); dim3 dimBlock(dimBlockX, 1, 1);
k_finalCumSum_%(nodename)s<<<dimGrid, dimBlock>>> k_finalCumSum_%(nodename)s<<<dimGrid, dimBlock>>>
( (
CudaNdarray_DEV_DATA(output), CudaNdarray_DEV_DATA(output),
CudaNdarray_DEV_DATA(deviceBlockSum), CudaNdarray_DEV_DATA(deviceBlockSum),
numElements, nbElementsPerCumsum,
outputStrides, outputStrides,
dataOffset offsetY,
offsetZ
); );
} }
// If shape[axis] is odd, the last element is compute manually // If shape[axis] is odd, the last element is compute manually
if (shape[axis] != numElements){ if (shape[axis] != nbElementsPerCumsum){
dim3 dimGrid(1, localDimGridY, 1); dim3 dimGrid(1, localDimGridY, localDimGridZ);
dim3 dimBlock(1, 1, 1); dim3 dimBlock(1, 1, 1);
k_cumadd_%(nodename)s<<<dimGrid, dimBlock>>> k_cumadd_%(nodename)s<<<dimGrid, dimBlock>>>
( (
...@@ -272,12 +334,14 @@ class GpuCumsum(CumsumOp, GpuOp): ...@@ -272,12 +334,14 @@ class GpuCumsum(CumsumOp, GpuOp):
CudaNdarray_DEV_DATA(output), CudaNdarray_DEV_DATA(output),
inputStrides, inputStrides,
outputStrides, outputStrides,
dataOffset, offsetY,
offsetZ,
shape[axis]-2, shape[axis]-2,
shape[axis]-1 shape[axis]-1
); );
} }
} }
}
cudaFree(CudaNdarray_DEV_DATA(deviceBlockSum)); cudaFree(CudaNdarray_DEV_DATA(deviceBlockSum));
CNDA_THREAD_SYNC; CNDA_THREAD_SYNC;
...@@ -293,7 +357,8 @@ class GpuCumsum(CumsumOp, GpuOp): ...@@ -293,7 +357,8 @@ class GpuCumsum(CumsumOp, GpuOp):
max_threads_dim0 = self.max_threads_dim0 max_threads_dim0 = self.max_threads_dim0
max_grid_size1 = self.max_grid_size1 max_grid_size1 = self.max_grid_size1
if max_threads_dim0 is None or max_grid_size1 is None: max_grid_size2 = self.max_grid_size2
if max_threads_dim0 is None or max_grid_size1 is None or max_grid_size2 is None:
raise NotImplementedError("GpuCumsum.c_code should not be called " raise NotImplementedError("GpuCumsum.c_code should not be called "
"directly. It should be called by " "directly. It should be called by "
"make_thunk() that add some information " "make_thunk() that add some information "
...@@ -322,7 +387,7 @@ class GpuCumsum(CumsumOp, GpuOp): ...@@ -322,7 +387,7 @@ class GpuCumsum(CumsumOp, GpuOp):
} }
{ // Namespace for kernel calls // { // Namespace for kernel calls //
if (cumSum_%(nodename)s(%(x)s, %(z)s, %(max_threads_dim0)s, %(axis)s, %(max_grid_size1)s) == -1){ if (cumSum_%(nodename)s(%(x)s, %(z)s, %(axis)s, %(max_threads_dim0)s, %(max_grid_size1)s, %(max_grid_size2)s) == -1){
%(fail)s; %(fail)s;
} }
......
...@@ -16,9 +16,8 @@ else: ...@@ -16,9 +16,8 @@ else:
from theano import tensor as T from theano import tensor as T
import numpy as np import numpy as np
import theano import theano
from theano import config
from theano.tensor.extra_ops import cumsum, CumsumOp from theano.tensor.extra_ops import cumsum, CumsumOp
import itertools
class TestGpuCumsum(theano.tensor.tests.test_extra_ops.TestCumsumOp): class TestGpuCumsum(theano.tensor.tests.test_extra_ops.TestCumsumOp):
mode = mode_with_gpu mode = mode_with_gpu
...@@ -45,68 +44,63 @@ class TestGpuCumsum(theano.tensor.tests.test_extra_ops.TestCumsumOp): ...@@ -45,68 +44,63 @@ class TestGpuCumsum(theano.tensor.tests.test_extra_ops.TestCumsumOp):
def test_Strides1D(self): def test_Strides1D(self):
x = T.fvector('x') x = T.fvector('x')
# Stepped strides for axis in [0, None]:
f = theano.function([x], cumsum(x[::2]), mode=self.mode) a = np.random.random((42,)).astype("float32")
assert [n for n in f.maker.fgraph.toposort() cumsum_function = theano.function([x], cumsum(x, axis=axis), mode=self.mode)
if isinstance(n.op, GpuCumsum)]
a = np.random.randint(10, size=(42,)).astype("float32")
assert np.allclose(np.cumsum(a[::2]), f(a))
# Alternative stepped strides slicings = [slice(None, None, None), # Normal strides
f = theano.function([x], cumsum(x), mode=self.mode) slice(None, None, 2), # Stepped strides
assert [n for n in f.maker.fgraph.toposort() slice(None, None, -1), # Negative strides
if isinstance(n.op, GpuCumsum)] ]
a = np.random.randint(10, size=(42,)).astype("float32")
assert np.allclose(np.cumsum(a[::2]), f(a[::2]))
# Negative strides # Cartesian product of all slicings to test.
f = theano.function([x], cumsum(x[::-1]), mode=self.mode) for slicing in itertools.product(slicings, repeat=x.ndim):
f = theano.function([x], cumsum(x[slicing], axis=axis), mode=self.mode)
assert [n for n in f.maker.fgraph.toposort() assert [n for n in f.maker.fgraph.toposort()
if isinstance(n.op, GpuCumsum)] if isinstance(n.op, GpuCumsum)]
a = np.random.randint(10, size=(42,)).astype("float32") assert np.allclose(np.cumsum(a[slicing], axis=axis), f(a))
assert np.allclose(np.cumsum(a[::-1]), f(a)) assert np.allclose(np.cumsum(a[slicing], axis=axis), cumsum_function(a[slicing]))
def test_Strides2D(self): def test_Strides2D(self):
x = T.fmatrix('x') x = T.fmatrix('x')
for shape_axis, axis in zip([0, 1, 0], [0, 1, None]): for axis in [0, 1, None]:
a = np.random.random((42, 30)).astype("float32") a = np.random.random((42, 30)).astype("float32")
cumsum_function = theano.function([x], cumsum(x, axis=axis), mode=self.mode)
# Stepped strides along axis=0 slicings = [slice(None, None, None), # Normal strides
f = theano.function([x], cumsum(x[::2], axis=axis), mode=self.mode) slice(None, None, 2), # Stepped strides
assert [n for n in f.maker.fgraph.toposort() slice(None, None, -1), # Negative strides
if isinstance(n.op, GpuCumsum)] ]
assert np.allclose(np.cumsum(a[::2], axis=axis), f(a))
# Stepped strides along axis=1 # Cartesian product of all slicings to test.
f = theano.function([x], cumsum(x[:, ::2], axis=axis), mode=self.mode) for slicing in itertools.product(slicings, repeat=x.ndim):
f = theano.function([x], cumsum(x[slicing], axis=axis), mode=self.mode)
assert [n for n in f.maker.fgraph.toposort() assert [n for n in f.maker.fgraph.toposort()
if isinstance(n.op, GpuCumsum)] if isinstance(n.op, GpuCumsum)]
assert np.allclose(np.cumsum(a[:, ::2], axis=axis), f(a)) assert np.allclose(np.cumsum(a[slicing], axis=axis), f(a))
assert np.allclose(np.cumsum(a[slicing], axis=axis), cumsum_function(a[slicing]))
# Alternative stepped strides along axis=0 def test_Strides3D(self):
f = theano.function([x], cumsum(x), mode=self.mode) x = T.ftensor3('x')
assert [n for n in f.maker.fgraph.toposort()
if isinstance(n.op, GpuCumsum)]
assert np.allclose(np.cumsum(a[::2]), f(a[::2]))
# Alternative stepped strides along axis=1 for axis in [0, 1, 2, None]:
f = theano.function([x], cumsum(x), mode=self.mode) a = np.random.random((42, 30, 25)).astype("float32")
assert [n for n in f.maker.fgraph.toposort() cumsum_function = theano.function([x], cumsum(x, axis=axis), mode=self.mode)
if isinstance(n.op, GpuCumsum)]
assert np.allclose(np.cumsum(a[:, ::2]), f(a[:, ::2]))
# Negative strides along axis=0 slicings = [slice(None, None, None), # Normal strides
f = theano.function([x], cumsum(x[::-1], axis=axis), mode=self.mode) slice(None, None, 2), # Stepped strides
assert [n for n in f.maker.fgraph.toposort() slice(None, None, -1), # Negative strides
if isinstance(n.op, GpuCumsum)] ]
assert np.allclose(np.cumsum(a[::-1], axis=axis), f(a))
# Negative strides along axis=1 # Cartesian product of all slicings to test.
f = theano.function([x], cumsum(x[:, ::-1], axis=axis), mode=self.mode) for slicing in itertools.product(slicings, repeat=x.ndim):
f = theano.function([x], cumsum(x[slicing], axis=axis), mode=self.mode)
assert [n for n in f.maker.fgraph.toposort() assert [n for n in f.maker.fgraph.toposort()
if isinstance(n.op, GpuCumsum)] if isinstance(n.op, GpuCumsum)]
assert np.allclose(np.cumsum(a[:, ::-1], axis=axis), f(a)) assert np.allclose(np.cumsum(a[slicing], axis=axis), f(a))
assert np.allclose(np.cumsum(a[slicing], axis=axis), cumsum_function(a[slicing]))
def test_GpuCumsum1D(self): def test_GpuCumsum1D(self):
block_max_size = self.max_threads_dim0 * 2 block_max_size = self.max_threads_dim0 * 2
...@@ -163,14 +157,63 @@ class TestGpuCumsum(theano.tensor.tests.test_extra_ops.TestCumsumOp): ...@@ -163,14 +157,63 @@ class TestGpuCumsum(theano.tensor.tests.test_extra_ops.TestCumsumOp):
assert np.allclose(np.cumsum(a, axis=axis), f(a)) assert np.allclose(np.cumsum(a, axis=axis), f(a))
# Use recursive cumsum # Use recursive cumsum
a_shape = [5, 3] a_shape = [3, 3]
a_shape[shape_axis] = block_max_size*(block_max_size+1)+2 a_shape[shape_axis] = block_max_size*(block_max_size+1)+2
a = np.ones(a_shape, dtype="float32") a = np.random.random(a_shape).astype("float32")
a = np.sign(a-0.5).astype("float32") # Avoid floating point error
assert np.allclose(np.cumsum(a, axis=axis), f(a)) assert np.allclose(np.cumsum(a, axis=axis), f(a))
def test_GpuCumsum3D(self): def test_GpuCumsum3D(self):
# Should not use the GPU version. block_max_size = self.max_threads_dim0 * 2
x = T.ftensor3('x') x = T.ftensor3('x')
for shape_axis, axis in zip([0, 1, 2, 0], [0, 1, 2, None]):
f = theano.function([x], cumsum(x, axis=axis), mode=self.mode)
assert [n for n in f.maker.fgraph.toposort()
if isinstance(n.op, GpuCumsum)]
# Extensive testing for the first 1025 sizes
a_shape = [5, 5, 5]
a_shape[shape_axis] = 1025
a = np.random.rand(*a_shape).astype("float32")
slices = [slice(None), slice(None), slice(None)]
for i in xrange(a.shape[shape_axis]):
slices[shape_axis] = slice(i)
fa = f(a[slices])
npa = np.cumsum(a[slices], axis=axis)
assert np.allclose(npa, fa)
# Use multiple GPU threadblocks (along accumulation axis)
a_shape = [2, 2, 2]
a_shape[shape_axis] = block_max_size+2
a = np.random.random(a_shape).astype("float32")
assert np.allclose(np.cumsum(a, axis=axis), f(a))
# Use multiple GPU gridblocks (not along accumulation axis)
a_shape = [5, 5, 5]
a_shape[(shape_axis+1) % 3] = self.max_grid_size1+1
a = np.random.random(a_shape).astype("float32")
if axis is None:
a = np.sign(a-0.5).astype("float32") # Avoid floating point error
assert np.allclose(np.cumsum(a, axis=axis), f(a))
a_shape = [5, 5, 5]
a_shape[(shape_axis+2) % 3] = self.max_grid_size1+1
a = np.random.random(a_shape).astype("float32")
if axis is None:
a = np.sign(a-0.5).astype("float32") # Avoid floating point error
assert np.allclose(np.cumsum(a, axis=axis), f(a))
# Use recursive cumsum (along accumulation axis)
a_shape = [3, 3, 3]
a_shape[shape_axis] = block_max_size*(block_max_size+1)+2
a = np.random.random(a_shape).astype("float32")
a = np.sign(a-0.5).astype("float32") # Avoid floating point error
assert np.allclose(np.cumsum(a, axis=axis), f(a))
def test_GpuCumsum4D(self):
# Should not use the GPU version.
x = T.ftensor4('x')
f = theano.function([x], cumsum(x, axis=1), mode=self.mode) f = theano.function([x], cumsum(x, axis=1), mode=self.mode)
assert [n for n in f.maker.fgraph.toposort() assert [n for n in f.maker.fgraph.toposort()
if isinstance(n.op, CumsumOp)] if isinstance(n.op, CumsumOp)]
...@@ -62,7 +62,7 @@ class TestCumsumOp(utt.InferShapeTester): ...@@ -62,7 +62,7 @@ class TestCumsumOp(utt.InferShapeTester):
utt.verify_grad(self.op, [a]) # Test axis=None utt.verify_grad(self.op, [a]) # Test axis=None
for axis in range(len(a.shape)): for axis in range(len(a.shape)):
utt.verify_grad(self.op_class(axis=axis), [a]) utt.verify_grad(self.op_class(axis=axis), [a], eps=4e-4)
class TestCumprodOp(utt.InferShapeTester): class TestCumprodOp(utt.InferShapeTester):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论