提交 584520bc authored 作者: Marc-Alexandre Cote's avatar Marc-Alexandre Cote

Use output's strides instead of assuming they are the same as the input's one.

上级 e9c843bb
......@@ -78,7 +78,7 @@ class GpuCumsum(CumsumOp, GpuOp):
compute_map, no_recycling)
def c_code_cache_version(self):
return (4,)
return (5,)
def c_support_code_apply(self, node, nodename):
return """
......@@ -126,11 +126,15 @@ class GpuCumsum(CumsumOp, GpuOp):
}
__global__
void k_cumadd_%(nodename)s(float* input, float* output, dim3 dataStrides, int dataOffset, int beforeLastElementIdx, int lastElementIdx) {
int dataOffsetY = (blockIdx.y + dataOffset) * dataStrides.y;
int idx_last = lastElementIdx*dataStrides.x + dataOffsetY;
int idx_beforelast = beforeLastElementIdx*dataStrides.x + dataOffsetY;
output[idx_last] = input[idx_last] + output[idx_beforelast];
void k_cumadd_%(nodename)s(float* input, float* output, dim3 inputStrides, dim3 outputStrides, int dataOffset, int beforeLastElementIdx, int lastElementIdx) {
int dataOffsetY_input = (blockIdx.y + dataOffset) * inputStrides.y;
int dataOffsetY_output = (blockIdx.y + dataOffset) * outputStrides.y;
int idx_last_input = lastElementIdx*inputStrides.x + dataOffsetY_input;
int idx_last_output = lastElementIdx*outputStrides.x + dataOffsetY_output;
int idx_beforelast = beforeLastElementIdx*outputStrides.x + dataOffsetY_output;
output[idx_last_output] = input[idx_last_input] + output[idx_beforelast];
}
__global__
......@@ -152,7 +156,7 @@ class GpuCumsum(CumsumOp, GpuOp):
}
__global__
void k_blockCumSum_%(nodename)s(float* input, float* output, int numElements, dim3 dataStrides, int dataOffset, float* blockSum) {
void k_blockCumSum_%(nodename)s(float* input, float* output, int numElements, dim3 inputStrides, dim3 outputStrides, int dataOffset, float* blockSum) {
// Regarding blockIdx and threadIdx, 'Cumsum' is always performed along the X axis.
// The Y axis will contain all the independent cumsums of the 2D case.
......@@ -166,7 +170,7 @@ class GpuCumsum(CumsumOp, GpuOp):
extern __shared__ float partialCumSum[];
// Load data in shared memory
k_fetchData_%(nodename)s(partialCumSum, input, globalThreadID, dataStrides, dataOffset);
k_fetchData_%(nodename)s(partialCumSum, input, globalThreadID, inputStrides, dataOffset);
// Use a dichotomy approach to compute the cumsum (i.e. balanced binary tree).
// The tree is sweeped from the leaves to the root and from the root to the leaves.
......@@ -175,7 +179,7 @@ class GpuCumsum(CumsumOp, GpuOp):
k_reversePhase_%(nodename)s(partialCumSum);
// Write the final output to global memory
k_pushData_%(nodename)s(partialCumSum, output, globalThreadID, dataStrides, dataOffset);
k_pushData_%(nodename)s(partialCumSum, output, globalThreadID, outputStrides, dataOffset);
if (blockSum != NULL){
if (threadIdx.x == blockDim.x - 1) {
......@@ -186,19 +190,23 @@ class GpuCumsum(CumsumOp, GpuOp):
int cumSum_%(nodename)s(CudaNdarray* input, CudaNdarray* output, int maxThreads, int axis, int maxGridY) {
int shape[2] = { 1, 1 };
dim3 dataStrides(0,0,0);
dim3 inputStrides(0,0,0);
dim3 outputStrides(0,0,0);
switch (CudaNdarray_NDIM(input))
{
case 1:
shape[0] = CudaNdarray_HOST_DIMS(input)[0];
dataStrides.x = CudaNdarray_HOST_STRIDES(input)[0];
inputStrides.x = CudaNdarray_HOST_STRIDES(input)[0];
outputStrides.x = CudaNdarray_HOST_STRIDES(output)[0];
break;
case 2:
shape[0] = CudaNdarray_HOST_DIMS(input)[0];
shape[1] = CudaNdarray_HOST_DIMS(input)[1];
dataStrides.x = CudaNdarray_HOST_STRIDES(input)[0];
dataStrides.y = CudaNdarray_HOST_STRIDES(input)[1];
inputStrides.x = CudaNdarray_HOST_STRIDES(input)[0];
inputStrides.y = CudaNdarray_HOST_STRIDES(input)[1];
outputStrides.x = CudaNdarray_HOST_STRIDES(output)[0];
outputStrides.y = CudaNdarray_HOST_STRIDES(output)[1];
break;
default:
printf("Only 1D and 2D cumsum is implemented yet.\\n");
......@@ -211,9 +219,13 @@ class GpuCumsum(CumsumOp, GpuOp):
}
if (axis == 1) {
int tmp = dataStrides.x;
dataStrides.x = dataStrides.y;
dataStrides.y = tmp;
int tmp = inputStrides.x;
inputStrides.x = inputStrides.y;
inputStrides.y = tmp;
tmp = outputStrides.x;
outputStrides.x = outputStrides.y;
outputStrides.y = tmp;
}
int numElements = shape[axis] - (shape[axis] %% 2);
......@@ -235,7 +247,8 @@ class GpuCumsum(CumsumOp, GpuOp):
CudaNdarray_DEV_DATA(input),
CudaNdarray_DEV_DATA(output),
numElements,
dataStrides,
inputStrides,
outputStrides,
dataOffset,
CudaNdarray_DEV_DATA(deviceBlockSum)
);
......@@ -255,7 +268,7 @@ class GpuCumsum(CumsumOp, GpuOp):
CudaNdarray_DEV_DATA(output),
CudaNdarray_DEV_DATA(deviceBlockSum),
numElements,
dataStrides,
outputStrides,
dataOffset
);
}
......@@ -268,7 +281,8 @@ class GpuCumsum(CumsumOp, GpuOp):
(
CudaNdarray_DEV_DATA(input),
CudaNdarray_DEV_DATA(output),
dataStrides,
inputStrides,
outputStrides,
dataOffset,
shape[axis]-2,
shape[axis]-1
......@@ -305,11 +319,7 @@ class GpuCumsum(CumsumOp, GpuOp):
// If output is already allocated, check if its shape matches the input's one.
if (!needAllocation) {
for (int i= 0; i < CudaNdarray_NDIM(%(x)s); ++i) {
if (CudaNdarray_HOST_DIMS(%(x)s)[i] == CudaNdarray_HOST_DIMS(%(z)s)[i]) {
needAllocation = true;
}
if (CudaNdarray_HOST_STRIDES(%(x)s)[i] == CudaNdarray_HOST_STRIDES(%(z)s)[i]) {
if (CudaNdarray_HOST_DIMS(%(x)s)[i] != CudaNdarray_HOST_DIMS(%(z)s)[i]) {
needAllocation = true;
}
}
......@@ -318,11 +328,6 @@ class GpuCumsum(CumsumOp, GpuOp):
if (needAllocation){
Py_XDECREF(%(z)s);
%(z)s = (CudaNdarray*) CudaNdarray_NewDims(CudaNdarray_NDIM(%(x)s), shape);
// Copy strides information
for (int i= 0; i < CudaNdarray_NDIM(%(x)s); ++i) {
CudaNdarray_set_stride(%(z)s, i, CudaNdarray_HOST_STRIDES(%(x)s)[i]);
}
}
if (!%(z)s) {
......
......@@ -52,6 +52,13 @@ class TestGpuCumsum(theano.tensor.tests.test_extra_ops.TestCumsumOp):
a = np.random.randint(10, size=(42,)).astype("float32")
assert np.allclose(np.cumsum(a[::2]), f(a))
# Alternative stepped strides
f = theano.function([x], cumsum(x), mode=self.mode)
assert [n for n in f.maker.fgraph.toposort()
if isinstance(n.op, GpuCumsum)]
a = np.random.randint(10, size=(42,)).astype("float32")
assert np.allclose(np.cumsum(a[::2]), f(a[::2]))
# Negative strides
f = theano.function([x], cumsum(x[::-1]), mode=self.mode)
assert [n for n in f.maker.fgraph.toposort()
......@@ -59,6 +66,48 @@ class TestGpuCumsum(theano.tensor.tests.test_extra_ops.TestCumsumOp):
a = np.random.randint(10, size=(42,)).astype("float32")
assert np.allclose(np.cumsum(a[::-1]), f(a))
def test_Strides2D(self):
x = T.fmatrix('x')
for shape_axis, axis in zip([0, 1, 0], [0, 1, None]):
a = np.random.random((42, 30)).astype("float32")
# Stepped strides along axis=0
f = theano.function([x], cumsum(x[::2], axis=axis), mode=self.mode)
assert [n for n in f.maker.fgraph.toposort()
if isinstance(n.op, GpuCumsum)]
assert np.allclose(np.cumsum(a[::2], axis=axis), f(a))
# Stepped strides along axis=1
f = theano.function([x], cumsum(x[:, ::2], axis=axis), mode=self.mode)
assert [n for n in f.maker.fgraph.toposort()
if isinstance(n.op, GpuCumsum)]
assert np.allclose(np.cumsum(a[:, ::2], axis=axis), f(a))
# Alternative stepped strides along axis=0
f = theano.function([x], cumsum(x), mode=self.mode)
assert [n for n in f.maker.fgraph.toposort()
if isinstance(n.op, GpuCumsum)]
assert np.allclose(np.cumsum(a[::2]), f(a[::2]))
# Alternative stepped strides along axis=1
f = theano.function([x], cumsum(x), mode=self.mode)
assert [n for n in f.maker.fgraph.toposort()
if isinstance(n.op, GpuCumsum)]
assert np.allclose(np.cumsum(a[:, ::2]), f(a[:, ::2]))
# Negative strides along axis=0
f = theano.function([x], cumsum(x[::-1], axis=axis), mode=self.mode)
assert [n for n in f.maker.fgraph.toposort()
if isinstance(n.op, GpuCumsum)]
assert np.allclose(np.cumsum(a[::-1], axis=axis), f(a))
# Negative strides along axis=1
f = theano.function([x], cumsum(x[:, ::-1], axis=axis), mode=self.mode)
assert [n for n in f.maker.fgraph.toposort()
if isinstance(n.op, GpuCumsum)]
assert np.allclose(np.cumsum(a[:, ::-1], axis=axis), f(a))
def test_GpuCumsum1D(self):
block_max_size = self.max_threads_dim0 * 2
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论