提交 7b13e2f4 authored 作者: Adam Becker's avatar Adam Becker

multiple improvements to gpu topk

- added xlarge kernel to handle array size >= 2^31 - ported original pytorch kernel - various small fixes
上级 330dd345
......@@ -260,6 +260,12 @@ struct RadixConfig<ga_half> {
#error "RADIX_SIZE must be smaller than warp size (32)"
#endif
void __device__ atomicAdd(long long *dst, long long &src) {
atomicAdd(
reinterpret_cast<unsigned long long*>(dst),
reinterpret_cast<unsigned long long&>(src));
}
template <typename T>
static inline __device__ T binary_cumsum(
int idx, int warp_id, T* smem, bool value) {
......@@ -343,7 +349,7 @@ static __device__ inline T& ptr_at(T *ptr, ga_ssize offset) {
// read array element using raw(byte) offset
template <typename T>
static __device__ inline T ptr_read(T *ptr, ga_ssize offset) {
static __device__ inline T ptr_read_cached(T *ptr, ga_ssize offset) {
return __ldg(((T*)((char*)ptr + offset)));
}
......@@ -29,9 +29,8 @@ KERNEL void k_topk_dense(
const ga_ubyte warp_id = idx / GA_WARP_SIZE;
// 0. get the slice for thread block to work on
// TODO if ndim <= 3, use native indexing ? (blockIdx.[xyz])
ga_size gid = GID_0, gidx;
$set_slice
//for(int i=1; i<NDIM; i++) {
......@@ -76,6 +75,7 @@ KERNEL void k_topk_dense(
}
local_barrier();
// find the bucket and update k2
// smem[:RADIX_SIZE:-1] = k2 - cumsum(smem[:RADIX_SIZE-1:-1])
if (idx == 0) {
ga_int sum = k2;
......@@ -130,4 +130,3 @@ KERNEL void k_topk_dense(
#endif
}
}
......@@ -58,20 +58,8 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
def gpu_kernels(self, node, nodename):
# load kernel source
device_type = node.inputs[0].type.context.kind
knames = ['k_topk_dense', 'k_topk_dense_large']
kernel_ext = {b'cuda': '.cu', b'opencl': '.cl'}[device_type]
common_ext = {b'cuda': '.cuh', b'opencl': '.h'}[device_type]
kernel_src = {}
for kname in knames:
with open(os.path.join(
os.path.dirname(__file__), 'c_code', kname + kernel_ext
), 'r') as f:
kernel_src[kname] = f.read()
with open(os.path.join(
os.path.dirname(__file__), 'c_code', 'k_topk_common' + common_ext
), 'r') as f:
common_src = f.read()
# prepare "$" macros
if device_type == b'cuda':
......@@ -108,31 +96,46 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
elif device_type == b'opencl':
raise NotImplementedError()
# compile kernels
kernels = []
# setup parameters
param_types = [ga.SIZE] * (ndim - 1) # dims
for _ in range(int(self.return_values) + int(self.return_indices)):
for _ in range(self.return_values + self.return_indices):
param_types.append(ga.GpuArray) # dst*
param_types.extend([ga.SSIZE] * ndim) # dst*_strides
param_types.append(ga.SIZE) # k
param_types.append(ga.GpuArray) # src
param_types.extend([ga.SSIZE] * ndim) # src_strides
param_types.append(ga.SIZE) # size
kernels.append(Kernel(
code=Template(common_src + kernel_src['k_topk_dense']).substitute(**subs),
name='k_topk_dense',
params=param_types,
flags=flags,
objvar='k_topk_dense_' + nodename
))
param_types.append(np.uint16) # inp_per_thread
kernels.append(Kernel(
code=Template(common_src + kernel_src['k_topk_dense_large']).substitute(**subs),
name='k_topk_dense_large',
params=param_types,
flags=flags,
objvar='k_topk_dense_large_' + nodename
))
# load and compile kernels
with open(os.path.join(
os.path.dirname(__file__), 'c_code', 'k_topk_common' + common_ext
)) as f:
common_src = f.read()
kernels = []
def build_kernel(fname, kname, subs):
with open(os.path.join(
os.path.dirname(__file__), 'c_code', fname)) as f:
kernel_src = f.read()
ker = Kernel(
code=Template(common_src + kernel_src).substitute(**subs),
name=kname,
params=param_types,
flags=flags,
objvar=kname + nodename)
return ker
subs['count_t'] = 'int'
kernels.append(
build_kernel('k_topk_dense' + kernel_ext, 'k_topk_dense', subs))
subs['kname'] = 'k_topk_dense_large'
kernels.append(
build_kernel('k_topk_dense_large' + kernel_ext, 'k_topk_dense_large', subs))
subs['count_t'] = 'long long'
subs['kname'] = 'k_topk_dense_xlarge'
kernels.append(
build_kernel('k_topk_dense_large' + kernel_ext, 'k_topk_dense_xlarge', subs))
return kernels
def c_code(self, node, nodename, inps, outs, sub):
......@@ -204,16 +207,11 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
PyExc_ValueError,
"topk: kth must not be zero");
%(fail)s;
} else if (dims[%(axis)d] < odims[%(axis)d]){
} else if (dims[%(axis)d] < odims[%(axis)d]) {
PyErr_SetString(
PyExc_ValueError,
"topk: kth cannot be larger than the size of specified axis %(axis)d");
%(fail)s;
} else if (dims[%(axis)d] >= (1u << 31)) {
PyErr_SetString(
PyExc_ValueError,
"topk: on GPU, array size of specified axis cannot larger or equal than 2^31");
%(fail)s;
}
%(prep_output)s
......@@ -221,7 +219,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
size_t *grd = blk+3;
blk[0] = blk[1] = blk[2] = 1;
grd[0] = grd[1] = grd[2] = 1;
for(int i=0; i<%(ndim)d; ++i) {
for (int i=0; i<%(ndim)d; ++i) {
if (i!=%(axis)d)
grd[0] *= dims[i];
else
......@@ -233,8 +231,6 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
%(def_dvstrides)s;
%(def_distrides)s;
const ssize_t *sstrides = PyGpuArray_STRIDES(%(x)s);
// inputs per thread
unsigned short ipt = (dims[%(axis)d] + (%(MAX_TPB)d / 2)-1) / (%(MAX_TPB)d / 2);
void* args[] = {
%(dims)s
%(params_dv)s
......@@ -243,19 +239,27 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
(void*)(%(x)s->ga.data),
%(sstrides)s,
(void*)(dims+%(axis)d),
(void*)(&ipt)
};
int err;
if (blk[0] > %(MAX_TPB)d) {
// LAUNCH_OUT_OF_RESOURCE if a 1024 sized block is used
blk[0] = %(MAX_TPB)d / 2;
if (dims[%(axis)d] > PY_SSIZE_T_MAX) {
PyErr_SetString(
PyExc_ValueError,
"topk: array size on specified axis is too large, should be less than PY_SSIZE_T_MAX.");
%(fail)s;
} else if (dims[%(axis)d] > (1u << 31)) {
blk[0] = %(MAX_TPB)d;
err = GpuKernel_call(
&k_topk_dense_xlarge%(nodename)s, 3,
grd, blk, 0, args);
} else if (blk[0] > %(MAX_TPB)d) {
blk[0] = %(MAX_TPB)d;
err = GpuKernel_call(
&k_topk_dense_large_%(nodename)s, 3,
&k_topk_dense_large%(nodename)s, 3,
grd, blk, 0, args);
} else {
err = GpuKernel_call(
&k_topk_dense_%(nodename)s, 3,
&k_topk_dense%(nodename)s, 3,
grd, blk, 0, args);
}
if (err != GA_NO_ERROR) {
......
......@@ -227,7 +227,10 @@ def _topk_py_impl(op, x, k, axis, idx_dtype):
assert -ndim <= axis < ndim
axis %= ndim
if k == 0:
raise ValueError('topk: k cannot be zero')
raise ValueError('topk: kth cannot be zero')
elif k > x.shape[axis]:
raise ValueError(
'topk: kth cannot be larger than the size of specified axis %d' % axis)
if abs(k) == 1:
# negative k means min instead of max
fn_max = [None, np.max, np.min][k]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论