提交 a3624d6f authored 作者: Adam Becker's avatar Adam Becker

CLUDA code -> pure CUDA

上级 62ce362d
......@@ -126,7 +126,7 @@ struct RadixConfig {
// We use this to enable radix selection of floating-point values.
// This also gives a relative order for NaNs, but that's ok, as they
// will all be adjacent
typedef ga_uint RadixType;
typedef unsigned int RadixType;
static inline __device__ RadixType convert(T v) {
return (RadixType)v;
}
......@@ -137,17 +137,17 @@ struct RadixConfig {
};
template <>
struct RadixConfig<ga_float> {
typedef ga_uint RadixType;
struct RadixConfig<float> {
typedef unsigned int RadixType;
static inline __device__ RadixType convert(ga_float v) {
static inline __device__ RadixType convert(float v) {
RadixType x = __float_as_int(v);
RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
return (x ^ mask);
}
static inline __device__ ga_float deconvert(RadixType v) {
static inline __device__ float deconvert(RadixType v) {
RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff;
return __int_as_float(v ^ mask);
......@@ -155,16 +155,16 @@ struct RadixConfig<ga_float> {
};
template <>
struct RadixConfig<ga_double> {
typedef ga_ulong RadixType;
struct RadixConfig<double> {
typedef unsigned long long RadixType;
static inline __device__ RadixType convert(ga_double v) {
static inline __device__ RadixType convert(double v) {
RadixType x = __double_as_longlong(v);
RadixType mask = -((x >> 63)) | 0x8000000000000000;
return (x ^ mask);
}
static inline __device__ ga_double deconvert(RadixType v) {
static inline __device__ double deconvert(RadixType v) {
RadixType mask = ((v >> 63) - 1) | 0x8000000000000000;
return __longlong_as_double(v ^ mask);
}
......@@ -172,52 +172,52 @@ struct RadixConfig<ga_double> {
template <>
struct RadixConfig<ga_byte> {
typedef ga_uint RadixType;
struct RadixConfig<char> {
typedef unsigned int RadixType;
static inline __device__ RadixType convert(ga_byte v) {
static inline __device__ RadixType convert(char v) {
return 128u + v;
}
static inline __device__ ga_byte deconvert(RadixType v) {
static inline __device__ char deconvert(RadixType v) {
return v - 128;
}
};
template <>
struct RadixConfig<ga_short> {
typedef ga_uint RadixType;
struct RadixConfig<short> {
typedef unsigned int RadixType;
static inline __device__ RadixType convert(ga_short v) {
assert(sizeof(ga_short) == 2);
static inline __device__ RadixType convert(short v) {
assert(sizeof(short) == 2);
return 32768u ^ v;
}
static inline __device__ ga_short deconvert(RadixType v) {
static inline __device__ short deconvert(RadixType v) {
return v - 32768;
}
};
template <>
struct RadixConfig<ga_int> {
typedef ga_uint RadixType;
struct RadixConfig<int> {
typedef unsigned int RadixType;
static inline __device__ RadixType convert(ga_int v) {
static inline __device__ RadixType convert(int v) {
assert(sizeof(int) == 4);
return 2147483648u + v;
}
static inline __device__ ga_int deconvert(RadixType v) {
static inline __device__ int deconvert(RadixType v) {
return v - 2147483648u;
}
};
template <>
struct RadixConfig<ga_long> {
typedef ga_ulong RadixType;
struct RadixConfig<long long> {
typedef unsigned long long RadixType;
static inline __device__ RadixType convert(ga_long v) {
assert(sizeof(ga_long) == 8);
static inline __device__ RadixType convert(long long v) {
assert(sizeof(long long) == 8);
return 9223372036854775808ull + v;
}
......@@ -229,19 +229,19 @@ struct RadixConfig<ga_long> {
#define USE_HALF $use_half
#if USE_HALF == 1
// since ga_half is ushort, use macro to protect this part is necessary
// since half is ushort, using macro to protect this part is necessary
template <>
struct RadixConfig<ga_half> {
typedef ga_uint RadixType;
struct RadixConfig<unsigned short> {
typedef unsigned int RadixType;
static inline __device__ RadixType convert(ga_half v) {
static inline __device__ RadixType convert(unsigned short v) {
RadixType mask = -(((RadixType)v >> 15)) | 0x8000;
return (v ^ mask);
}
static inline __device__ ga_half deconvert(RadixType v) {
static inline __device__ unsigned short deconvert(RadixType v) {
RadixType mask = ((v >> 15) - 1) | 0x8000;
return (ga_half)(v ^ mask);
return (unsigned short)(v ^ mask);
}
};
#endif // USE_HALF
......@@ -274,7 +274,7 @@ static inline __device__ T binary_cumsum(
// binary_cumsum(1, 0, 1, 0, 1) -> (1, 1, 2, 2, 3)
// cumsum within warp
ga_uint warp_bits = __ballot(value);
unsigned int warp_bits = __ballot(value);
T warp_sum = __popc(lane_mask_le() & warp_bits);
if (lane_id() == 0)
......@@ -285,7 +285,7 @@ static inline __device__ T binary_cumsum(
// cumsum across warps in one thread
if (idx == 0) {
T sum = smem[0];
for (int i = 1; i < LDIM_0 / GA_WARP_SIZE; ++i) {
for (int i = 1; i < blockDim.x / GA_WARP_SIZE; ++i) {
sum += smem[i];
smem[i] = sum;
}
......@@ -309,7 +309,7 @@ static inline __device__ T binary_cumsum_exclusive(
// binary_cumsum_excl(1, 0, 1, 0, 1) -> (0, 1, 1, 2, 2)
// cumsum within warp
ga_uint warp_bits = __ballot(value);
unsigned int warp_bits = __ballot(value);
T warp_sum = __popc(lane_mask_lt() & warp_bits);
if (lane_id() == 0)
......@@ -320,7 +320,7 @@ static inline __device__ T binary_cumsum_exclusive(
// cumsum across warps in one thread
if (idx == 0) {
T sum = smem[0];
for (int i = 1; i < LDIM_0 / GA_WARP_SIZE; ++i) {
for (int i = 1; i < blockDim.x / GA_WARP_SIZE; ++i) {
sum += smem[i];
smem[i] = sum;
}
......@@ -337,19 +337,19 @@ static inline __device__ T binary_cumsum_exclusive(
// apply raw(byte) offset to pointer
template <typename T>
static __device__ inline T* ptr_add(T *ptr, ga_ssize offset) {
static __device__ inline T* ptr_add(T *ptr, ssize_t offset) {
return (T*)((char*)ptr + offset);
}
// get array element using raw(byte) offset
template <typename T>
static __device__ inline T& ptr_at(T *ptr, ga_ssize offset) {
static __device__ inline T& ptr_at(T *ptr, ssize_t offset) {
return *((T*)((char*)ptr + offset));
}
// read array element using raw(byte) offset
template <typename T>
static __device__ inline T ptr_read_cached(T *ptr, ga_ssize offset) {
static __device__ inline T ptr_read_cached(T *ptr, ssize_t offset) {
return __ldg(((T*)((char*)ptr + offset)));
}
......@@ -6,32 +6,32 @@
// works when length on axis is within max allowed threads in block (1024)
KERNEL void k_topk_dense(
$dims
// ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM}
// size_t dims_1, ssize_t dims_2, ... , dims_$${NDIM}
$dstv
// INPUT_TYPE *dstv
$dstv_strides
// ga_ssize dstv_strides_0, ga_ssize dstv_strides_1, ... , dstv_strides_$${NDIM}
// ssize_t dstv_strides_0, ssize_t dstv_strides_1, ... , dstv_strides_$${NDIM}
$dsti
// INDEX_TYPE *dsti
$dsti_strides
// ga_ssize dsti_strides_0, ga_ssize dsti_strides_1, ... , dsti_strides_$${NDIM}
ga_ssize k,
// ssize_t dsti_strides_0, ssize_t dsti_strides_1, ... , dsti_strides_$${NDIM}
ssize_t k,
INPUT_TYPE* src,
$src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
ga_size size) {
LOCAL_MEM ga_int smem[32 * RADIX_SIZE];
LOCAL_MEM ga_int k2;
const ga_uint idx = LID_0;
// ssize_t src_strides_0, ssize_t src_strides_1, ... , src_strides_$${NDIM}
size_t size) {
__shared__ int smem[32 * RADIX_SIZE];
__shared__ int k2;
const unsigned int idx = threadIdx.x;
bool is_topk= (idx < size);
bool is_topkth = is_topk;
ga_size out_idx;
size_t out_idx;
const ga_ubyte warp_id = idx / GA_WARP_SIZE;
const unsigned char warp_id = idx / GA_WARP_SIZE;
// 0. get the slice for thread block to work on
ga_size gid = GID_0, gidx;
size_t gid = blockIdx.x, gidx;
$set_slice
// $$set_slice expands into:
//for(int i=1; i<NDIM; i++) {
......@@ -55,22 +55,22 @@ KERNEL void k_topk_dense(
#pragma unroll
for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
const ga_int digit = Bitfield<radix_t>::get(x, i, RADIX_BITS);
/*ga_int digit = (x>>i) & (RADIX_SIZE-1);*/
const int digit = Bitfield<radix_t>::get(x, i, RADIX_BITS);
/*int digit = (x>>i) & (RADIX_SIZE-1);*/
// count within warp
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
bool vote = (bin == digit) && is_topkth;
ga_uint votes = __ballot(vote);
unsigned int votes = __ballot(vote);
if (lane_id()==0)
smem[bin + RADIX_SIZE*warp_id] = __popc(votes);
}
local_barrier();
// sum counts across all warps
if (idx < RADIX_SIZE) {
ga_int sum = smem[idx];
int sum = smem[idx];
#pragma unroll
for(int w=RADIX_SIZE; w<LDIM_0*RADIX_SIZE / GA_WARP_SIZE; w+=RADIX_SIZE)
for(int w=RADIX_SIZE; w<blockDim.x*RADIX_SIZE / GA_WARP_SIZE; w+=RADIX_SIZE)
sum += smem[idx + w];
smem[idx] = sum;
}
......@@ -79,7 +79,7 @@ KERNEL void k_topk_dense(
// find the bucket and update k2
// smem[:RADIX_SIZE:-1] = k2 - cumsum(smem[:RADIX_SIZE-1:-1])
if (idx == 0) {
ga_int sum = k2;
int sum = k2;
#pragma unroll
for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
sum -= smem[bin];
......
......@@ -14,13 +14,13 @@ __device__ DataType find_pattern(DataType* smem,
CountType stride,
RadixType known_bits,
RadixType known_bits_mask) {
if (LID_0 < 32)
smem[LID_0] = 0;
if (threadIdx.x < 32)
smem[threadIdx.x] = 0;
local_barrier();
// All threads participate in the loop, in order to sync on the flag
for (CountType i = LID_0; i < (slice_size + (CountType)LDIM_0-1); i += LDIM_0) {
for (CountType i = threadIdx.x; i < (slice_size + (CountType)blockDim.x-1); i += blockDim.x) {
bool in_range = (i < slice_size);
DataType v = in_range ? ptr_read_cached(data, i*stride) : 0;
......@@ -64,14 +64,14 @@ __device__ void count_radix_masked(CountType counts[RADIX_SIZE],
for (int i = 0; i < RADIX_SIZE; ++i)
counts[i] = 0;
if (LID_0 < RADIX_SIZE)
smem[LID_0] = 0;
if (threadIdx.x < RADIX_SIZE)
smem[threadIdx.x] = 0;
local_barrier();
// Scan over all the data. Upon a read, the warp will accumulate
// counts per each digit in the radix using warp voting.
for (CountType i = LID_0; i < slice_size; i += LDIM_0) {
for (CountType i = threadIdx.x; i < slice_size; i += blockDim.x) {
RadixType val = RadixConfig<DataType>::convert(ptr_read_cached(data, i*stride));
bool has_val = ((val & known_bits_mask) == known_bits);
......@@ -196,32 +196,32 @@ __device__ void radix_select(DataType* data,
KERNEL void KERNEL_NAME(
$dims
// ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM}
// size_t dims_1, ssize_t dims_2, ... , dims_$${NDIM}
$dstv
// INPUT_TYPE *dstv
$dstv_strides
// ga_ssize dstv_strides_0, ga_ssize dstv_strides_1, ... , dstv_strides_$${NDIM}
// ssize_t dstv_strides_0, ssize_t dstv_strides_1, ... , dstv_strides_$${NDIM}
$dsti
// INDEX_TYPE *dsti
$dsti_strides
// ga_ssize dsti_strides_0, ga_ssize dsti_strides_1, ... , dsti_strides_$${NDIM}
ga_ssize k,
// ssize_t dsti_strides_0, ssize_t dsti_strides_1, ... , dsti_strides_$${NDIM}
ssize_t k,
INPUT_TYPE* src,
$src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
ga_size size) {
LOCAL_MEM COUNT_TYPE smem[32];
// ssize_t src_strides_0, ssize_t src_strides_1, ... , src_strides_$${NDIM}
size_t size) {
__shared__ COUNT_TYPE smem[32];
INPUT_TYPE topkth_value;
const bool order = (k>0);
k = (order ? k : -k);
const ga_int idx = LID_0;
const ga_int warp_id = idx / GA_WARP_SIZE;
const int idx = threadIdx.x;
const int warp_id = idx / GA_WARP_SIZE;
// get the slice for thread block to work on
// size <- the axis to work on
// dims_1+ <- batched dimensions
ga_uint gid = GID_0, gidx;
unsigned int gid = blockIdx.x, gidx;
$set_slice
// $$set_slice expands into:
//for(int i=1; i<NDIM; i++) {
......@@ -250,10 +250,10 @@ KERNEL void KERNEL_NAME(
// All threads need to participate in the loop and the cumsum
// but not necessarily in the load; hence loop bounds being rounded
// up to a multiple of the block dim.
COUNT_TYPE iter_bound = size + LDIM_0-1;
COUNT_TYPE iter_bound = size + blockDim.x-1;
INDEX_TYPE write_base = 0;
for (int i = idx; i < iter_bound; i += LDIM_0) {
for (int i = idx; i < iter_bound; i += blockDim.x) {
bool in_range = (i < size);
INPUT_TYPE v = in_range ? ptr_read_cached(src, i*src_strides_0) : 0;
bool has_topk;
......@@ -264,7 +264,7 @@ KERNEL void KERNEL_NAME(
}
int index = binary_cumsum_exclusive(idx, warp_id, smem, has_topk);
int carry = smem[LDIM_0 / 32 - 1];
int carry = smem[blockDim.x / 32 - 1];
if (has_topk) {
COUNT_TYPE write_idx = write_base + index;
......@@ -281,13 +281,13 @@ KERNEL void KERNEL_NAME(
COUNT_TYPE topk_remaining = (k - write_base);
for (COUNT_TYPE i = idx; i < iter_bound; i += LDIM_0) {
for (COUNT_TYPE i = idx; i < iter_bound; i += blockDim.x) {
bool in_range = (i < size);
INPUT_TYPE v = in_range ? ptr_read_cached(src, i*src_strides_0) : 0;
bool has_topk = in_range && (v == topkth_value);
int index = binary_cumsum_exclusive(idx, warp_id, smem, has_topk);
int carry = smem[LDIM_0 / 32 - 1];
int carry = smem[blockDim.x / 32 - 1];
if (has_topk && index < topk_remaining) {
COUNT_TYPE write_idx = write_base + index;
......
......@@ -63,9 +63,9 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
# prepare "$" macros
if device_type == b'cuda':
ndim = node.inputs[0].ndim
dstv_strides_code = ''.join('ga_ssize dstv_strides_%d, ' % i for i in range(ndim))
dsti_strides_code = ''.join('ga_ssize dsti_strides_%d, ' % i for i in range(ndim))
src_strides_code = ''.join('ga_ssize src_strides_%d, ' % i for i in range(ndim))
dstv_strides_code = ''.join('ssize_t dstv_strides_%d, ' % i for i in range(ndim))
dsti_strides_code = ''.join('ssize_t dsti_strides_%d, ' % i for i in range(ndim))
src_strides_code = ''.join('ssize_t src_strides_%d, ' % i for i in range(ndim))
set_slice_code = '''
gidx = gid %% dims_%(i)d;
gid /= dims_%(i)d;
......@@ -80,7 +80,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
subs = dict(
inp_t=ga.dtype_to_ctype(node.inputs[0].dtype),
out_t=ga.dtype_to_ctype(self.idx_dtype),
dims=''.join('ga_size dims_%d, ' % i for i in range(1, ndim)),
dims=''.join('size_t dims_%d, ' % i for i in range(1, ndim)),
dstv='INPUT_TYPE *dstv,' if self.return_values else '',
dsti='INDEX_TYPE *dsti,' if self.return_indices else '',
dstv_strides=dstv_strides_code if self.return_values else '',
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论