提交 7b13e2f4 authored 作者: Adam Becker's avatar Adam Becker

multiple improvements to gpu topk

- added xlarge kernel to handle array size >= 2^31 - ported original pytorch kernel - various small fixes
上级 330dd345
......@@ -260,6 +260,12 @@ struct RadixConfig<ga_half> {
#error "RADIX_SIZE must be smaller than warp size (32)"
#endif
void __device__ atomicAdd(long long *dst, long long &src) {
atomicAdd(
reinterpret_cast<unsigned long long*>(dst),
reinterpret_cast<unsigned long long&>(src));
}
template <typename T>
static inline __device__ T binary_cumsum(
int idx, int warp_id, T* smem, bool value) {
......@@ -343,7 +349,7 @@ static __device__ inline T& ptr_at(T *ptr, ga_ssize offset) {
// read array element using raw(byte) offset
template <typename T>
static __device__ inline T ptr_read(T *ptr, ga_ssize offset) {
static __device__ inline T ptr_read_cached(T *ptr, ga_ssize offset) {
return __ldg(((T*)((char*)ptr + offset)));
}
......@@ -29,9 +29,8 @@ KERNEL void k_topk_dense(
const ga_ubyte warp_id = idx / GA_WARP_SIZE;
// 0. get the slice for thread block to work on
// TODO if ndim <= 3, use native indexing ? (blockIdx.[xyz])
ga_size gid = GID_0, gidx;
$set_slice
//for(int i=1; i<NDIM; i++) {
......@@ -76,6 +75,7 @@ KERNEL void k_topk_dense(
}
local_barrier();
// find the bucket and update k2
// smem[:RADIX_SIZE:-1] = k2 - cumsum(smem[:RADIX_SIZE-1:-1])
if (idx == 0) {
ga_int sum = k2;
......@@ -130,4 +130,3 @@ KERNEL void k_topk_dense(
#endif
}
}
#define RADIX_BITS 2
#define RADIX_SIZE (1<<RADIX_BITS)
#define RADIX_MASK(n) ((RADIX_SIZE-1) << (n*RADIX_BITS))
#define RADIX_DIGITS(T) (bitsof(T)/RADIX_BITS)
// works when length on axis is in [1025, 2^31-1]
KERNEL void k_topk_dense_large(
#define COUNT_TYPE $count_t
#define KERNEL_NAME $kname
// works when array size along axis is within [1025, 2^63-1]
template <typename DataType, typename RadixType, typename CountType>
__device__ DataType find_pattern(DataType* smem,
DataType* data,
CountType slice_size,
CountType stride,
RadixType known_bits,
RadixType known_bits_mask) {
if (LID_0 < 32)
smem[LID_0] = 0;
local_barrier();
// All threads participate in the loop, in order to sync on the flag
for (CountType i = LID_0; i < (slice_size + (CountType)LDIM_0-1); i += LDIM_0) {
bool in_range = (i < slice_size);
DataType v = in_range ? ptr_read_cached(data, i*stride) : 0;
if (in_range && ((RadixConfig<DataType>::convert(v) & known_bits_mask) == known_bits)) {
// There should not be conflicts if we are using find_pattern,
// since the result is unique
smem[0] = 1;
smem[1] = v; // can't use val as the flag, since it could be 0
}
local_barrier();
DataType found = smem[0];
DataType val = smem[1];
local_barrier();
// Check to see if a thread found the value
if (found != 0)
return val;
}
return 0;
}
// This function counts the distribution of all input values in a
// slice we are selecting by radix digit at `radix_digit_pos`, but only
// those that pass the filter `((v & known_bits_mask) == known_bits)`.
// This produces and broadcasts the seen counts for a single block only.
// `smem` must have at least `RADIX_SIZE` elements.
template <typename DataType, typename RadixType, typename CountType>
__device__ void count_radix_masked(CountType counts[RADIX_SIZE],
CountType* smem,
RadixType known_bits,
RadixType known_bits_mask,
int radix_digit_pos,
CountType slice_size,
CountType stride,
DataType* data) {
// Clear out per-thread counts from a previous round
#pragma unroll
for (int i = 0; i < RADIX_SIZE; ++i)
counts[i] = 0;
if (LID_0 < RADIX_SIZE)
smem[LID_0] = 0;
local_barrier();
// Scan over all the data. Upon a read, the warp will accumulate
// counts per each digit in the radix using warp voting.
for (CountType i = LID_0; i < slice_size; i += LDIM_0) {
RadixType val = RadixConfig<DataType>::convert(ptr_read_cached(data, i*stride));
bool hasVal = ((val & known_bits_mask) == known_bits);
RadixType digit_in_radix = Bitfield<RadixType>::get(val, radix_digit_pos, RADIX_BITS);
#pragma unroll
for (int j = 0; j < RADIX_SIZE; ++j) {
bool vote = hasVal && (digit_in_radix == j);
counts[j] += __popc(__ballot(vote));
}
}
// Now, for each warp, sum values
if (lane_id() == 0) {
for (int i=0; i<RADIX_SIZE; ++i)
atomicAdd(&smem[i], counts[i]);
}
/*
// not sure why, but this just give wrong results
if (lane_id() < RADIX_SIZE)
atomicAdd(&smem[lane_id()], counts[lane_id()]);
*/
local_barrier();
// For each thread, read in the total counts
#pragma unroll
for (unsigned int i = 0; i < RADIX_SIZE; ++i)
counts[i] = smem[i];
local_barrier();
}
template <typename DataType, typename RadixType, typename CountType>
__device__ void radix_select(DataType* data,
CountType k,
bool order,
CountType slice_size,
CountType stride,
CountType* smem,
DataType* top_kth) {
// Per-thread buckets into which we accumulate digit counts in our
// radix
register CountType counts[RADIX_SIZE];
// We only consider elements x such that (x & known_bits_mask) == known_bits
// Initially, we consider all elements of the array, so the above
// statement is true regardless of input.
RadixType known_bits = 0, known_bits_mask = 0;
// We are looking for the top k_to_find-th element when iterating over
// digits; this count gets reduced by elimination when counting
// successive digits
CountType k_to_find = abs(k);
// We start at the most significant digit in our radix, scanning
// through to the least significant digit
#pragma unroll
for (int digit_pos = bitsof(DataType) - RADIX_BITS;
digit_pos >= 0; digit_pos -= RADIX_BITS) {
// Count radix distribution for the current position and reduce
// across all threads
count_radix_masked<DataType, RadixType, CountType>(
counts, smem,
known_bits, known_bits_mask, digit_pos,
slice_size, stride, data);
// All threads participate in the comparisons below to know the
// final result
#define CHECK_RADIX(i) \\
int count = counts[i]; \\
/* All threads have the same value in counts here, so all */ \\
/* threads will return from the function. */ \\
if (count == 1 && k_to_find == 1) { \\
/* There is a unique answer. */ \\
known_bits = Bitfield<RadixType>::set( \\
known_bits, i, digit_pos, RADIX_BITS); \\
known_bits_mask = Bitfield<RadixType>::set( \\
known_bits_mask, RADIX_SIZE-1, digit_pos, RADIX_BITS); \\
/* The answer is now the unique element v such that: */ \\
/* (v & known_bits_mask) == known_bits */ \\
/* However, we do not yet know what the actual element is. We */ \\
/* need to perform a search through the data to find the */ \\
/* element that matches this pattern. */ \\
*top_kth = find_pattern<DataType, RadixType, CountType>( \\
(DataType*) smem, data, slice_size, \\
stride, known_bits, known_bits_mask); \\
return; \\
} \\
if (count >= k_to_find) { \\
known_bits = Bitfield<RadixType>::set(known_bits, i, digit_pos, RADIX_BITS); \\
known_bits_mask = Bitfield<RadixType>::set( \\
known_bits_mask, RADIX_SIZE-1, digit_pos, RADIX_BITS); \\
/* The top-Kth element v must now be one such that: */ \\
/* (v & known_bits_mask == known_bits) */ \\
/* but we haven't narrowed it down; we must check the next */ \\
/* least-significant digit */ \\
break; \\
} \\
k_to_find -= count
if (order) {
#pragma unroll
for (int i=RADIX_SIZE - 1; i >= 0; --i) {
CHECK_RADIX(i);
}
} else {
#pragma unroll
for (int i=0; i < RADIX_SIZE; ++i) {
CHECK_RADIX(i);
}
}
#undef CHECK_RADIX
} // end digit_pos for
// There is no unique result, but there is a non-unique result
// matching `known_bits` exactly
*top_kth = RadixConfig<DataType>::deconvert(known_bits);
}
KERNEL void KERNEL_NAME(
$dims
// ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM}
$dstv
......@@ -19,139 +208,100 @@ KERNEL void k_topk_dense_large(
INPUT_TYPE* src,
$src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
ga_size size, ga_ushort inp_per_thread) {
LOCAL_MEM ga_int smem[32];
LOCAL_MEM radix_t known_bits;
LOCAL_MEM ga_uint k2;
int counts[RADIX_SIZE];
unsigned out_idx;
INPUT_TYPE xval;
radix_t x;
bool in_range, is_topk;
const ga_uint idx = LID_0;
const ga_uint inp_idx = idx * inp_per_thread;
ga_size size) {
LOCAL_MEM COUNT_TYPE smem[32];
INPUT_TYPE topkth_value;
const bool order = (k>0);
k = (order ? k : -k);
const ga_int idx = LID_0;
const ga_int warp_id = idx / GA_WARP_SIZE;
// 0. get the slice for thread block to work on
// TODO if ndim <= 3, use native indexing ? (blockIdx.[xyz])
// get the slice for thread block to work on
// size <- the axis to work on
// dims_1+ <- batched dimensions
ga_uint gid = GID_0, gidx;
$set_slice
//for(int i=1; i<NDIM; i++) {
// gidx = gid % dims_$${i};
// gid /= dims_$${i};
// dsti = ptr_add(dsti, gidx*dsti_strides_$${i};
// dstv = ptr_add(dstv, gidx*dstv_strides_$${i};
// dsti = ptr_add(dsti, gidx*dsti_strides_$${i});
// dstv = ptr_add(dstv, gidx*dstv_strides_$${i});
// src = ptr_add(src, gidx*src_strides_$${i});
//}
src = ptr_add(src, idx*inp_per_thread*src_strides_0);
if (idx==0) {
known_bits = 0;
k2 = (k>=0) ? k : -k;
}
const radix_t inv_bits = (k>=0) ? 0 : ~0;
if (k<0) { k = -k; }
local_barrier();
// 1. find bits of top-k-th value using radix select
#pragma unroll
for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
#pragma unroll
for (int j=0; j<RADIX_SIZE; ++j)
counts[j] = 0;
if (warp_id == 0)
smem[idx] = 0;
local_barrier();
radix_select<INPUT_TYPE, radix_t, COUNT_TYPE>(
src, k, order, size, src_strides_0,
smem, &topkth_value);
// count within warp
for (int j=0; j<inp_per_thread; ++j) {
in_range = (inp_idx+j) < size;
xval = in_range ? ptr_read(src, j*src_strides_0) : (INPUT_TYPE)0;
x = inv_bits^RadixConfig<INPUT_TYPE>::convert(xval);
ga_int digit = (int)((x>>i) & (RADIX_SIZE-1));
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
bool incr_bin = (
(bin == digit) &&
((x >> (i+RADIX_BITS)) == known_bits) && in_range);
counts[bin] += __popc(__ballot(incr_bin));
}
}
local_barrier();
// Every value that is strictly less/greater than `pattern`
// (depending on sort dir) in sorted int format is in the top-K.
// The top-K value itself might not be unique.
//
// Since there are a variable number of elements that we see that
// are within the top-k, we don't know at what index to write out
// the resulting values.
// In order to get this, we perform an exclusive cumsum of
// `has_topk`. This will return the resulting index into which we
// need to write the result, if a thread has a result.
// sum counts across all warps
if (lane_id() < RADIX_SIZE) {
atomicAdd(&smem[lane_id()], counts[lane_id()]);
}
local_barrier();
// All threads need to participate in the loop and the prefix sum,
// but not necessarily in the load; hence loop bounds being rounded
// up to a multiple of the block dim.
COUNT_TYPE iter_bound = size + LDIM_0-1;
INDEX_TYPE write_base = 0;
// update known bits
if (idx==0) {
#pragma unroll
for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
if (smem[bin] >= k2) {
known_bits = (known_bits << RADIX_BITS) | bin;
break;
} else
k2 -= smem[bin];
}
for (int i = idx; i < iter_bound; i += LDIM_0) {
bool in_range = (i < size);
INPUT_TYPE v = in_range ? ptr_read_cached(src, i*src_strides_0) : 0;
bool has_topk;
if (order) {
has_topk = in_range && (v > topkth_value);
} else {
has_topk = in_range && (v < topkth_value);
}
local_barrier();
}
// now we use k2 for base index to write output
if (idx == 0)
k2 = 0;
local_barrier();
int index = binary_cumsum_exclusive(idx, warp_id, smem, has_topk);
int carry = smem[LDIM_0 / 32 - 1];
// 2. write values smaller than top-kth
for (int i=0; i<inp_per_thread; ++i) {
in_range = (inp_idx+i) < size;
xval = in_range ? ptr_read(src, i*src_strides_0) : (INPUT_TYPE)0;
x = inv_bits ^ RadixConfig<INPUT_TYPE>::convert(xval);
is_topk = (x > known_bits) && in_range;
out_idx = binary_cumsum(idx, warp_id, smem, is_topk);
if (is_topk) {
if (has_topk) {
COUNT_TYPE write_idx = write_base + index;
#if WRITE_VALUE == 1
ptr_at(dstv, (out_idx+k2-1) * dstv_strides_0) = xval;
ptr_at(dstv, write_idx * dstv_strides_0) = v;
#endif
#if WRITE_INDEX == 1
ptr_at(dsti, (out_idx+k2-1) * dsti_strides_0) = (INDEX_TYPE)(idx*inp_per_thread + i);
ptr_at(dsti, write_idx * dsti_strides_0) = (INDEX_TYPE)i;
#endif
}
local_barrier();
if (idx == blockDim.x - 1)
k2 += out_idx;
local_barrier();
write_base += carry;
}
// 3. write values equal to top-kth
for (int i=0; i<inp_per_thread; ++i) {
in_range = (inp_idx+i) < size;
xval = in_range ? ptr_read(src, i*src_strides_0) : (INPUT_TYPE)0;
x = inv_bits ^ RadixConfig<INPUT_TYPE>::convert(xval);
is_topk = (x == known_bits) && in_range;
out_idx = binary_cumsum(idx, warp_id, smem, is_topk);
is_topk &= (out_idx+k2) <= k;
if (is_topk) {
COUNT_TYPE topk_remaining = (k - write_base);
for (COUNT_TYPE i = idx; i < iter_bound; i += LDIM_0) {
bool in_range = (i < size);
INPUT_TYPE v = in_range ? ptr_read_cached(src, i*src_strides_0) : 0;
bool has_topk = in_range && (v == topkth_value);
int index = binary_cumsum_exclusive(idx, warp_id, smem, has_topk);
int carry = smem[LDIM_0 / 32 - 1];
if (has_topk && index < topk_remaining) {
COUNT_TYPE write_idx = write_base + index;
#if WRITE_VALUE == 1
ptr_at(dstv, (out_idx+k2-1) * dstv_strides_0) = xval;
ptr_at(dstv, write_idx * dstv_strides_0) = v;
#endif
#if WRITE_INDEX == 1
ptr_at(dsti, (out_idx+k2-1) * dsti_strides_0) = (INDEX_TYPE)(inp_idx+ i);
ptr_at(dsti, write_idx * dsti_strides_0) = (INDEX_TYPE)i;
#endif
}
local_barrier();
if (idx == blockDim.x - 1)
k2 += out_idx;
local_barrier();
if(k2 >= k)
if (carry >= topk_remaining)
break;
topk_remaining -= carry;
write_base += carry;
}
}
......@@ -58,20 +58,8 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
def gpu_kernels(self, node, nodename):
# load kernel source
device_type = node.inputs[0].type.context.kind
knames = ['k_topk_dense', 'k_topk_dense_large']
kernel_ext = {b'cuda': '.cu', b'opencl': '.cl'}[device_type]
common_ext = {b'cuda': '.cuh', b'opencl': '.h'}[device_type]
kernel_src = {}
for kname in knames:
with open(os.path.join(
os.path.dirname(__file__), 'c_code', kname + kernel_ext
), 'r') as f:
kernel_src[kname] = f.read()
with open(os.path.join(
os.path.dirname(__file__), 'c_code', 'k_topk_common' + common_ext
), 'r') as f:
common_src = f.read()
# prepare "$" macros
if device_type == b'cuda':
......@@ -108,31 +96,46 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
elif device_type == b'opencl':
raise NotImplementedError()
# compile kernels
kernels = []
# setup parameters
param_types = [ga.SIZE] * (ndim - 1) # dims
for _ in range(int(self.return_values) + int(self.return_indices)):
for _ in range(self.return_values + self.return_indices):
param_types.append(ga.GpuArray) # dst*
param_types.extend([ga.SSIZE] * ndim) # dst*_strides
param_types.append(ga.SIZE) # k
param_types.append(ga.GpuArray) # src
param_types.extend([ga.SSIZE] * ndim) # src_strides
param_types.append(ga.SIZE) # size
kernels.append(Kernel(
code=Template(common_src + kernel_src['k_topk_dense']).substitute(**subs),
name='k_topk_dense',
params=param_types,
flags=flags,
objvar='k_topk_dense_' + nodename
))
param_types.append(np.uint16) # inp_per_thread
kernels.append(Kernel(
code=Template(common_src + kernel_src['k_topk_dense_large']).substitute(**subs),
name='k_topk_dense_large',
params=param_types,
flags=flags,
objvar='k_topk_dense_large_' + nodename
))
# load and compile kernels
with open(os.path.join(
os.path.dirname(__file__), 'c_code', 'k_topk_common' + common_ext
)) as f:
common_src = f.read()
kernels = []
def build_kernel(fname, kname, subs):
with open(os.path.join(
os.path.dirname(__file__), 'c_code', fname)) as f:
kernel_src = f.read()
ker = Kernel(
code=Template(common_src + kernel_src).substitute(**subs),
name=kname,
params=param_types,
flags=flags,
objvar=kname + nodename)
return ker
subs['count_t'] = 'int'
kernels.append(
build_kernel('k_topk_dense' + kernel_ext, 'k_topk_dense', subs))
subs['kname'] = 'k_topk_dense_large'
kernels.append(
build_kernel('k_topk_dense_large' + kernel_ext, 'k_topk_dense_large', subs))
subs['count_t'] = 'long long'
subs['kname'] = 'k_topk_dense_xlarge'
kernels.append(
build_kernel('k_topk_dense_large' + kernel_ext, 'k_topk_dense_xlarge', subs))
return kernels
def c_code(self, node, nodename, inps, outs, sub):
......@@ -204,16 +207,11 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
PyExc_ValueError,
"topk: kth must not be zero");
%(fail)s;
} else if (dims[%(axis)d] < odims[%(axis)d]){
} else if (dims[%(axis)d] < odims[%(axis)d]) {
PyErr_SetString(
PyExc_ValueError,
"topk: kth cannot be larger than the size of specified axis %(axis)d");
%(fail)s;
} else if (dims[%(axis)d] >= (1u << 31)) {
PyErr_SetString(
PyExc_ValueError,
"topk: on GPU, array size of specified axis cannot larger or equal than 2^31");
%(fail)s;
}
%(prep_output)s
......@@ -221,7 +219,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
size_t *grd = blk+3;
blk[0] = blk[1] = blk[2] = 1;
grd[0] = grd[1] = grd[2] = 1;
for(int i=0; i<%(ndim)d; ++i) {
for (int i=0; i<%(ndim)d; ++i) {
if (i!=%(axis)d)
grd[0] *= dims[i];
else
......@@ -233,8 +231,6 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
%(def_dvstrides)s;
%(def_distrides)s;
const ssize_t *sstrides = PyGpuArray_STRIDES(%(x)s);
// inputs per thread
unsigned short ipt = (dims[%(axis)d] + (%(MAX_TPB)d / 2)-1) / (%(MAX_TPB)d / 2);
void* args[] = {
%(dims)s
%(params_dv)s
......@@ -243,19 +239,27 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
(void*)(%(x)s->ga.data),
%(sstrides)s,
(void*)(dims+%(axis)d),
(void*)(&ipt)
};
int err;
if (blk[0] > %(MAX_TPB)d) {
// LAUNCH_OUT_OF_RESOURCE if a 1024 sized block is used
blk[0] = %(MAX_TPB)d / 2;
if (dims[%(axis)d] > PY_SSIZE_T_MAX) {
PyErr_SetString(
PyExc_ValueError,
"topk: array size on specified axis is too large, should be less than PY_SSIZE_T_MAX.");
%(fail)s;
} else if (dims[%(axis)d] > (1u << 31)) {
blk[0] = %(MAX_TPB)d;
err = GpuKernel_call(
&k_topk_dense_xlarge%(nodename)s, 3,
grd, blk, 0, args);
} else if (blk[0] > %(MAX_TPB)d) {
blk[0] = %(MAX_TPB)d;
err = GpuKernel_call(
&k_topk_dense_large_%(nodename)s, 3,
&k_topk_dense_large%(nodename)s, 3,
grd, blk, 0, args);
} else {
err = GpuKernel_call(
&k_topk_dense_%(nodename)s, 3,
&k_topk_dense%(nodename)s, 3,
grd, blk, 0, args);
}
if (err != GA_NO_ERROR) {
......
......@@ -227,7 +227,10 @@ def _topk_py_impl(op, x, k, axis, idx_dtype):
assert -ndim <= axis < ndim
axis %= ndim
if k == 0:
raise ValueError('topk: k cannot be zero')
raise ValueError('topk: kth cannot be zero')
elif k > x.shape[axis]:
raise ValueError(
'topk: kth cannot be larger than the size of specified axis %d' % axis)
if abs(k) == 1:
# negative k means min instead of max
fn_max = [None, np.max, np.min][k]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论