提交 bf83d342 authored 作者: Adam Becker's avatar Adam Becker

reduce shared memory consumption

上级 670c324b
......@@ -216,8 +216,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
int err = GpuKernel_call(
&k_topk_dense_%(nodename)s, 3,
grd, blk,
blk[0] * gpuarray_get_elsize(%(x)s->ga.typecode),
grd, blk, 0,
args);
if (err != GA_NO_ERROR) {
PyErr_SetString(
......
......@@ -258,7 +258,7 @@ KERNEL void k_topk_dense(
$src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
ga_size size) {
extern LOCAL_MEM radix_t smem[];
LOCAL_MEM radix_t smem[32 * RADIX_SIZE];
ga_ssize LOCAL_MEM bins[RADIX_SIZE]; // TODO: does using 32-bit gives speedup?
bool is_topk=true, is_topkth=true;
radix_t out_idx;
......@@ -267,7 +267,6 @@ KERNEL void k_topk_dense(
ga_size LOCAL_MEM k2, exceed;
const ga_uint warp_id = idx / GA_WARP_SIZE;
const ga_uint lane_id = idx % GA_WARP_SIZE;
radix_t *wmem = (radix_t*)(smem) + warp_id * GA_WARP_SIZE;
const bool in_range = (idx < size);
is_topk &= in_range;
......@@ -296,7 +295,6 @@ KERNEL void k_topk_dense(
#pragma unroll
for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
smem[idx] = 0;
int digit = (x>>i) & (RADIX_SIZE-1);
// count within warp
#pragma unroll
......@@ -304,13 +302,13 @@ KERNEL void k_topk_dense(
bool incr_bin = (bin == digit) && is_topkth && in_range;
ga_uint incr_bin_warp = __ballot(incr_bin);
if (lane_id==0)
wmem[bin] += __popc(incr_bin_warp);
smem[bin + RADIX_SIZE*warp_id] = __popc(incr_bin_warp);
}
local_barrier();
// sum counts across all warps
// TODO: test in-block parallel sum?
if (idx < RADIX_SIZE) {
for(int w=GA_WARP_SIZE; w<LDIM_0; w+=GA_WARP_SIZE)
for(int w=RADIX_SIZE; w<LDIM_0*RADIX_SIZE / GA_WARP_SIZE; w+=RADIX_SIZE)
smem[idx] += smem[idx + w];
}
local_barrier();
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论