提交 95f6eda6 authored 作者: Adam Becker's avatar Adam Becker

mixed changes

- add multidim support for top_k - use unified TopKOp, can implement topk, argtopk, and both
上级 ece8b25b
差异被折叠。
...@@ -113,6 +113,7 @@ struct RadixConfig<double> { ...@@ -113,6 +113,7 @@ struct RadixConfig<double> {
} }
}; };
#ifdef USE_HALF
template <> template <>
struct RadixConfig<half> { struct RadixConfig<half> {
typedef unsigned int RadixType; typedef unsigned int RadixType;
...@@ -138,23 +139,29 @@ struct RadixConfig<half> { ...@@ -138,23 +139,29 @@ struct RadixConfig<half> {
#endif #endif
} }
}; };
#endif
// $$inp_t should be replaced in c_code
// we cannot use templated __global__ because gpuarray API does not support it yet
#define NDIM $ndim
#define INPUT_TYPE $inp_t
#define INDEX_TYPE $out_t
#define bitsof(T) (sizeof(T)*8) #define bitsof(T) (sizeof(T)*8)
#define RADIX_BITS 2 #define RADIX_BITS 2
#define RADIX_SIZE (1<<RADIX_BITS) #define RADIX_SIZE (1<<RADIX_BITS)
#define RADIX_MASK(n) ((RADIX_SIZE-1) << (n*RADIX_BITS)) #define RADIX_MASK(n) ((RADIX_SIZE-1) << (n*RADIX_BITS))
#define RADIX_DIGITS(T) (bitsof(T)/RADIX_BITS) #define RADIX_DIGITS(T) (bitsof(T)/RADIX_BITS)
#define radix_t RadixConfig<T>::RadixType #define radix_t RadixConfig<INPUT_TYPE>::RadixType
#if RADIX_SIZE > 32 #if RADIX_SIZE > 32
#error "RADIX_SIZE must be smaller than warp size (32)" #error "RADIX_SIZE must be smaller than warp size (32)"
#endif #endif
template <typename T> template <typename T>
inline __device__ T binary_cumsum(int idx, int warp_id, int lane_id, T* smem, bool value) { static inline __device__ T binary_cumsum(int idx, int warp_id, int lane_id, T* smem, bool value) {
// cumsum within 1D thread block, which adds up `value` of all threads whose id is *no greater than* the current thread // cumsum within 1D thread block, which adds up `value` of all threads whose id is *no greater than* the current thread
// cumsum within warp // cumsum within warp
unsigned int warp_bits = __ballot(in); unsigned int warp_bits = __ballot(value);
T warp_sum = __popc(((2<<lane_id)-1) & warp_bits); T warp_sum = __popc(((2<<lane_id)-1) & warp_bits);
if (lane_id == 0) if (lane_id == 0)
...@@ -175,20 +182,21 @@ inline __device__ T binary_cumsum(int idx, int warp_id, int lane_id, T* smem, bo ...@@ -175,20 +182,21 @@ inline __device__ T binary_cumsum(int idx, int warp_id, int lane_id, T* smem, bo
__syncthreads(); __syncthreads();
// load the carry from the preceding warp // load the carry from the preceding warp
if (warp >= 1) { if (warp_id >= 1) {
warp_sum = warp_sum+smem[warp - 1]; warp_sum = warp_sum+smem[warp_id - 1];
} }
return warp_sum; return warp_sum;
} }
template <typename T> template <typename T>
inline __device__ T binary_cumsum_exclusive( static inline __device__ T binary_cumsum_exclusive(
int idx, int warp_id, int lane_id, T* smem, bool value) { int idx, int warp_id, int lane_id, T* smem, bool value) {
// cumsum within 1D thread block, which adds up `value` of all threads // cumsum within 1D thread block, which adds up `value` of all threads
// whose id is *less than* the current thread // whose id is *less than* the current thread
// cumsum within warp // cumsum within warp
unsigned int warp_bits = __ballot(in); unsigned int warp_bits = __ballot(value);
T warp_sum = __popc(((1<<lane_id)-1) & warp_bits); T warp_sum = __popc(((1<<lane_id)-1) & warp_bits);
if (lane_id == 0) if (lane_id == 0)
...@@ -209,35 +217,77 @@ inline __device__ T binary_cumsum_exclusive( ...@@ -209,35 +217,77 @@ inline __device__ T binary_cumsum_exclusive(
__syncthreads(); __syncthreads();
// load the carry from the preceding warp // load the carry from the preceding warp
if (warp >= 1) { if (warp_id >= 1)
warp_sum = warp_sum+smem[warp - 1]; warp_sum += smem[warp_id - 1];
}
return warp_sum; return warp_sum;
} }
// apply raw(byte) offset to pointer
template <typename T>
static __device__ inline T* ptr_add(T *ptr, ga_ssize offset) {
return (T*)((char*)ptr + offset);
}
// get array element using raw(byte) offset
template <typename T> template <typename T>
void __global__ topk_1d_contig_kernel(T* dst, T* src, size_t size, size_t k) { static __device__ inline T& ptr_at(T *ptr, ga_ssize offset) {
extern radix_t smem[]; return *((T*)((char*)ptr + offset));
ssize_t bins[RADIX_SIZE]; // TODO: does using 32-bit gives speedup? }
KERNEL void k_topk_dense(
$dims
// ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM}
$dstv
// INPUT_TYPE *dstv
$dstv_strides
// ga_ssize dstv_strides_0, ga_ssize dstv_strides_1, ... , dstv_strides_$${NDIM}
$dsti
// INDEX_TYPE *dsti
$dsti_strides
// ga_ssize dsti_strides_0, ga_ssize dsti_strides_1, ... , dsti_strides_$${NDIM}
ga_ssize k,
INPUT_TYPE* src,
$src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
size_t size) {
/*
extern __shared__ radix_t smem[];
ga_ssize __shared__ bins[RADIX_SIZE]; // TODO: does using 32-bit gives speedup?
bool is_topk = true; bool is_topk = true;
bool is_topkth = true; // exactly k-th largest bool is_topkth = true; // exactly k-th largest
radix_t out_idx;
const size_t idx = threadIdx.x;
size_t __shared__ k2, exceed;
const ga_uint warp_id = idx / 32;
const ga_uint lane_id = idx % 32;
radix_t *wmem = (radix_t*)(smem) + warp_id * 32;
const bool in_range = (idx < size);
is_topk &= in_range;
const INPUT_TYPE xval = in_range ? ptr_at(src, idx*src_strides_0) : (INPUT_TYPE)0;
radix_t x = in_range ? RadixConfig<INPUT_TYPE>::convert(xval) : 0;
// resolve negative k
if (k<0) { x = ~x; k = -k; }
if (idx==0) k2 = k;
// 0. get the slice for thread block to work on
size_t gid = blockIdx.x, gidx;
$set_slice
//for(int i=0; i<NDIM; i++) {
//gidx = gid % dims_$${i};
//gid /= dims_$${i};
//dsti = ptr_add(dsti, gidx*dsti_strides_$${i+1};
//dstv = ptr_add(dstv, gidx*dstv_strides_$${i+1};
//src = ptr_add(src, gidx*src_strides_$${i+1});
//}
// 1. filter is_topk and is_topkth using radix select
size_t idx = threadIdx.x;
size_t k2 = k, exceed;
int warp_id = idx / 32;
int lane_id = idx % 32;
radix_t wmem = smem + warp_id * 32;
bool in_range = (idx < size);
RadixConfig<T>::RadixType x = in_range ? RadixConfig<T>::convert(src[idx]) : 0;
// 1. find the kth largest value using radix select
// 1.1 for each radix mask, count
smem[threadIdx.x] = 0;
#pragma unroll #pragma unroll
for (int i=bitsof(T)-RADIX_BITS; i; i-=RADIX_BITS) { for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
radix_t mask = (RADIX_SIZE-1)<<i; smem[idx] = 0;
int digit = (x>>i) & (RADIX_SIZE-1); int digit = (x>>i) & (RADIX_SIZE-1);
// count within warp // count within warp
#pragma unroll #pragma unroll
...@@ -245,43 +295,34 @@ void __global__ topk_1d_contig_kernel(T* dst, T* src, size_t size, size_t k) { ...@@ -245,43 +295,34 @@ void __global__ topk_1d_contig_kernel(T* dst, T* src, size_t size, size_t k) {
bool incr_bin = (bin == digit) && is_topkth && in_range; bool incr_bin = (bin == digit) && is_topkth && in_range;
unsigned int incr_bin_warp = __ballot(incr_bin); unsigned int incr_bin_warp = __ballot(incr_bin);
if (lane_id==0) if (lane_id==0)
wmem[bin] += __popc(bin_warp); wmem[bin] += __popc(incr_bin_warp);
} }
__syncthreads(); __syncthreads();
// sum counts across all warps // sum counts across all warps
// TODO: test in-block parallel sum? // TODO: test in-block parallel sum?
if (idx<RADIX_SIZE) if (idx < RADIX_SIZE) {
bins[idx] = 0; for(int w=32; w<blockDim.x; w+=32)
if (idx==0) { smem[idx] += smem[idx + w];
for(int w=1; w<blockDim.x/32; ++w) {
#pragma unroll
for(int bin=0; bin<RADIX_SIZE; ++bin) {
smem[bin] += wmem[bin];
}
}
} }
__syncthreads(); __syncthreads();
// broadcast sum result
if (idx < RADIX_SIZE)
smem[idx] = bins[idx];
__syncthreads();
// calculate k minus cumsum(count) // calculate k minus cumsum(count)
exceed = -k; // how many the number of is_topk exceeds k if (idx<RADIX_SIZE)
bins[idx] = 0;
if (idx == 0) { if (idx == 0) {
bins[0] = k2 - smem[0]; exceed = k; // how many the number of is_topk exceeds k
if (bins[0] > 0) bins[RADIX_SIZE-1] = k2 - smem[RADIX_SIZE-1];
k2 = bins[0]; if (bins[RADIX_SIZE-1] > 0)
else if (bins[0] < 0) k2 = bins[RADIX_SIZE-1];
exceed = max(exceed, bins[0]); else
exceed = min(exceed, bins[RADIX_SIZE-1]);
#pragma unroll #pragma unroll
for(int bin=1; bin<RADIX_SIZE; ++bin) { for(int bin=RADIX_SIZE-1; bin; --bin) {
bins[bin] = bins[bin-1] - smem[bin]; bins[bin-1] = bins[bin] - smem[bin-1];
if (bins[bin] > 0) if (bins[bin-1] > 0)
k2 = bins[bin]; k2 = bins[bin-1];
else if (bins[bin] < 0) else
exceed = max(exceed, bins[bin]); exceed = min(exceed, bins[bin-1]);
} }
} }
__syncthreads(); __syncthreads();
...@@ -290,7 +331,7 @@ void __global__ topk_1d_contig_kernel(T* dst, T* src, size_t size, size_t k) { ...@@ -290,7 +331,7 @@ void __global__ topk_1d_contig_kernel(T* dst, T* src, size_t size, size_t k) {
// smem -> count // smem -> count
// bins -> k2 - cumsum(count) // bins -> k2 - cumsum(count)
if (is_topk && is_topkth) { if (is_topk && is_topkth) {
ssize_t icount = bins[digit]; ga_ssize icount = bins[digit];
if (icount > 0) { if (icount > 0) {
is_topkth = false; is_topkth = false;
} else if (icount < 0) { } else if (icount < 0) {
...@@ -305,17 +346,23 @@ void __global__ topk_1d_contig_kernel(T* dst, T* src, size_t size, size_t k) { ...@@ -305,17 +346,23 @@ void __global__ topk_1d_contig_kernel(T* dst, T* src, size_t size, size_t k) {
} }
// 2. find the index of output array, if exists // 2. find the index of output array, if exists
//
// top_kth value may not be unique, so we need to
// count how many is needed
if (exceed != 0) {
// top_kth value may not be unique, so we need to
// perform binary cumsum on is_topkth to drop exceeding top-kth values // perform binary cumsum on is_topkth to drop exceeding top-kth values
radix_t topkth_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topkth); out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topkth);
if (topkth_idx >= exceed) is_topk &= (out_idx < exceed);
is_topk = false; }
// perform binary cumsum on is_topk to determine idx to put result // perform binary cumsum on is_topk to determine the indices to put result
topkth_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topkth); out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topk);
if (is_topk) __syncthreads();
dst[topkth_idx] = idx;
if (is_topk) {
$write_value;
// ptr_at(dstv, out_idx * dstv_strides_0) = xval;
$write_index;
// ptr_at(dsti, out_idx * dsti_strides_0) = (INDEX_TYPE)idx;
}
*/
} }
...@@ -40,7 +40,7 @@ from theano.tensor import nnet # used for softmax, sigmoid, etc. ...@@ -40,7 +40,7 @@ from theano.tensor import nnet # used for softmax, sigmoid, etc.
from theano.gradient import Rop, Lop, grad, numeric_grad, verify_grad, \ from theano.gradient import Rop, Lop, grad, numeric_grad, verify_grad, \
jacobian, hessian, consider_constant jacobian, hessian, consider_constant
from theano.tensor.sort import sort, argsort, argtopk from theano.tensor.sort import sort, argsort, topk, argtopk, topk_and_argtopk
from theano.tensor.extra_ops import (DiffOp, bincount, squeeze, from theano.tensor.extra_ops import (DiffOp, bincount, squeeze,
repeat, bartlett, fill_diagonal, fill_diagonal_offset, repeat, bartlett, fill_diagonal, fill_diagonal_offset,
cumsum, cumprod) cumsum, cumprod)
......
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论