提交 670c324b authored 作者: Adam Becker's avatar Adam Becker

mixed changes

- fix thread count calculation bug - let write_value and write_index become C macro
上级 13095b4b
...@@ -82,8 +82,6 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -82,8 +82,6 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
set_slice_code = ''.join( set_slice_code = ''.join(
set_slice_code % dict(i=j) for j in range(1, ndim)) set_slice_code % dict(i=j) for j in range(1, ndim))
flags = Kernel.get_flags(node.inputs[0].dtype) flags = Kernel.get_flags(node.inputs[0].dtype)
write_value = 'ptr_at(dstv, out_idx * dstv_strides_0) = xval' if self.return_values else ''
write_index = 'ptr_at(dsti, out_idx * dsti_strides_0) = (INDEX_TYPE)idx' if self.return_indices else ''
subs = dict( subs = dict(
inp_t=ga.dtype_to_ctype(node.inputs[0].dtype), inp_t=ga.dtype_to_ctype(node.inputs[0].dtype),
out_t=ga.dtype_to_ctype(node.outputs[0].dtype), out_t=ga.dtype_to_ctype(node.outputs[0].dtype),
...@@ -94,8 +92,8 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -94,8 +92,8 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
dsti_strides=dsti_strides_code if self.return_indices else '', dsti_strides=dsti_strides_code if self.return_indices else '',
src_strides=src_strides_code, src_strides=src_strides_code,
set_slice=set_slice_code, set_slice=set_slice_code,
write_value=write_value, write_value=int(self.return_values),
write_index=write_index, write_index=int(self.return_indices),
ndim=str(ndim)) ndim=str(ndim))
# substitute "$" macros in kernel code # substitute "$" macros in kernel code
...@@ -192,14 +190,16 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -192,14 +190,16 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
// TODO better scheduling? // TODO better scheduling?
size_t blk[6]; size_t blk[6];
size_t *grd = blk+3; size_t *grd = blk+3;
blk[1] = blk[2] = 1; blk[0] = blk[1] = blk[2] = 1;
grd[0] = grd[1] = grd[2] = 1; grd[0] = grd[1] = grd[2] = 1;
// round up to multiples of warp size // round up to multiples of warp size
blk[0] = ((dims[0] + %(WARP_SIZE)d - 1) / %(WARP_SIZE)d) * %(WARP_SIZE)d;
for(int i=0; i<%(ndim)d; ++i) { for(int i=0; i<%(ndim)d; ++i) {
if (i!=%(axis)d) if (i!=%(axis)d)
grd[0] *= dims[i]; grd[0] *= dims[i];
else
blk[0] = dims[i];
} }
blk[0] = ((blk[0] + %(WARP_SIZE)d - 1) / %(WARP_SIZE)d) * %(WARP_SIZE)d;
%(def_dvstrides)s; %(def_dvstrides)s;
%(def_distrides)s; %(def_distrides)s;
......
...@@ -15,14 +15,14 @@ template <> ...@@ -15,14 +15,14 @@ template <>
struct RadixConfig<float> { struct RadixConfig<float> {
typedef ga_uint RadixType; typedef ga_uint RadixType;
static inline WITHIN_KERNEL RadixType convert(float v) { static inline __device__ RadixType convert(float v) {
RadixType x = __float_as_int(v); RadixType x = __float_as_int(v);
RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000; RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
return (x ^ mask); return (x ^ mask);
} }
static inline WITHIN_KERNEL float deconvert(RadixType v) { static inline __device__ float deconvert(RadixType v) {
RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff; RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff;
return __int_as_float(v ^ mask); return __int_as_float(v ^ mask);
...@@ -30,14 +30,14 @@ struct RadixConfig<float> { ...@@ -30,14 +30,14 @@ struct RadixConfig<float> {
}; };
template <> template <>
struct RadixConfig<ga_uchar> { struct RadixConfig<ga_ubyte> {
typedef ga_uint RadixType; typedef ga_uint RadixType;
static inline WITHIN_KERNEL RadixType convert(ga_uchar v) { static inline __device__ RadixType convert(ga_ubyte v) {
return v; return v;
} }
static inline WITHIN_KERNEL ga_uchar deconvert(RadixType v) { static inline __device__ ga_ubyte deconvert(RadixType v) {
return v; return v;
} }
}; };
...@@ -46,11 +46,11 @@ template <> ...@@ -46,11 +46,11 @@ template <>
struct RadixConfig<char> { struct RadixConfig<char> {
typedef ga_uint RadixType; typedef ga_uint RadixType;
static inline WITHIN_KERNEL RadixType convert(char v) { static inline __device__ RadixType convert(char v) {
return 128u + v; return 128u + v;
} }
static inline WITHIN_KERNEL char deconvert(RadixType v) { static inline __device__ char deconvert(RadixType v) {
return v - 128; return v - 128;
} }
}; };
...@@ -59,12 +59,12 @@ template <> ...@@ -59,12 +59,12 @@ template <>
struct RadixConfig<ga_short> { struct RadixConfig<ga_short> {
typedef ga_uint RadixType; typedef ga_uint RadixType;
static inline WITHIN_KERNEL RadixType convert(ga_short v) { static inline __device__ RadixType convert(ga_short v) {
assert(sizeof(ga_short) == 2); assert(sizeof(ga_short) == 2);
return 32768u + v; return 32768u + v;
} }
static inline WITHIN_KERNEL ga_short deconvert(RadixType v) { static inline __device__ ga_short deconvert(RadixType v) {
return v - 32768; return v - 32768;
} }
}; };
...@@ -73,12 +73,12 @@ template <> ...@@ -73,12 +73,12 @@ template <>
struct RadixConfig<int> { struct RadixConfig<int> {
typedef ga_uint RadixType; typedef ga_uint RadixType;
static inline WITHIN_KERNEL RadixType convert(int v) { static inline __device__ RadixType convert(int v) {
assert(sizeof(int) == 4); assert(sizeof(int) == 4);
return 2147483648u + v; return 2147483648u + v;
} }
static inline WITHIN_KERNEL int deconvert(RadixType v) { static inline __device__ int deconvert(RadixType v) {
return v - 2147483648u; return v - 2147483648u;
} }
}; };
...@@ -87,12 +87,12 @@ template <> ...@@ -87,12 +87,12 @@ template <>
struct RadixConfig<long> { struct RadixConfig<long> {
typedef unsigned long long int RadixType; typedef unsigned long long int RadixType;
static inline WITHIN_KERNEL RadixType convert(long v) { static inline __device__ RadixType convert(long v) {
assert(sizeof(long) == 8); assert(sizeof(long) == 8);
return 9223372036854775808ull + v; return 9223372036854775808ull + v;
} }
static inline WITHIN_KERNEL long deconvert(RadixType v) { static inline __device__ long deconvert(RadixType v) {
return v - 9223372036854775808ull; return v - 9223372036854775808ull;
} }
}; };
...@@ -101,13 +101,13 @@ template <> ...@@ -101,13 +101,13 @@ template <>
struct RadixConfig<double> { struct RadixConfig<double> {
typedef unsigned long long int RadixType; typedef unsigned long long int RadixType;
static inline WITHIN_KERNEL RadixType convert(double v) { static inline __device__ RadixType convert(double v) {
RadixType x = __double_as_longlong(v); RadixType x = __double_as_longlong(v);
RadixType mask = -((x >> 63)) | 0x8000000000000000; RadixType mask = -((x >> 63)) | 0x8000000000000000;
return (x ^ mask); return (x ^ mask);
} }
static inline WITHIN_KERNEL double deconvert(RadixType v) { static inline __device__ double deconvert(RadixType v) {
RadixType mask = ((v >> 63) - 1) | 0x8000000000000000; RadixType mask = ((v >> 63) - 1) | 0x8000000000000000;
return __longlong_as_double(v ^ mask); return __longlong_as_double(v ^ mask);
} }
...@@ -118,7 +118,7 @@ template <> ...@@ -118,7 +118,7 @@ template <>
struct RadixConfig<half> { struct RadixConfig<half> {
typedef ga_uint RadixType; typedef ga_uint RadixType;
static inline WITHIN_KERNEL RadixType convert(half v) { static inline __device__ RadixType convert(half v) {
#if defined(__CUDACC_VER__) && __CUDACC_VER__ >= 80000 #if defined(__CUDACC_VER__) && __CUDACC_VER__ >= 80000
RadixType x = __half_as_ushort(v); RadixType x = __half_as_ushort(v);
RadixType mask = -((x >> 15)) | 0x8000; RadixType mask = -((x >> 15)) | 0x8000;
...@@ -129,7 +129,7 @@ struct RadixConfig<half> { ...@@ -129,7 +129,7 @@ struct RadixConfig<half> {
#endif #endif
} }
static inline WITHIN_KERNEL half deconvert(RadixType v) { static inline __device__ half deconvert(RadixType v) {
#if defined(__CUDACC_VER__) && __CUDACC_VER__ >= 80000 #if defined(__CUDACC_VER__) && __CUDACC_VER__ >= 80000
RadixType mask = ((v >> 15) - 1) | 0x8000; RadixType mask = ((v >> 15) - 1) | 0x8000;
return __ushort_as_half(v ^ mask); return __ushort_as_half(v ^ mask);
...@@ -152,13 +152,15 @@ struct RadixConfig<half> { ...@@ -152,13 +152,15 @@ struct RadixConfig<half> {
#define RADIX_MASK(n) ((RADIX_SIZE-1) << (n*RADIX_BITS)) #define RADIX_MASK(n) ((RADIX_SIZE-1) << (n*RADIX_BITS))
#define RADIX_DIGITS(T) (bitsof(T)/RADIX_BITS) #define RADIX_DIGITS(T) (bitsof(T)/RADIX_BITS)
#define radix_t RadixConfig<INPUT_TYPE>::RadixType #define radix_t RadixConfig<INPUT_TYPE>::RadixType
#define WRITE_VALUE $write_value
#define WRITE_INDEX $write_index
#if RADIX_SIZE > GA_WARP_SIZE #if RADIX_SIZE > 32
#error "RADIX_SIZE must be smaller than warp size" #error "RADIX_SIZE must be smaller than warp size (32)"
#endif #endif
template <typename T> template <typename T>
static inline WITHIN_KERNEL T binary_cumsum( static inline __device__ T binary_cumsum(
int idx, int warp_id, int lane_id, T* smem, bool value) { int idx, int warp_id, int lane_id, T* smem, bool value) {
// cumsum within 1D thread block, which adds up `value` of all threads // cumsum within 1D thread block, which adds up `value` of all threads
// whose id is *no greater than* the current thread // whose id is *no greater than* the current thread
...@@ -194,7 +196,7 @@ static inline WITHIN_KERNEL T binary_cumsum( ...@@ -194,7 +196,7 @@ static inline WITHIN_KERNEL T binary_cumsum(
} }
template <typename T> template <typename T>
static inline WITHIN_KERNEL T binary_cumsum_exclusive( static inline __device__ T binary_cumsum_exclusive(
int idx, int warp_id, int lane_id, T* smem, bool value) { int idx, int warp_id, int lane_id, T* smem, bool value) {
// cumsum within 1D thread block, which adds up `value` of all threads // cumsum within 1D thread block, which adds up `value` of all threads
// whose id is *less than* the current thread // whose id is *less than* the current thread
...@@ -230,13 +232,13 @@ static inline WITHIN_KERNEL T binary_cumsum_exclusive( ...@@ -230,13 +232,13 @@ static inline WITHIN_KERNEL T binary_cumsum_exclusive(
// apply raw(byte) offset to pointer // apply raw(byte) offset to pointer
template <typename T> template <typename T>
static WITHIN_KERNEL inline T* ptr_add(T *ptr, ga_ssize offset) { static __device__ inline T* ptr_add(T *ptr, ga_ssize offset) {
return (T*)((char*)ptr + offset); return (T*)((char*)ptr + offset);
} }
// get array element using raw(byte) offset // get array element using raw(byte) offset
template <typename T> template <typename T>
static WITHIN_KERNEL inline T& ptr_at(T *ptr, ga_ssize offset) { static __device__ inline T& ptr_at(T *ptr, ga_ssize offset) {
return *((T*)((char*)ptr + offset)); return *((T*)((char*)ptr + offset));
} }
...@@ -366,9 +368,11 @@ KERNEL void k_topk_dense( ...@@ -366,9 +368,11 @@ KERNEL void k_topk_dense(
local_barrier(); local_barrier();
if (is_topk) { if (is_topk) {
$write_value; #if WRITE_VALUE == 1
// ptr_at(dstv, out_idx * dstv_strides_0) = xval; ptr_at(dstv, out_idx * dstv_strides_0) = xval;
$write_index; #endif
// ptr_at(dsti, out_idx * dsti_strides_0) = (INDEX_TYPE)idx; #if WRITE_INDEX == 1
ptr_at(dsti, out_idx * dsti_strides_0) = (INDEX_TYPE)idx;
#endif
} }
} }
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论