提交 13095b4b authored 作者: Adam Becker's avatar Adam Becker

fix segfault

上级 95f6eda6
...@@ -20,13 +20,16 @@ except ImportError as e: ...@@ -20,13 +20,16 @@ except ImportError as e:
pass pass
# TODO add support is slice size is larger than max allowed block size (1024) # TODO add support when slice size is larger than max allowed block size (1024)
# TODO add runtime opt, if k==1, use max/min reduce # TODO add runtime opt, if k==1, use max/min reduce
# TODO add opt to merge argtopk / topk, or split topk_and_argtopk when only
# one result is needed
# TODO add grad
# TODO sort / argsort # TODO sort / argsort
class GpuTopKOp(GpuKernelBase, TopKOp): class GpuTopKOp(GpuKernelBase, TopKOp):
''' '''
Implements TopKOp() on gpu Implements TopKOp on gpu
''' '''
__props__ = TopKOp.__props__ __props__ = TopKOp.__props__
...@@ -79,11 +82,6 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -79,11 +82,6 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
set_slice_code = ''.join( set_slice_code = ''.join(
set_slice_code % dict(i=j) for j in range(1, ndim)) set_slice_code % dict(i=j) for j in range(1, ndim))
flags = Kernel.get_flags(node.inputs[0].dtype) flags = Kernel.get_flags(node.inputs[0].dtype)
dst = ''
if self.return_values:
dst += 'INPUT_TYPE *dstv, '
if self.return_values:
dst += 'INDEX_TYPE *dsti, '
write_value = 'ptr_at(dstv, out_idx * dstv_strides_0) = xval' if self.return_values else '' write_value = 'ptr_at(dstv, out_idx * dstv_strides_0) = xval' if self.return_values else ''
write_index = 'ptr_at(dsti, out_idx * dsti_strides_0) = (INDEX_TYPE)idx' if self.return_indices else '' write_index = 'ptr_at(dsti, out_idx * dsti_strides_0) = (INDEX_TYPE)idx' if self.return_indices else ''
subs = dict( subs = dict(
...@@ -92,8 +90,8 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -92,8 +90,8 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
dims=''.join('ga_size dims_%d, ' % i for i in range(1, ndim)), dims=''.join('ga_size dims_%d, ' % i for i in range(1, ndim)),
dstv='INPUT_TYPE *dstv,' if self.return_values else '', dstv='INPUT_TYPE *dstv,' if self.return_values else '',
dsti='INDEX_TYPE *dsti,' if self.return_indices else '', dsti='INDEX_TYPE *dsti,' if self.return_indices else '',
dstv_strides=dstv_strides_code, dstv_strides=dstv_strides_code if self.return_values else '',
dsti_strides=dsti_strides_code, dsti_strides=dsti_strides_code if self.return_indices else '',
src_strides=src_strides_code, src_strides=src_strides_code,
set_slice=set_slice_code, set_slice=set_slice_code,
write_value=write_value, write_value=write_value,
...@@ -112,6 +110,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -112,6 +110,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
param_types.append(ga.GpuArray) # src param_types.append(ga.GpuArray) # src
param_types.extend([ga.SSIZE] * ndim) # src_strides param_types.extend([ga.SSIZE] * ndim) # src_strides
param_types.append(ga.SIZE) # size param_types.append(ga.SIZE) # size
self.nargs = len(param_types)
return [Kernel( return [Kernel(
code=kernel_src, code=kernel_src,
name='k_topk_dense', name='k_topk_dense',
...@@ -143,11 +142,12 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -143,11 +142,12 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
WARP_SIZE = 32 WARP_SIZE = 32
ndim = node.inputs[0].ndim ndim = node.inputs[0].ndim
nargs = self.nargs
reordered_axes = list(range(ndim)) reordered_axes = list(range(ndim))
axis = self.axis % ndim axis = self.axis % ndim
del(reordered_axes[axis]) del(reordered_axes[axis])
reordered_axes = [axis] + reordered_axes reordered_axes = [axis] + reordered_axes
dims = ', '.join('(void*)(dims+%d)' % i for i in reordered_axes[1:]) dims = ''.join('(void*)(dims+%d), ' % i for i in reordered_axes[1:])
prep_output = '' prep_output = ''
if self.return_values: if self.return_values:
def_dvstrides = 'const ssize_t *dvstrides = PyGpuArray_STRIDES(%s)' % yv def_dvstrides = 'const ssize_t *dvstrides = PyGpuArray_STRIDES(%s)' % yv
...@@ -179,9 +179,8 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -179,9 +179,8 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
{ {
const size_t *dims = PyGpuArray_DIMS(%(x)s); const size_t *dims = PyGpuArray_DIMS(%(x)s);
size_t odims[%(ndim)d]; size_t odims[%(ndim)d];
for (int i=0; i<%(ndim)d; i++) { for (int i=0; i<%(ndim)d; i++)
odims[i] = dims[i]; odims[i] = dims[i];
}
odims[%(axis)d] = *((%(k_dtype)s*)(PyArray_DATA(%(k)s))); odims[%(axis)d] = *((%(k_dtype)s*)(PyArray_DATA(%(k)s)));
if (odims[0] > %(MAX_TPB)d) { if (odims[0] > %(MAX_TPB)d) {
PyErr_SetString( PyErr_SetString(
......
...@@ -13,16 +13,16 @@ struct RadixConfig {}; ...@@ -13,16 +13,16 @@ struct RadixConfig {};
template <> template <>
struct RadixConfig<float> { struct RadixConfig<float> {
typedef unsigned int RadixType; typedef ga_uint RadixType;
static inline __device__ RadixType convert(float v) { static inline WITHIN_KERNEL RadixType convert(float v) {
RadixType x = __float_as_int(v); RadixType x = __float_as_int(v);
RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000; RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
return (x ^ mask); return (x ^ mask);
} }
static inline __device__ float deconvert(RadixType v) { static inline WITHIN_KERNEL float deconvert(RadixType v) {
RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff; RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff;
return __int_as_float(v ^ mask); return __int_as_float(v ^ mask);
...@@ -30,55 +30,55 @@ struct RadixConfig<float> { ...@@ -30,55 +30,55 @@ struct RadixConfig<float> {
}; };
template <> template <>
struct RadixConfig<unsigned char> { struct RadixConfig<ga_uchar> {
typedef unsigned int RadixType; typedef ga_uint RadixType;
static inline __device__ RadixType convert(unsigned char v) { static inline WITHIN_KERNEL RadixType convert(ga_uchar v) {
return v; return v;
} }
static inline __device__ unsigned char deconvert(RadixType v) { static inline WITHIN_KERNEL ga_uchar deconvert(RadixType v) {
return v; return v;
} }
}; };
template <> template <>
struct RadixConfig<char> { struct RadixConfig<char> {
typedef unsigned int RadixType; typedef ga_uint RadixType;
static inline __device__ RadixType convert(char v) { static inline WITHIN_KERNEL RadixType convert(char v) {
return 128u + v; return 128u + v;
} }
static inline __device__ char deconvert(RadixType v) { static inline WITHIN_KERNEL char deconvert(RadixType v) {
return v - 128; return v - 128;
} }
}; };
template <> template <>
struct RadixConfig<short> { struct RadixConfig<ga_short> {
typedef unsigned int RadixType; typedef ga_uint RadixType;
static inline __device__ RadixType convert(short v) { static inline WITHIN_KERNEL RadixType convert(ga_short v) {
assert(sizeof(short) == 2); assert(sizeof(ga_short) == 2);
return 32768u + v; return 32768u + v;
} }
static inline __device__ short deconvert(RadixType v) { static inline WITHIN_KERNEL ga_short deconvert(RadixType v) {
return v - 32768; return v - 32768;
} }
}; };
template <> template <>
struct RadixConfig<int> { struct RadixConfig<int> {
typedef unsigned int RadixType; typedef ga_uint RadixType;
static inline __device__ RadixType convert(int v) { static inline WITHIN_KERNEL RadixType convert(int v) {
assert(sizeof(int) == 4); assert(sizeof(int) == 4);
return 2147483648u + v; return 2147483648u + v;
} }
static inline __device__ int deconvert(RadixType v) { static inline WITHIN_KERNEL int deconvert(RadixType v) {
return v - 2147483648u; return v - 2147483648u;
} }
}; };
...@@ -87,12 +87,12 @@ template <> ...@@ -87,12 +87,12 @@ template <>
struct RadixConfig<long> { struct RadixConfig<long> {
typedef unsigned long long int RadixType; typedef unsigned long long int RadixType;
static inline __device__ RadixType convert(long v) { static inline WITHIN_KERNEL RadixType convert(long v) {
assert(sizeof(long) == 8); assert(sizeof(long) == 8);
return 9223372036854775808ull + v; return 9223372036854775808ull + v;
} }
static inline __device__ long deconvert(RadixType v) { static inline WITHIN_KERNEL long deconvert(RadixType v) {
return v - 9223372036854775808ull; return v - 9223372036854775808ull;
} }
}; };
...@@ -101,13 +101,13 @@ template <> ...@@ -101,13 +101,13 @@ template <>
struct RadixConfig<double> { struct RadixConfig<double> {
typedef unsigned long long int RadixType; typedef unsigned long long int RadixType;
static inline __device__ RadixType convert(double v) { static inline WITHIN_KERNEL RadixType convert(double v) {
RadixType x = __double_as_longlong(v); RadixType x = __double_as_longlong(v);
RadixType mask = -((x >> 63)) | 0x8000000000000000; RadixType mask = -((x >> 63)) | 0x8000000000000000;
return (x ^ mask); return (x ^ mask);
} }
static inline __device__ double deconvert(RadixType v) { static inline WITHIN_KERNEL double deconvert(RadixType v) {
RadixType mask = ((v >> 63) - 1) | 0x8000000000000000; RadixType mask = ((v >> 63) - 1) | 0x8000000000000000;
return __longlong_as_double(v ^ mask); return __longlong_as_double(v ^ mask);
} }
...@@ -116,9 +116,9 @@ struct RadixConfig<double> { ...@@ -116,9 +116,9 @@ struct RadixConfig<double> {
#ifdef USE_HALF #ifdef USE_HALF
template <> template <>
struct RadixConfig<half> { struct RadixConfig<half> {
typedef unsigned int RadixType; typedef ga_uint RadixType;
static inline __device__ RadixType convert(half v) { static inline WITHIN_KERNEL RadixType convert(half v) {
#if defined(__CUDACC_VER__) && __CUDACC_VER__ >= 80000 #if defined(__CUDACC_VER__) && __CUDACC_VER__ >= 80000
RadixType x = __half_as_ushort(v); RadixType x = __half_as_ushort(v);
RadixType mask = -((x >> 15)) | 0x8000; RadixType mask = -((x >> 15)) | 0x8000;
...@@ -129,7 +129,7 @@ struct RadixConfig<half> { ...@@ -129,7 +129,7 @@ struct RadixConfig<half> {
#endif #endif
} }
static inline __device__ half deconvert(RadixType v) { static inline WITHIN_KERNEL half deconvert(RadixType v) {
#if defined(__CUDACC_VER__) && __CUDACC_VER__ >= 80000 #if defined(__CUDACC_VER__) && __CUDACC_VER__ >= 80000
RadixType mask = ((v >> 15) - 1) | 0x8000; RadixType mask = ((v >> 15) - 1) | 0x8000;
return __ushort_as_half(v ^ mask); return __ushort_as_half(v ^ mask);
...@@ -142,7 +142,7 @@ struct RadixConfig<half> { ...@@ -142,7 +142,7 @@ struct RadixConfig<half> {
#endif #endif
// $$inp_t should be replaced in c_code // $$inp_t should be replaced in c_code
// we cannot use templated __global__ because gpuarray API does not support it yet // we cannot use templated kernel because gpuarray API does not support it
#define NDIM $ndim #define NDIM $ndim
#define INPUT_TYPE $inp_t #define INPUT_TYPE $inp_t
#define INDEX_TYPE $out_t #define INDEX_TYPE $out_t
...@@ -153,33 +153,37 @@ struct RadixConfig<half> { ...@@ -153,33 +153,37 @@ struct RadixConfig<half> {
#define RADIX_DIGITS(T) (bitsof(T)/RADIX_BITS) #define RADIX_DIGITS(T) (bitsof(T)/RADIX_BITS)
#define radix_t RadixConfig<INPUT_TYPE>::RadixType #define radix_t RadixConfig<INPUT_TYPE>::RadixType
#if RADIX_SIZE > 32 #if RADIX_SIZE > GA_WARP_SIZE
#error "RADIX_SIZE must be smaller than warp size (32)" #error "RADIX_SIZE must be smaller than warp size"
#endif #endif
template <typename T> template <typename T>
static inline __device__ T binary_cumsum(int idx, int warp_id, int lane_id, T* smem, bool value) { static inline WITHIN_KERNEL T binary_cumsum(
// cumsum within 1D thread block, which adds up `value` of all threads whose id is *no greater than* the current thread int idx, int warp_id, int lane_id, T* smem, bool value) {
// cumsum within 1D thread block, which adds up `value` of all threads
// whose id is *no greater than* the current thread
// binary_cumsum(1, 0, 1, 0, 1) -> (1, 1, 2, 2, 3)
// cumsum within warp // cumsum within warp
unsigned int warp_bits = __ballot(value); ga_uint warp_bits = __ballot(value);
T warp_sum = __popc(((2<<lane_id)-1) & warp_bits); T warp_sum = __popc(((2<<lane_id)-1) & warp_bits);
if (lane_id == 0) if (lane_id == 0)
smem[warp_id] = __popc(warp_bits); smem[warp_id] = __popc(warp_bits);
__syncthreads(); local_barrier();
// cumsum across warps in one thread // cumsum across warps in one thread
if (idx == 0) { if (idx == 0) {
int current = 0; int current = 0;
for (int i = 0; i < blockDim.x / 32; ++i) { for (int i = 0; i < LDIM_0 / GA_WARP_SIZE; ++i) {
T v = smem[i]; T v = smem[i];
smem[i] = smem[i]+current; smem[i] = smem[i]+current;
current = current+v; current = current+v;
} }
} }
__syncthreads(); local_barrier();
// load the carry from the preceding warp // load the carry from the preceding warp
if (warp_id >= 1) { if (warp_id >= 1) {
...@@ -190,31 +194,32 @@ static inline __device__ T binary_cumsum(int idx, int warp_id, int lane_id, T* s ...@@ -190,31 +194,32 @@ static inline __device__ T binary_cumsum(int idx, int warp_id, int lane_id, T* s
} }
template <typename T> template <typename T>
static inline __device__ T binary_cumsum_exclusive( static inline WITHIN_KERNEL T binary_cumsum_exclusive(
int idx, int warp_id, int lane_id, T* smem, bool value) { int idx, int warp_id, int lane_id, T* smem, bool value) {
// cumsum within 1D thread block, which adds up `value` of all threads // cumsum within 1D thread block, which adds up `value` of all threads
// whose id is *less than* the current thread // whose id is *less than* the current thread
// binary_cumsum(1, 0, 1, 0, 1) -> (0, 1, 1, 2, 2)
// cumsum within warp // cumsum within warp
unsigned int warp_bits = __ballot(value); ga_uint warp_bits = __ballot(value);
T warp_sum = __popc(((1<<lane_id)-1) & warp_bits); T warp_sum = __popc(((1<<lane_id)-1) & warp_bits);
if (lane_id == 0) if (lane_id == 0)
smem[warp_id] = __popc(warp_bits); smem[warp_id] = __popc(warp_bits);
__syncthreads(); local_barrier();
// cumsum across warps in one thread // cumsum across warps in one thread
if (idx == 0) { if (idx == 0) {
int current = 0; int current = 0;
for (int i = 0; i < blockDim.x / 32; ++i) { for (int i = 0; i < LDIM_0 / GA_WARP_SIZE; ++i) {
T v = smem[i]; T v = smem[i];
smem[i] = smem[i]+current; smem[i] = smem[i]+current;
current = current+v; current = current+v;
} }
} }
__syncthreads(); local_barrier();
// load the carry from the preceding warp // load the carry from the preceding warp
if (warp_id >= 1) if (warp_id >= 1)
...@@ -225,13 +230,13 @@ static inline __device__ T binary_cumsum_exclusive( ...@@ -225,13 +230,13 @@ static inline __device__ T binary_cumsum_exclusive(
// apply raw(byte) offset to pointer // apply raw(byte) offset to pointer
template <typename T> template <typename T>
static __device__ inline T* ptr_add(T *ptr, ga_ssize offset) { static WITHIN_KERNEL inline T* ptr_add(T *ptr, ga_ssize offset) {
return (T*)((char*)ptr + offset); return (T*)((char*)ptr + offset);
} }
// get array element using raw(byte) offset // get array element using raw(byte) offset
template <typename T> template <typename T>
static __device__ inline T& ptr_at(T *ptr, ga_ssize offset) { static WITHIN_KERNEL inline T& ptr_at(T *ptr, ga_ssize offset) {
return *((T*)((char*)ptr + offset)); return *((T*)((char*)ptr + offset));
} }
...@@ -250,21 +255,34 @@ KERNEL void k_topk_dense( ...@@ -250,21 +255,34 @@ KERNEL void k_topk_dense(
INPUT_TYPE* src, INPUT_TYPE* src,
$src_strides $src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM} // ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
size_t size) { ga_size size) {
/* extern LOCAL_MEM radix_t smem[];
extern __shared__ radix_t smem[]; ga_ssize LOCAL_MEM bins[RADIX_SIZE]; // TODO: does using 32-bit gives speedup?
ga_ssize __shared__ bins[RADIX_SIZE]; // TODO: does using 32-bit gives speedup? bool is_topk=true, is_topkth=true;
bool is_topk = true;
bool is_topkth = true; // exactly k-th largest
radix_t out_idx; radix_t out_idx;
const size_t idx = threadIdx.x; const ga_size idx = LID_0;
size_t __shared__ k2, exceed; ga_size LOCAL_MEM k2, exceed;
const ga_uint warp_id = idx / 32; const ga_uint warp_id = idx / GA_WARP_SIZE;
const ga_uint lane_id = idx % 32; const ga_uint lane_id = idx % GA_WARP_SIZE;
radix_t *wmem = (radix_t*)(smem) + warp_id * 32; radix_t *wmem = (radix_t*)(smem) + warp_id * GA_WARP_SIZE;
const bool in_range = (idx < size); const bool in_range = (idx < size);
is_topk &= in_range; is_topk &= in_range;
// 0. get the slice for thread block to work on
// TODO if ndim <= 3, use native indexing ? (blockIdx.[xyz])
ga_size gid = GID_0, gidx;
$set_slice
//for(int i=1; i<NDIM; i++) {
// gidx = gid % dims_$${i};
// gid /= dims_$${i};
// dsti = ptr_add(dsti, gidx*dsti_strides_$${i};
// dstv = ptr_add(dstv, gidx*dstv_strides_$${i};
// src = ptr_add(src, gidx*src_strides_$${i});
//}
// get input and its radix friendly form
const INPUT_TYPE xval = in_range ? ptr_at(src, idx*src_strides_0) : (INPUT_TYPE)0; const INPUT_TYPE xval = in_range ? ptr_at(src, idx*src_strides_0) : (INPUT_TYPE)0;
radix_t x = in_range ? RadixConfig<INPUT_TYPE>::convert(xval) : 0; radix_t x = in_range ? RadixConfig<INPUT_TYPE>::convert(xval) : 0;
...@@ -272,17 +290,6 @@ KERNEL void k_topk_dense( ...@@ -272,17 +290,6 @@ KERNEL void k_topk_dense(
if (k<0) { x = ~x; k = -k; } if (k<0) { x = ~x; k = -k; }
if (idx==0) k2 = k; if (idx==0) k2 = k;
// 0. get the slice for thread block to work on
size_t gid = blockIdx.x, gidx;
$set_slice
//for(int i=0; i<NDIM; i++) {
//gidx = gid % dims_$${i};
//gid /= dims_$${i};
//dsti = ptr_add(dsti, gidx*dsti_strides_$${i+1};
//dstv = ptr_add(dstv, gidx*dstv_strides_$${i+1};
//src = ptr_add(src, gidx*src_strides_$${i+1});
//}
// 1. filter is_topk and is_topkth using radix select // 1. filter is_topk and is_topkth using radix select
#pragma unroll #pragma unroll
...@@ -293,18 +300,18 @@ KERNEL void k_topk_dense( ...@@ -293,18 +300,18 @@ KERNEL void k_topk_dense(
#pragma unroll #pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) { for (int bin=0; bin<RADIX_SIZE; ++bin) {
bool incr_bin = (bin == digit) && is_topkth && in_range; bool incr_bin = (bin == digit) && is_topkth && in_range;
unsigned int incr_bin_warp = __ballot(incr_bin); ga_uint incr_bin_warp = __ballot(incr_bin);
if (lane_id==0) if (lane_id==0)
wmem[bin] += __popc(incr_bin_warp); wmem[bin] += __popc(incr_bin_warp);
} }
__syncthreads(); local_barrier();
// sum counts across all warps // sum counts across all warps
// TODO: test in-block parallel sum? // TODO: test in-block parallel sum?
if (idx < RADIX_SIZE) { if (idx < RADIX_SIZE) {
for(int w=32; w<blockDim.x; w+=32) for(int w=GA_WARP_SIZE; w<LDIM_0; w+=GA_WARP_SIZE)
smem[idx] += smem[idx + w]; smem[idx] += smem[idx + w];
} }
__syncthreads(); local_barrier();
// calculate k minus cumsum(count) // calculate k minus cumsum(count)
if (idx<RADIX_SIZE) if (idx<RADIX_SIZE)
...@@ -325,7 +332,7 @@ KERNEL void k_topk_dense( ...@@ -325,7 +332,7 @@ KERNEL void k_topk_dense(
exceed = min(exceed, bins[bin-1]); exceed = min(exceed, bins[bin-1]);
} }
} }
__syncthreads(); local_barrier();
// smem -> count // smem -> count
...@@ -356,7 +363,7 @@ KERNEL void k_topk_dense( ...@@ -356,7 +363,7 @@ KERNEL void k_topk_dense(
// perform binary cumsum on is_topk to determine the indices to put result // perform binary cumsum on is_topk to determine the indices to put result
out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topk); out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topk);
__syncthreads(); local_barrier();
if (is_topk) { if (is_topk) {
$write_value; $write_value;
...@@ -364,5 +371,4 @@ KERNEL void k_topk_dense( ...@@ -364,5 +371,4 @@ KERNEL void k_topk_dense(
$write_index; $write_index;
// ptr_at(dsti, out_idx * dsti_strides_0) = (INDEX_TYPE)idx; // ptr_at(dsti, out_idx * dsti_strides_0) = (INDEX_TYPE)idx;
} }
*/
} }
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论