提交 a59d9a54 authored 作者: Frederic Bastien's avatar Frederic Bastien

simplified code.

上级 adb994dd
...@@ -1019,6 +1019,7 @@ class GpuSum(Op): ...@@ -1019,6 +1019,7 @@ class GpuSum(Op):
#TODO: This kernel is pretty inefficient in terms of reading, because if A is #TODO: This kernel is pretty inefficient in terms of reading, because if A is
# c_contiguous (typical case) then each warp is accessing non-contigous # c_contiguous (typical case) then each warp is accessing non-contigous
# memory (a segment of a column). # memory (a segment of a column).
reducebuf = self._k_reduce_buf('Z[blockIdx.x * sZ0]')
print >> sio, """ print >> sio, """
static __global__ void kernel_reduce_sum_10_%(nodename)s( static __global__ void kernel_reduce_sum_10_%(nodename)s(
const int d0, const int d0,
...@@ -1041,31 +1042,7 @@ class GpuSum(Op): ...@@ -1041,31 +1042,7 @@ class GpuSum(Op):
float Ai = A[i0 * sA0 + blockIdx.x * sA1]; float Ai = A[i0 * sA0 + blockIdx.x * sA1];
mysum += Ai; mysum += Ai;
} }
buf[threadNum] = mysum; %(reducebuf)s
__syncthreads();
// rest of function is handled by one warp
if (threadNum < warpSize)
{
for (int i = threadNum + warpSize; i < threadCount; i += warpSize)
{
mysum += buf[i];
}
buf[threadNum] = mysum;
if (threadNum < 16)
{
//reduce so that threadNum 0 has the sum of everything
if(threadNum + 16 < threadCount) buf[threadNum] += buf[threadNum+16];
if(threadNum + 8 < threadCount) buf[threadNum] += buf[threadNum+8];
if(threadNum + 4 < threadCount) buf[threadNum] += buf[threadNum+4];
if(threadNum + 2 < threadCount) buf[threadNum] += buf[threadNum+2];
if(threadNum + 1 < threadCount) buf[threadNum] += buf[threadNum+1];
if (threadNum == 0)
{
Z[blockIdx.x * sZ0] = buf[0];
}
}
}
} }
""" %locals() """ %locals()
if self.reduce_mask == (1,1,0): if self.reduce_mask == (1,1,0):
...@@ -1155,6 +1132,7 @@ class GpuSum(Op): ...@@ -1155,6 +1132,7 @@ class GpuSum(Op):
if self.reduce_mask == (0,0,1): if self.reduce_mask == (0,0,1):
# this kernel uses one block for each row, # this kernel uses one block for each row,
# threads per block for each element per row. # threads per block for each element per row.
reducebuf = self._k_reduce_buf('Z[i0 * sZ0 + i1 * sZ1]')
print >> sio, """ print >> sio, """
static __global__ void kernel_reduce_sum_001_%(nodename)s( static __global__ void kernel_reduce_sum_001_%(nodename)s(
const int d0, const int d0,
...@@ -1181,31 +1159,7 @@ class GpuSum(Op): ...@@ -1181,31 +1159,7 @@ class GpuSum(Op):
{ {
mysum += A[i0 * sA0 + i1 * sA1 + i2 * sA2]; mysum += A[i0 * sA0 + i1 * sA1 + i2 * sA2];
} }
buf[threadNum] = mysum; %(reducebuf)s
__syncthreads();
// rest of function is handled by one warp
if (threadNum < warpSize)
{
for (int i = threadNum + warpSize; i < threadCount; i += warpSize)
{
mysum += buf[i];
}
buf[threadNum] = mysum;
if (threadNum < 16)
{
//reduce so that threadNum 0 has the sum of everything
if(threadNum + 16 < threadCount) buf[threadNum] += buf[threadNum+16];
if(threadNum + 8 < threadCount) buf[threadNum] += buf[threadNum+8];
if(threadNum + 4 < threadCount) buf[threadNum] += buf[threadNum+4];
if(threadNum + 2 < threadCount) buf[threadNum] += buf[threadNum+2];
if(threadNum + 1 < threadCount) buf[threadNum] += buf[threadNum+1];
if (threadNum == 0)
{
Z[i0 * sZ0 + i1 * sZ1] = buf[0];
}
}
}
} }
} }
} }
...@@ -1234,6 +1188,7 @@ class GpuSum(Op): ...@@ -1234,6 +1188,7 @@ class GpuSum(Op):
} }
""" %locals() """ %locals()
if self.reduce_mask == (1,0,1,1): if self.reduce_mask == (1,0,1,1):
reducebuf = self._k_reduce_buf('Z[blockIdx.x*sZ0]')
print >> sio, """ print >> sio, """
static __global__ void kernel_reduce_sum_1011_%(nodename)s( static __global__ void kernel_reduce_sum_1011_%(nodename)s(
const unsigned int d0, const unsigned int d0,
...@@ -1264,31 +1219,7 @@ class GpuSum(Op): ...@@ -1264,31 +1219,7 @@ class GpuSum(Op):
} }
} }
} }
buf[threadNum] = mysum; %(reducebuf)s
__syncthreads();
// rest of function is handled by one warp
if (threadNum < warpSize)
{
for (int i = threadNum + warpSize; i < threadCount; i += warpSize)
{
mysum += buf[i];
}
buf[threadNum] = mysum;
if (threadNum < 16)
{
//reduce so that threadNum 0 has the sum of everything
if(threadNum + 16 < threadCount) buf[threadNum] += buf[threadNum+16];
if(threadNum + 8 < threadCount) buf[threadNum] += buf[threadNum+8];
if(threadNum + 4 < threadCount) buf[threadNum] += buf[threadNum+4];
if(threadNum + 2 < threadCount) buf[threadNum] += buf[threadNum+2];
if(threadNum + 1 < threadCount) buf[threadNum] += buf[threadNum+1];
if (threadNum == 0)
{
Z[blockIdx.x*sZ0] = buf[0];
}
}
}
} }
""" %locals() """ %locals()
return sio.getvalue() return sio.getvalue()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论