提交 1728cd8c authored 作者: Adam Becker's avatar Adam Becker

GPU speedup

上级 aa14269f
...@@ -48,6 +48,76 @@ POSSIBILITY OF SUCH DAMAGE. ...@@ -48,6 +48,76 @@ POSSIBILITY OF SUCH DAMAGE.
#endif #endif
__device__ __forceinline__ int lane_id() {
int id;
asm("mov.s32 %0, %laneid;" : "=r"(id) );
return id;
}
__device__ __forceinline__ unsigned lane_mask_lt() {
unsigned mask;
asm("mov.u32 %0, %%lanemask_lt;" : "=r"(mask));
return mask;
}
__device__ __forceinline__ unsigned lane_mask_le() {
unsigned mask;
asm("mov.u32 %0, %%lanemask_le;" : "=r"(mask));
return mask;
}
__device__ __forceinline__ unsigned lane_mask_gt() {
unsigned mask;
asm("mov.u32 %0, %%lanemask_gt;" : "=r"(mask));
return mask;
}
__device__ __forceinline__ unsigned lane_mask_ge() {
unsigned mask;
asm("mov.u32 %0, %%lanemask_ge;" : "=r"(mask));
return mask;
}
template <typename T>
struct Bitfield {};
template <>
struct Bitfield<unsigned int> {
static __device__ __forceinline__
unsigned int get(unsigned int val, int pos, int len) {
unsigned int ret;
asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len));
return ret;
}
static __device__ __forceinline__
unsigned int set(unsigned int val, unsigned int toInsert, int pos, int len) {
unsigned int ret;
asm("bfi.b32 %0, %1, %2, %3, %4;" :
"=r"(ret) : "r"(toInsert), "r"(val), "r"(pos), "r"(len));
return ret;
}
};
template <>
struct Bitfield<unsigned long long int> {
static __device__ __forceinline__
unsigned long long int get(unsigned long long int val, int pos, int len) {
unsigned long long int ret;
asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len));
return ret;
}
static __device__ __forceinline__
unsigned long long int set(unsigned long long int val, unsigned long long int toInsert, int pos, int len) {
unsigned long long int ret;
asm("bfi.b64 %0, %1, %2, %3, %4;" :
"=l"(ret) : "l"(toInsert), "l"(val), "r"(pos), "r"(len));
return ret;
}
};
template <typename T> template <typename T>
struct RadixConfig { struct RadixConfig {
// Converts a type (maybe float) to an integer representation with the same // Converts a type (maybe float) to an integer representation with the same
...@@ -182,10 +252,6 @@ struct RadixConfig<ga_half> { ...@@ -182,10 +252,6 @@ struct RadixConfig<ga_half> {
#define INPUT_TYPE $inp_t #define INPUT_TYPE $inp_t
#define INDEX_TYPE $out_t #define INDEX_TYPE $out_t
#define bitsof(T) (sizeof(T)*8) #define bitsof(T) (sizeof(T)*8)
#define RADIX_BITS 2
#define RADIX_SIZE (1<<RADIX_BITS)
#define RADIX_MASK(n) ((RADIX_SIZE-1) << (n*RADIX_BITS))
#define RADIX_DIGITS(T) (bitsof(T)/RADIX_BITS)
#define radix_t RadixConfig<INPUT_TYPE>::RadixType #define radix_t RadixConfig<INPUT_TYPE>::RadixType
#define WRITE_VALUE $write_value #define WRITE_VALUE $write_value
#define WRITE_INDEX $write_index #define WRITE_INDEX $write_index
...@@ -194,28 +260,28 @@ struct RadixConfig<ga_half> { ...@@ -194,28 +260,28 @@ struct RadixConfig<ga_half> {
#error "RADIX_SIZE must be smaller than warp size (32)" #error "RADIX_SIZE must be smaller than warp size (32)"
#endif #endif
static inline __device__ ga_size binary_cumsum( template <typename T>
int idx, int warp_id, int lane_id, ga_size* smem, bool value) { static inline __device__ T binary_cumsum(
int idx, int warp_id, T* smem, bool value) {
// cumsum within 1D thread block, which adds up `value` of all threads // cumsum within 1D thread block, which adds up `value` of all threads
// whose id is *no greater than* the current thread // whose id is *no greater than* the current thread
// binary_cumsum(1, 0, 1, 0, 1) -> (1, 1, 2, 2, 3) // binary_cumsum(1, 0, 1, 0, 1) -> (1, 1, 2, 2, 3)
// cumsum within warp // cumsum within warp
ga_uint warp_bits = __ballot(value); ga_uint warp_bits = __ballot(value);
ga_size warp_sum = __popc(((2<<lane_id)-1) & warp_bits); T warp_sum = __popc(lane_mask_le() & warp_bits);
if (lane_id == 0) if (lane_id() == 0)
smem[warp_id] = __popc(warp_bits); smem[warp_id] = __popc(warp_bits);
local_barrier(); local_barrier();
// cumsum across warps in one thread // cumsum across warps in one thread
if (idx == 0) { if (idx == 0) {
int current = 0; T sum = smem[0];
for (int i = 0; i < LDIM_0 / GA_WARP_SIZE; ++i) { for (int i = 1; i < LDIM_0 / GA_WARP_SIZE; ++i) {
ga_size v = smem[i]; sum += smem[i];
smem[i] = smem[i]+current; smem[i] = sum;
current = current+v;
} }
} }
...@@ -229,28 +295,28 @@ static inline __device__ ga_size binary_cumsum( ...@@ -229,28 +295,28 @@ static inline __device__ ga_size binary_cumsum(
return warp_sum; return warp_sum;
} }
static inline __device__ ga_size binary_cumsum_exclusive( template <typename T>
int idx, int warp_id, int lane_id, ga_size* smem, bool value) { static inline __device__ T binary_cumsum_exclusive(
int idx, int warp_id, T* smem, bool value) {
// cumsum within 1D thread block, which adds up `value` of all threads // cumsum within 1D thread block, which adds up `value` of all threads
// whose id is *less than* the current thread // whose id is *less than* the current thread
// binary_cumsum_excl(1, 0, 1, 0, 1) -> (0, 1, 1, 2, 2) // binary_cumsum_excl(1, 0, 1, 0, 1) -> (0, 1, 1, 2, 2)
// cumsum within warp // cumsum within warp
ga_uint warp_bits = __ballot(value); ga_uint warp_bits = __ballot(value);
ga_size warp_sum = __popc(((1<<lane_id)-1) & warp_bits); T warp_sum = __popc(lane_mask_lt() & warp_bits);
if (lane_id == 0) if (lane_id() == 0)
smem[warp_id] = __popc(warp_bits); smem[warp_id] = __popc(warp_bits);
local_barrier(); local_barrier();
// cumsum across warps in one thread // cumsum across warps in one thread
if (idx == 0) { if (idx == 0) {
int current = 0; T sum = smem[0];
for (int i = 0; i < LDIM_0 / GA_WARP_SIZE; ++i) { for (int i = 1; i < LDIM_0 / GA_WARP_SIZE; ++i) {
ga_size v = smem[i]; sum += smem[i];
smem[i] = smem[i]+current; smem[i] = sum;
current = current+v;
} }
} }
......
#define RADIX_BITS 4
#define RADIX_SIZE (1<<RADIX_BITS)
#define RADIX_MASK(n) ((RADIX_SIZE-1) << (n*RADIX_BITS))
#define RADIX_DIGITS(T) (bitsof(T)/RADIX_BITS)
// works when length on axis is within max allowed threads in block (1024) // works when length on axis is within max allowed threads in block (1024)
KERNEL void k_topk_dense( KERNEL void k_topk_dense(
$dims $dims
...@@ -15,17 +20,14 @@ KERNEL void k_topk_dense( ...@@ -15,17 +20,14 @@ KERNEL void k_topk_dense(
$src_strides $src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM} // ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
ga_size size) { ga_size size) {
LOCAL_MEM ga_size smem[32 * RADIX_SIZE]; LOCAL_MEM ga_int smem[32 * RADIX_SIZE];
ga_ssize LOCAL_MEM bins[RADIX_SIZE+1]; // TODO: does using 32-bit gives good speedup? LOCAL_MEM ga_int k2;
bool is_topk=true, is_topkth=true; const ga_uint idx = LID_0;
bool is_topk= (idx < size);
bool is_topkth = is_topk;
ga_size out_idx; ga_size out_idx;
const ga_ushort idx = LID_0;
ga_size LOCAL_MEM k2, exceed;
const ga_ubyte warp_id = idx / GA_WARP_SIZE; const ga_ubyte warp_id = idx / GA_WARP_SIZE;
const ga_ubyte lane_id = idx % GA_WARP_SIZE;
const bool in_range = (idx < size);
is_topk &= in_range;
// 0. get the slice for thread block to work on // 0. get the slice for thread block to work on
...@@ -41,90 +43,84 @@ KERNEL void k_topk_dense( ...@@ -41,90 +43,84 @@ KERNEL void k_topk_dense(
//} //}
// get input and its radix friendly form // get input and its radix friendly form
const INPUT_TYPE xval = in_range ? ptr_at(src, idx*src_strides_0) : (INPUT_TYPE)0; const INPUT_TYPE xval = is_topk ? ptr_at(src, idx*src_strides_0) : (INPUT_TYPE)0;
radix_t x = in_range ? RadixConfig<INPUT_TYPE>::convert(xval) : 0; radix_t x = RadixConfig<INPUT_TYPE>::convert(xval);
// resolve negative k // resolve negative k
if (k<0) { x = ~x; k = -k; } if (k<0) { x = ~x; k = -k; }
if (idx==0) { if (idx==0)
k2 = k; k2 = k;
bins[RADIX_SIZE] = 1;
}
// 1. filter is_topk and is_topkth using radix select // 1. filter is_topk and is_topkth using radix select
#pragma unroll #pragma unroll
for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) { for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
int digit = (x>>i) & (RADIX_SIZE-1); const ga_int digit = Bitfield<radix_t>::get(x, i, RADIX_BITS);
/*ga_int digit = (x>>i) & (RADIX_SIZE-1);*/
// count within warp // count within warp
#pragma unroll #pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) { for (int bin=0; bin<RADIX_SIZE; ++bin) {
bool incr_bin = (bin == digit) && is_topkth && in_range; bool vote = (bin == digit) && is_topkth;
ga_uint incr_bin_warp = __ballot(incr_bin); ga_uint votes = __ballot(vote);
if (lane_id==0) if (lane_id()==0)
smem[bin + RADIX_SIZE*warp_id] = __popc(incr_bin_warp); smem[bin + RADIX_SIZE*warp_id] = __popc(votes);
} }
local_barrier(); local_barrier();
// sum counts across all warps // sum counts across all warps
// TODO: test in-block parallel sum?
if (idx < RADIX_SIZE) { if (idx < RADIX_SIZE) {
ga_int sum = smem[idx];
#pragma unroll
for(int w=RADIX_SIZE; w<LDIM_0*RADIX_SIZE / GA_WARP_SIZE; w+=RADIX_SIZE) for(int w=RADIX_SIZE; w<LDIM_0*RADIX_SIZE / GA_WARP_SIZE; w+=RADIX_SIZE)
smem[idx] += smem[idx + w]; sum += smem[idx + w];
smem[idx] = sum;
} }
local_barrier(); local_barrier();
// bins = k - cumsum(smem[:RADIX_SIZE]) // smem[:RADIX_SIZE:-1] = k2 - cumsum(smem[:RADIX_SIZE-1:-1])
if (idx == 0) { if (idx == 0) {
bins[RADIX_SIZE-1] = k2 - smem[RADIX_SIZE-1]; ga_int sum = k2;
if (bins[RADIX_SIZE-1] > 0)
k2 = bins[RADIX_SIZE-1];
#pragma unroll #pragma unroll
for (int bin=RADIX_SIZE-1; bin; --bin) { for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
bins[bin-1] = bins[bin] - smem[bin-1]; sum -= smem[bin];
if (bins[bin-1] > 0) smem[bin] = sum;
k2 = bins[bin-1]; k2 = (sum > 0) ? sum : k2;
} }
smem[RADIX_SIZE] = 1;
} }
local_barrier(); local_barrier();
if (is_topkth) {
// smem -> count is_topk &= (smem[digit+1] > 0);
// bins -> k2 - cumsum(count) is_topkth &= (smem[digit] <= 0) && (smem[digit+1] > 0);
if (is_topk && is_topkth) {
ga_ssize icount = bins[digit];
if (icount > 0) {
is_topkth = false;
} else if (bins[digit+1] <= 0) {
is_topk = false;
is_topkth = false;
}
} }
local_barrier();
} }
// set k2 as number of exceeding values
if (idx==0) { if (idx==0) {
#pragma unroll #pragma unroll
for (int bin=RADIX_SIZE-1; bin>=0; --bin) { for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
if (bins[bin] <= 0) { if (smem[bin] <= 0) {
exceed = -bins[bin]; k2 = -smem[bin];
break; break;
} }
} }
} }
local_barrier(); local_barrier();
// 2. find the index of output array, if exists // 2. find the index of output array, if exists
if (exceed != 0) { if (k2 != 0) {
// top_kth value may not be unique, so we need to // top_kth value may not be unique, so we need to
// perform binary cumsum on is_topkth to drop exceeding top-kth values // perform binary cumsum on is_topkth to drop exceeding top-kth values
out_idx = binary_cumsum_exclusive(idx, warp_id, lane_id, smem, is_topkth); out_idx = binary_cumsum_exclusive(idx, warp_id, smem, is_topkth);
is_topk &= ((!is_topkth) || out_idx>=exceed); if ((out_idx >= k2) && is_topkth)
is_topk = false;
local_barrier();
} }
// perform binary cumsum on is_topk to determine the indices to put result // perform binary cumsum on is_topk to determine the indices to put result
out_idx = binary_cumsum_exclusive(idx, warp_id, lane_id, smem, is_topk); out_idx = binary_cumsum_exclusive(idx, warp_id, smem, is_topk);
local_barrier();
if (is_topk) { if (is_topk) {
#if WRITE_VALUE == 1 #if WRITE_VALUE == 1
......
// works when length on axis is larger than max allowed threads in block (1024) #define RADIX_BITS 2
#define RADIX_SIZE (1<<RADIX_BITS)
#define RADIX_MASK(n) ((RADIX_SIZE-1) << (n*RADIX_BITS))
#define RADIX_DIGITS(T) (bitsof(T)/RADIX_BITS)
// works when length on axis is in [1025, 2^31-1]
KERNEL void k_topk_dense_large( KERNEL void k_topk_dense_large(
$dims $dims
// ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM} // ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM}
...@@ -15,23 +20,22 @@ KERNEL void k_topk_dense_large( ...@@ -15,23 +20,22 @@ KERNEL void k_topk_dense_large(
$src_strides $src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM} // ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
ga_size size, ga_ushort inp_per_thread) { ga_size size, ga_ushort inp_per_thread) {
LOCAL_MEM ga_size smem[32 * RADIX_SIZE]; LOCAL_MEM ga_int smem[32];
LOCAL_MEM radix_t known_bits, known_bits_mask; LOCAL_MEM radix_t known_bits;
ga_size out_idx; LOCAL_MEM ga_uint k2;
ga_size LOCAL_MEM write_base; int counts[RADIX_SIZE];
unsigned out_idx;
INPUT_TYPE xval; INPUT_TYPE xval;
radix_t x; radix_t x;
ga_int i;
bool in_range, is_topk; bool in_range, is_topk;
const ga_size idx = LID_0; const ga_uint idx = LID_0;
ga_size LOCAL_MEM k2; const ga_uint inp_idx = idx * inp_per_thread;
const ga_ushort warp_id = idx / GA_WARP_SIZE; const ga_int warp_id = idx / GA_WARP_SIZE;
const ga_ushort lane_id = idx % GA_WARP_SIZE;
// 0. get the slice for thread block to work on // 0. get the slice for thread block to work on
// TODO if ndim <= 3, use native indexing ? (blockIdx.[xyz]) // TODO if ndim <= 3, use native indexing ? (blockIdx.[xyz])
ga_size gid = GID_0, gidx; ga_uint gid = GID_0, gidx;
$set_slice $set_slice
//for(int i=1; i<NDIM; i++) { //for(int i=1; i<NDIM; i++) {
// gidx = gid % dims_$${i}; // gidx = gid % dims_$${i};
...@@ -42,55 +46,45 @@ KERNEL void k_topk_dense_large( ...@@ -42,55 +46,45 @@ KERNEL void k_topk_dense_large(
//} //}
src = ptr_add(src, idx*inp_per_thread*src_strides_0); src = ptr_add(src, idx*inp_per_thread*src_strides_0);
LOCAL_MEM radix_t inv_bits;
if (idx==0) { if (idx==0) {
known_bits = known_bits_mask = 0; known_bits = 0;
k2 = abs(k); k2 = (k>=0) ? k : -k;
inv_bits = (k>=0) ? 0 : (~0);
write_base = 0;
} }
const radix_t inv_bits = (k>=0) ? 0 : ~0;
if (k<0) { k = -k; } if (k<0) { k = -k; }
local_barrier(); local_barrier();
// 1. find bits of top-k-th value using radix select // 1. find bits of top-k-th value using radix select
#pragma unroll #pragma unroll
for (i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) { for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
/*for (i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i*=-1) {*/
if (lane_id == 0) {
#pragma unroll #pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) { for (int j=0; j<RADIX_SIZE; ++j)
smem[bin + warp_id*RADIX_SIZE] = 0; counts[j] = 0;
} if (warp_id == 0)
} smem[idx] = 0;
local_barrier(); local_barrier();
// count within warp
for (int j=0; j<inp_per_thread; ++j) { for (int j=0; j<inp_per_thread; ++j) {
in_range = (idx*inp_per_thread+j) < size; in_range = (inp_idx+j) < size;
xval = in_range ? ptr_read(src, j*src_strides_0) : (INPUT_TYPE)0; xval = in_range ? ptr_read(src, j*src_strides_0) : (INPUT_TYPE)0;
x = inv_bits^RadixConfig<INPUT_TYPE>::convert(xval); x = inv_bits^RadixConfig<INPUT_TYPE>::convert(xval);
ga_int digit = (int)((x>>i) & (RADIX_SIZE-1)); ga_int digit = (int)((x>>i) & (RADIX_SIZE-1));
// count within warp
#pragma unroll #pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) { for (int bin=0; bin<RADIX_SIZE; ++bin) {
bool incr_bin = ( bool incr_bin = (
(bin == digit) && (bin == digit) &&
((x&known_bits_mask) == known_bits) && ((x >> (i+RADIX_BITS)) == known_bits) && in_range);
in_range); counts[bin] += __popc(__ballot(incr_bin));
ga_uint incr_bin_warp = __ballot(incr_bin);
if (lane_id==0)
smem[bin + RADIX_SIZE*warp_id] += __popc(incr_bin_warp);
} }
} }
local_barrier(); local_barrier();
// sum counts across all warps // sum counts across all warps
// TODO: test in-block parallel sum? if (lane_id() < RADIX_SIZE) {
if (idx < RADIX_SIZE) { atomicAdd(&smem[lane_id()], counts[lane_id()]);
for(int w=RADIX_SIZE;
w<(LDIM_0/ GA_WARP_SIZE)*RADIX_SIZE;
w+=RADIX_SIZE)
smem[idx] += smem[idx + w];
} }
local_barrier(); local_barrier();
...@@ -99,8 +93,7 @@ KERNEL void k_topk_dense_large( ...@@ -99,8 +93,7 @@ KERNEL void k_topk_dense_large(
#pragma unroll #pragma unroll
for (int bin=RADIX_SIZE-1; bin>=0; --bin) { for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
if (smem[bin] >= k2) { if (smem[bin] >= k2) {
known_bits |= (((radix_t)bin) << i); known_bits = (known_bits << RADIX_BITS) | bin;
known_bits_mask |= (((radix_t)(RADIX_SIZE-1)) << i);
break; break;
} else } else
k2 -= smem[bin]; k2 -= smem[bin];
...@@ -109,50 +102,55 @@ KERNEL void k_topk_dense_large( ...@@ -109,50 +102,55 @@ KERNEL void k_topk_dense_large(
local_barrier(); local_barrier();
} }
// now we use k2 for base index to write output
if (idx == 0)
k2 = 0;
local_barrier();
// 2. write values smaller than top-kth // 2. write values smaller than top-kth
for (i=0; i<inp_per_thread; ++i) { for (int i=0; i<inp_per_thread; ++i) {
in_range = (idx*inp_per_thread+i) < size; in_range = (inp_idx+i) < size;
xval = in_range ? ptr_read(src, i*src_strides_0) : (INPUT_TYPE)0; xval = in_range ? ptr_read(src, i*src_strides_0) : (INPUT_TYPE)0;
x = inv_bits ^ RadixConfig<INPUT_TYPE>::convert(xval); x = inv_bits ^ RadixConfig<INPUT_TYPE>::convert(xval);
is_topk = (x > known_bits) && in_range; is_topk = (x > known_bits) && in_range;
out_idx = binary_cumsum(idx, warp_id, lane_id, smem, is_topk); out_idx = binary_cumsum(idx, warp_id, smem, is_topk);
if (is_topk) { if (is_topk) {
#if WRITE_VALUE == 1 #if WRITE_VALUE == 1
ptr_at(dstv, (out_idx+write_base-1) * dstv_strides_0) = xval; ptr_at(dstv, (out_idx+k2-1) * dstv_strides_0) = xval;
#endif #endif
#if WRITE_INDEX == 1 #if WRITE_INDEX == 1
ptr_at(dsti, (out_idx+write_base-1) * dsti_strides_0) = (INDEX_TYPE)(idx*inp_per_thread + i); ptr_at(dsti, (out_idx+k2-1) * dsti_strides_0) = (INDEX_TYPE)(idx*inp_per_thread + i);
#endif #endif
} }
local_barrier(); local_barrier();
if (idx == blockDim.x - 1) if (idx == blockDim.x - 1)
write_base += out_idx; k2 += out_idx;
local_barrier(); local_barrier();
} }
// 3. write values equal to top-kth // 3. write values equal to top-kth
for (i=0; i<inp_per_thread; ++i) { for (int i=0; i<inp_per_thread; ++i) {
in_range = (idx*inp_per_thread+i) < size; in_range = (inp_idx+i) < size;
xval = in_range ? ptr_read(src, i*src_strides_0) : (INPUT_TYPE)0; xval = in_range ? ptr_read(src, i*src_strides_0) : (INPUT_TYPE)0;
x = inv_bits ^ RadixConfig<INPUT_TYPE>::convert(xval); x = inv_bits ^ RadixConfig<INPUT_TYPE>::convert(xval);
is_topk = (x == known_bits) && in_range; is_topk = (x == known_bits) && in_range;
out_idx = binary_cumsum(idx, warp_id, lane_id, smem, is_topk); out_idx = binary_cumsum(idx, warp_id, smem, is_topk);
is_topk = ((out_idx+write_base) <= abs(k)) && is_topk; is_topk &= (out_idx+k2) <= k;
if (is_topk) { if (is_topk) {
#if WRITE_VALUE == 1 #if WRITE_VALUE == 1
ptr_at(dstv, (out_idx+write_base-1) * dstv_strides_0) = xval; ptr_at(dstv, (out_idx+k2-1) * dstv_strides_0) = xval;
#endif #endif
#if WRITE_INDEX == 1 #if WRITE_INDEX == 1
ptr_at(dsti, (out_idx+write_base-1) * dsti_strides_0) = (INDEX_TYPE)(idx*inp_per_thread + i); ptr_at(dsti, (out_idx+k2-1) * dsti_strides_0) = (INDEX_TYPE)(inp_idx+ i);
#endif #endif
} }
local_barrier(); local_barrier();
if (idx == blockDim.x - 1) if (idx == blockDim.x - 1)
write_base += out_idx; k2 += out_idx;
local_barrier(); local_barrier();
if(write_base >= abs(k)) if(k2 >= k)
break; break;
} }
} }
......
...@@ -20,6 +20,7 @@ except ImportError as e: ...@@ -20,6 +20,7 @@ except ImportError as e:
pass pass
# TODO GPU sort / argsort # TODO GPU sort / argsort
# TODO support when k >= 2^31
class GpuTopKOp(GpuKernelBase, TopKOp): class GpuTopKOp(GpuKernelBase, TopKOp):
...@@ -34,7 +35,8 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -34,7 +35,8 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
self, axis=-1, self, axis=-1,
idx_dtype='int64', idx_dtype='int64',
return_values=True, return_values=True,
return_indices=True): return_indices=True
):
GpuKernelBase.__init__(self) GpuKernelBase.__init__(self)
TopKOp.__init__( TopKOp.__init__(
self, axis=axis, self, axis=axis,
...@@ -206,10 +208,14 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -206,10 +208,14 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
PyExc_ValueError, PyExc_ValueError,
"topk: k cannot larger than size on specified axis %(axis)d"); "topk: k cannot larger than size on specified axis %(axis)d");
%(fail)s; %(fail)s;
} else if (odims[%(axis)d] > 0x7fffffffu) {
PyErr_SetString(
PyExc_ValueError,
"topk: on GPU, k cannot larger or equal than 2^31");
%(fail)s;
} }
%(prep_output)s %(prep_output)s
// TODO better scheduling?
size_t blk[6]; size_t blk[6];
size_t *grd = blk+3; size_t *grd = blk+3;
blk[0] = blk[1] = blk[2] = 1; blk[0] = blk[1] = blk[2] = 1;
...@@ -227,7 +233,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -227,7 +233,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
%(def_distrides)s; %(def_distrides)s;
const ssize_t *sstrides = PyGpuArray_STRIDES(%(x)s); const ssize_t *sstrides = PyGpuArray_STRIDES(%(x)s);
// inputs per thread // inputs per thread
unsigned short ipt = (dims[%(axis)d] + (%(MAX_TPB)d/2)-1) / (%(MAX_TPB)d/2); unsigned short ipt = (dims[%(axis)d] + (%(MAX_TPB)d / 2)-1) / (%(MAX_TPB)d / 2);
void* args[] = { void* args[] = {
%(dims)s %(dims)s
%(params_dv)s %(params_dv)s
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论