提交 95f6eda6 authored 作者: Adam Becker's avatar Adam Becker

mixed changes

- add multidim support for top_k - use unified TopKOp, can implement topk, argtopk, and both
上级 ece8b25b
差异被折叠。
......@@ -113,6 +113,7 @@ struct RadixConfig<double> {
}
};
#ifdef USE_HALF
template <>
struct RadixConfig<half> {
typedef unsigned int RadixType;
......@@ -138,23 +139,29 @@ struct RadixConfig<half> {
#endif
}
};
#endif
// $$inp_t should be replaced in c_code
// we cannot use templated __global__ because gpuarray API does not support it yet
#define NDIM $ndim
#define INPUT_TYPE $inp_t
#define INDEX_TYPE $out_t
#define bitsof(T) (sizeof(T)*8)
#define RADIX_BITS 2
#define RADIX_SIZE (1<<RADIX_BITS)
#define RADIX_MASK(n) ((RADIX_SIZE-1) << (n*RADIX_BITS))
#define RADIX_DIGITS(T) (bitsof(T)/RADIX_BITS)
#define radix_t RadixConfig<T>::RadixType
#define radix_t RadixConfig<INPUT_TYPE>::RadixType
#if RADIX_SIZE > 32
#error "RADIX_SIZE must be smaller than warp size (32)"
#endif
template <typename T>
inline __device__ T binary_cumsum(int idx, int warp_id, int lane_id, T* smem, bool value) {
static inline __device__ T binary_cumsum(int idx, int warp_id, int lane_id, T* smem, bool value) {
// cumsum within 1D thread block, which adds up `value` of all threads whose id is *no greater than* the current thread
// cumsum within warp
unsigned int warp_bits = __ballot(in);
unsigned int warp_bits = __ballot(value);
T warp_sum = __popc(((2<<lane_id)-1) & warp_bits);
if (lane_id == 0)
......@@ -175,20 +182,21 @@ inline __device__ T binary_cumsum(int idx, int warp_id, int lane_id, T* smem, bo
__syncthreads();
// load the carry from the preceding warp
if (warp >= 1) {
warp_sum = warp_sum+smem[warp - 1];
if (warp_id >= 1) {
warp_sum = warp_sum+smem[warp_id - 1];
}
return warp_sum;
}
template <typename T>
inline __device__ T binary_cumsum_exclusive(
static inline __device__ T binary_cumsum_exclusive(
int idx, int warp_id, int lane_id, T* smem, bool value) {
// cumsum within 1D thread block, which adds up `value` of all threads
// whose id is *less than* the current thread
// cumsum within warp
unsigned int warp_bits = __ballot(in);
unsigned int warp_bits = __ballot(value);
T warp_sum = __popc(((1<<lane_id)-1) & warp_bits);
if (lane_id == 0)
......@@ -209,35 +217,77 @@ inline __device__ T binary_cumsum_exclusive(
__syncthreads();
// load the carry from the preceding warp
if (warp >= 1) {
warp_sum = warp_sum+smem[warp - 1];
}
if (warp_id >= 1)
warp_sum += smem[warp_id - 1];
return warp_sum;
}
// apply raw(byte) offset to pointer
template <typename T>
static __device__ inline T* ptr_add(T *ptr, ga_ssize offset) {
return (T*)((char*)ptr + offset);
}
// get array element using raw(byte) offset
template <typename T>
void __global__ topk_1d_contig_kernel(T* dst, T* src, size_t size, size_t k) {
extern radix_t smem[];
ssize_t bins[RADIX_SIZE]; // TODO: does using 32-bit gives speedup?
static __device__ inline T& ptr_at(T *ptr, ga_ssize offset) {
return *((T*)((char*)ptr + offset));
}
KERNEL void k_topk_dense(
$dims
// ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM}
$dstv
// INPUT_TYPE *dstv
$dstv_strides
// ga_ssize dstv_strides_0, ga_ssize dstv_strides_1, ... , dstv_strides_$${NDIM}
$dsti
// INDEX_TYPE *dsti
$dsti_strides
// ga_ssize dsti_strides_0, ga_ssize dsti_strides_1, ... , dsti_strides_$${NDIM}
ga_ssize k,
INPUT_TYPE* src,
$src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
size_t size) {
/*
extern __shared__ radix_t smem[];
ga_ssize __shared__ bins[RADIX_SIZE]; // TODO: does using 32-bit gives speedup?
bool is_topk = true;
bool is_topkth = true; // exactly k-th largest
radix_t out_idx;
const size_t idx = threadIdx.x;
size_t __shared__ k2, exceed;
const ga_uint warp_id = idx / 32;
const ga_uint lane_id = idx % 32;
radix_t *wmem = (radix_t*)(smem) + warp_id * 32;
const bool in_range = (idx < size);
is_topk &= in_range;
const INPUT_TYPE xval = in_range ? ptr_at(src, idx*src_strides_0) : (INPUT_TYPE)0;
radix_t x = in_range ? RadixConfig<INPUT_TYPE>::convert(xval) : 0;
// resolve negative k
if (k<0) { x = ~x; k = -k; }
if (idx==0) k2 = k;
// 0. get the slice for thread block to work on
size_t gid = blockIdx.x, gidx;
$set_slice
//for(int i=0; i<NDIM; i++) {
//gidx = gid % dims_$${i};
//gid /= dims_$${i};
//dsti = ptr_add(dsti, gidx*dsti_strides_$${i+1};
//dstv = ptr_add(dstv, gidx*dstv_strides_$${i+1};
//src = ptr_add(src, gidx*src_strides_$${i+1});
//}
// 1. filter is_topk and is_topkth using radix select
size_t idx = threadIdx.x;
size_t k2 = k, exceed;
int warp_id = idx / 32;
int lane_id = idx % 32;
radix_t wmem = smem + warp_id * 32;
bool in_range = (idx < size);
RadixConfig<T>::RadixType x = in_range ? RadixConfig<T>::convert(src[idx]) : 0;
// 1. find the kth largest value using radix select
// 1.1 for each radix mask, count
smem[threadIdx.x] = 0;
#pragma unroll
for (int i=bitsof(T)-RADIX_BITS; i; i-=RADIX_BITS) {
radix_t mask = (RADIX_SIZE-1)<<i;
for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
smem[idx] = 0;
int digit = (x>>i) & (RADIX_SIZE-1);
// count within warp
#pragma unroll
......@@ -245,43 +295,34 @@ void __global__ topk_1d_contig_kernel(T* dst, T* src, size_t size, size_t k) {
bool incr_bin = (bin == digit) && is_topkth && in_range;
unsigned int incr_bin_warp = __ballot(incr_bin);
if (lane_id==0)
wmem[bin] += __popc(bin_warp);
wmem[bin] += __popc(incr_bin_warp);
}
__syncthreads();
// sum counts across all warps
// TODO: test in-block parallel sum?
if (idx<RADIX_SIZE)
bins[idx] = 0;
if (idx==0) {
for(int w=1; w<blockDim.x/32; ++w) {
#pragma unroll
for(int bin=0; bin<RADIX_SIZE; ++bin) {
smem[bin] += wmem[bin];
}
}
if (idx < RADIX_SIZE) {
for(int w=32; w<blockDim.x; w+=32)
smem[idx] += smem[idx + w];
}
__syncthreads();
// broadcast sum result
if (idx < RADIX_SIZE)
smem[idx] = bins[idx];
__syncthreads();
// calculate k minus cumsum(count)
exceed = -k; // how many the number of is_topk exceeds k
if (idx<RADIX_SIZE)
bins[idx] = 0;
if (idx == 0) {
bins[0] = k2 - smem[0];
if (bins[0] > 0)
k2 = bins[0];
else if (bins[0] < 0)
exceed = max(exceed, bins[0]);
exceed = k; // how many the number of is_topk exceeds k
bins[RADIX_SIZE-1] = k2 - smem[RADIX_SIZE-1];
if (bins[RADIX_SIZE-1] > 0)
k2 = bins[RADIX_SIZE-1];
else
exceed = min(exceed, bins[RADIX_SIZE-1]);
#pragma unroll
for(int bin=1; bin<RADIX_SIZE; ++bin) {
bins[bin] = bins[bin-1] - smem[bin];
if (bins[bin] > 0)
k2 = bins[bin];
else if (bins[bin] < 0)
exceed = max(exceed, bins[bin]);
for(int bin=RADIX_SIZE-1; bin; --bin) {
bins[bin-1] = bins[bin] - smem[bin-1];
if (bins[bin-1] > 0)
k2 = bins[bin-1];
else
exceed = min(exceed, bins[bin-1]);
}
}
__syncthreads();
......@@ -290,7 +331,7 @@ void __global__ topk_1d_contig_kernel(T* dst, T* src, size_t size, size_t k) {
// smem -> count
// bins -> k2 - cumsum(count)
if (is_topk && is_topkth) {
ssize_t icount = bins[digit];
ga_ssize icount = bins[digit];
if (icount > 0) {
is_topkth = false;
} else if (icount < 0) {
......@@ -305,17 +346,23 @@ void __global__ topk_1d_contig_kernel(T* dst, T* src, size_t size, size_t k) {
}
// 2. find the index of output array, if exists
//
// top_kth value may not be unique, so we need to
// count how many is needed
// perform binary cumsum on is_topkth to drop exceeding top-kth values
radix_t topkth_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topkth);
if (topkth_idx >= exceed)
is_topk = false;
// perform binary cumsum on is_topk to determine idx to put result
topkth_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topkth);
if (is_topk)
dst[topkth_idx] = idx;
if (exceed != 0) {
// top_kth value may not be unique, so we need to
// perform binary cumsum on is_topkth to drop exceeding top-kth values
out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topkth);
is_topk &= (out_idx < exceed);
}
// perform binary cumsum on is_topk to determine the indices to put result
out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topk);
__syncthreads();
if (is_topk) {
$write_value;
// ptr_at(dstv, out_idx * dstv_strides_0) = xval;
$write_index;
// ptr_at(dsti, out_idx * dsti_strides_0) = (INDEX_TYPE)idx;
}
*/
}
......@@ -40,7 +40,7 @@ from theano.tensor import nnet # used for softmax, sigmoid, etc.
from theano.gradient import Rop, Lop, grad, numeric_grad, verify_grad, \
jacobian, hessian, consider_constant
from theano.tensor.sort import sort, argsort, argtopk
from theano.tensor.sort import sort, argsort, topk, argtopk, topk_and_argtopk
from theano.tensor.extra_ops import (DiffOp, bincount, squeeze,
repeat, bartlett, fill_diagonal, fill_diagonal_offset,
cumsum, cumprod)
......
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论