提交 670c324b authored 作者: Adam Becker's avatar Adam Becker

mixed changes

- fix thread count calculation bug - let write_value and write_index become C macro
上级 13095b4b
......@@ -82,8 +82,6 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
set_slice_code = ''.join(
set_slice_code % dict(i=j) for j in range(1, ndim))
flags = Kernel.get_flags(node.inputs[0].dtype)
write_value = 'ptr_at(dstv, out_idx * dstv_strides_0) = xval' if self.return_values else ''
write_index = 'ptr_at(dsti, out_idx * dsti_strides_0) = (INDEX_TYPE)idx' if self.return_indices else ''
subs = dict(
inp_t=ga.dtype_to_ctype(node.inputs[0].dtype),
out_t=ga.dtype_to_ctype(node.outputs[0].dtype),
......@@ -94,8 +92,8 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
dsti_strides=dsti_strides_code if self.return_indices else '',
src_strides=src_strides_code,
set_slice=set_slice_code,
write_value=write_value,
write_index=write_index,
write_value=int(self.return_values),
write_index=int(self.return_indices),
ndim=str(ndim))
# substitute "$" macros in kernel code
......@@ -192,14 +190,16 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
// TODO better scheduling?
size_t blk[6];
size_t *grd = blk+3;
blk[1] = blk[2] = 1;
blk[0] = blk[1] = blk[2] = 1;
grd[0] = grd[1] = grd[2] = 1;
// round up to multiples of warp size
blk[0] = ((dims[0] + %(WARP_SIZE)d - 1) / %(WARP_SIZE)d) * %(WARP_SIZE)d;
for(int i=0; i<%(ndim)d; ++i) {
if (i!=%(axis)d)
grd[0] *= dims[i];
else
blk[0] = dims[i];
}
blk[0] = ((blk[0] + %(WARP_SIZE)d - 1) / %(WARP_SIZE)d) * %(WARP_SIZE)d;
%(def_dvstrides)s;
%(def_distrides)s;
......
......@@ -15,14 +15,14 @@ template <>
struct RadixConfig<float> {
typedef ga_uint RadixType;
static inline WITHIN_KERNEL RadixType convert(float v) {
static inline __device__ RadixType convert(float v) {
RadixType x = __float_as_int(v);
RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
return (x ^ mask);
}
static inline WITHIN_KERNEL float deconvert(RadixType v) {
static inline __device__ float deconvert(RadixType v) {
RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff;
return __int_as_float(v ^ mask);
......@@ -30,14 +30,14 @@ struct RadixConfig<float> {
};
template <>
struct RadixConfig<ga_uchar> {
struct RadixConfig<ga_ubyte> {
typedef ga_uint RadixType;
static inline WITHIN_KERNEL RadixType convert(ga_uchar v) {
static inline __device__ RadixType convert(ga_ubyte v) {
return v;
}
static inline WITHIN_KERNEL ga_uchar deconvert(RadixType v) {
static inline __device__ ga_ubyte deconvert(RadixType v) {
return v;
}
};
......@@ -46,11 +46,11 @@ template <>
struct RadixConfig<char> {
typedef ga_uint RadixType;
static inline WITHIN_KERNEL RadixType convert(char v) {
static inline __device__ RadixType convert(char v) {
return 128u + v;
}
static inline WITHIN_KERNEL char deconvert(RadixType v) {
static inline __device__ char deconvert(RadixType v) {
return v - 128;
}
};
......@@ -59,12 +59,12 @@ template <>
struct RadixConfig<ga_short> {
typedef ga_uint RadixType;
static inline WITHIN_KERNEL RadixType convert(ga_short v) {
static inline __device__ RadixType convert(ga_short v) {
assert(sizeof(ga_short) == 2);
return 32768u + v;
}
static inline WITHIN_KERNEL ga_short deconvert(RadixType v) {
static inline __device__ ga_short deconvert(RadixType v) {
return v - 32768;
}
};
......@@ -73,12 +73,12 @@ template <>
struct RadixConfig<int> {
typedef ga_uint RadixType;
static inline WITHIN_KERNEL RadixType convert(int v) {
static inline __device__ RadixType convert(int v) {
assert(sizeof(int) == 4);
return 2147483648u + v;
}
static inline WITHIN_KERNEL int deconvert(RadixType v) {
static inline __device__ int deconvert(RadixType v) {
return v - 2147483648u;
}
};
......@@ -87,12 +87,12 @@ template <>
struct RadixConfig<long> {
typedef unsigned long long int RadixType;
static inline WITHIN_KERNEL RadixType convert(long v) {
static inline __device__ RadixType convert(long v) {
assert(sizeof(long) == 8);
return 9223372036854775808ull + v;
}
static inline WITHIN_KERNEL long deconvert(RadixType v) {
static inline __device__ long deconvert(RadixType v) {
return v - 9223372036854775808ull;
}
};
......@@ -101,13 +101,13 @@ template <>
struct RadixConfig<double> {
typedef unsigned long long int RadixType;
static inline WITHIN_KERNEL RadixType convert(double v) {
static inline __device__ RadixType convert(double v) {
RadixType x = __double_as_longlong(v);
RadixType mask = -((x >> 63)) | 0x8000000000000000;
return (x ^ mask);
}
static inline WITHIN_KERNEL double deconvert(RadixType v) {
static inline __device__ double deconvert(RadixType v) {
RadixType mask = ((v >> 63) - 1) | 0x8000000000000000;
return __longlong_as_double(v ^ mask);
}
......@@ -118,7 +118,7 @@ template <>
struct RadixConfig<half> {
typedef ga_uint RadixType;
static inline WITHIN_KERNEL RadixType convert(half v) {
static inline __device__ RadixType convert(half v) {
#if defined(__CUDACC_VER__) && __CUDACC_VER__ >= 80000
RadixType x = __half_as_ushort(v);
RadixType mask = -((x >> 15)) | 0x8000;
......@@ -129,7 +129,7 @@ struct RadixConfig<half> {
#endif
}
static inline WITHIN_KERNEL half deconvert(RadixType v) {
static inline __device__ half deconvert(RadixType v) {
#if defined(__CUDACC_VER__) && __CUDACC_VER__ >= 80000
RadixType mask = ((v >> 15) - 1) | 0x8000;
return __ushort_as_half(v ^ mask);
......@@ -152,13 +152,15 @@ struct RadixConfig<half> {
#define RADIX_MASK(n) ((RADIX_SIZE-1) << (n*RADIX_BITS))
#define RADIX_DIGITS(T) (bitsof(T)/RADIX_BITS)
#define radix_t RadixConfig<INPUT_TYPE>::RadixType
#define WRITE_VALUE $write_value
#define WRITE_INDEX $write_index
#if RADIX_SIZE > GA_WARP_SIZE
#error "RADIX_SIZE must be smaller than warp size"
#if RADIX_SIZE > 32
#error "RADIX_SIZE must be smaller than warp size (32)"
#endif
template <typename T>
static inline WITHIN_KERNEL T binary_cumsum(
static inline __device__ T binary_cumsum(
int idx, int warp_id, int lane_id, T* smem, bool value) {
// cumsum within 1D thread block, which adds up `value` of all threads
// whose id is *no greater than* the current thread
......@@ -194,7 +196,7 @@ static inline WITHIN_KERNEL T binary_cumsum(
}
template <typename T>
static inline WITHIN_KERNEL T binary_cumsum_exclusive(
static inline __device__ T binary_cumsum_exclusive(
int idx, int warp_id, int lane_id, T* smem, bool value) {
// cumsum within 1D thread block, which adds up `value` of all threads
// whose id is *less than* the current thread
......@@ -230,13 +232,13 @@ static inline WITHIN_KERNEL T binary_cumsum_exclusive(
// apply raw(byte) offset to pointer
template <typename T>
static WITHIN_KERNEL inline T* ptr_add(T *ptr, ga_ssize offset) {
static __device__ inline T* ptr_add(T *ptr, ga_ssize offset) {
return (T*)((char*)ptr + offset);
}
// get array element using raw(byte) offset
template <typename T>
static WITHIN_KERNEL inline T& ptr_at(T *ptr, ga_ssize offset) {
static __device__ inline T& ptr_at(T *ptr, ga_ssize offset) {
return *((T*)((char*)ptr + offset));
}
......@@ -366,9 +368,11 @@ KERNEL void k_topk_dense(
local_barrier();
if (is_topk) {
$write_value;
// ptr_at(dstv, out_idx * dstv_strides_0) = xval;
$write_index;
// ptr_at(dsti, out_idx * dsti_strides_0) = (INDEX_TYPE)idx;
#if WRITE_VALUE == 1
ptr_at(dstv, out_idx * dstv_strides_0) = xval;
#endif
#if WRITE_INDEX == 1
ptr_at(dsti, out_idx * dsti_strides_0) = (INDEX_TYPE)idx;
#endif
}
}
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论