提交 933cb859 authored 作者: Adam Becker's avatar Adam Becker

Implement new kernel to handle arbitrary shape

上级 bf83d342
......@@ -9,26 +9,52 @@
// will all be adjacent
template <typename T>
struct RadixConfig {};
struct RadixConfig {
typedef T RadixType;
static inline __device__ RadixType convert(T v) {
return v;
}
static inline __device__ float deconvert(RadixType v) {
return v;
}
};
template <>
struct RadixConfig<float> {
struct RadixConfig<ga_float> {
typedef ga_uint RadixType;
static inline __device__ RadixType convert(float v) {
static inline __device__ RadixType convert(ga_float v) {
RadixType x = __float_as_int(v);
RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
return (x ^ mask);
}
static inline __device__ float deconvert(RadixType v) {
static inline __device__ ga_float deconvert(RadixType v) {
RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff;
return __int_as_float(v ^ mask);
}
};
template <>
struct RadixConfig<ga_double> {
typedef unsigned long long int RadixType;
static inline __device__ RadixType convert(ga_double v) {
RadixType x = __double_as_longlong(v);
RadixType mask = -((x >> 63)) | 0x8000000000000000;
return (x ^ mask);
}
static inline __device__ ga_double deconvert(RadixType v) {
RadixType mask = ((v >> 63) - 1) | 0x8000000000000000;
return __longlong_as_double(v ^ mask);
}
};
template <>
struct RadixConfig<ga_ubyte> {
typedef ga_uint RadixType;
......@@ -43,14 +69,14 @@ struct RadixConfig<ga_ubyte> {
};
template <>
struct RadixConfig<char> {
struct RadixConfig<ga_byte> {
typedef ga_uint RadixType;
static inline __device__ RadixType convert(char v) {
static inline __device__ RadixType convert(ga_byte v) {
return 128u + v;
}
static inline __device__ char deconvert(RadixType v) {
static inline __device__ ga_byte deconvert(RadixType v) {
return v - 128;
}
};
......@@ -61,7 +87,7 @@ struct RadixConfig<ga_short> {
static inline __device__ RadixType convert(ga_short v) {
assert(sizeof(ga_short) == 2);
return 32768u + v;
return 32768u ^ v;
}
static inline __device__ ga_short deconvert(RadixType v) {
......@@ -75,45 +101,30 @@ struct RadixConfig<int> {
static inline __device__ RadixType convert(int v) {
assert(sizeof(int) == 4);
return 2147483648u + v;
return (1u << 31) ^ v;
}
static inline __device__ int deconvert(RadixType v) {
return v - 2147483648u;
return (1u << 31) ^ v;
}
};
template <>
struct RadixConfig<long> {
struct RadixConfig<ga_long> {
typedef unsigned long long int RadixType;
static inline __device__ RadixType convert(long v) {
static inline __device__ RadixType convert(ga_long v) {
assert(sizeof(long) == 8);
return 9223372036854775808ull + v;
return (1ull << 63) ^ v;
}
static inline __device__ long deconvert(RadixType v) {
return v - 9223372036854775808ull;
}
};
template <>
struct RadixConfig<double> {
typedef unsigned long long int RadixType;
static inline __device__ RadixType convert(double v) {
RadixType x = __double_as_longlong(v);
RadixType mask = -((x >> 63)) | 0x8000000000000000;
return (x ^ mask);
}
static inline __device__ double deconvert(RadixType v) {
RadixType mask = ((v >> 63) - 1) | 0x8000000000000000;
return __longlong_as_double(v ^ mask);
static inline __device__ ga_long deconvert(RadixType v) {
return (1ull << 63) ^ v;
}
};
#ifdef USE_HALF
// TODO: make this work
template <>
struct RadixConfig<half> {
typedef ga_uint RadixType;
......@@ -242,135 +253,9 @@ static __device__ inline T& ptr_at(T *ptr, ga_ssize offset) {
return *((T*)((char*)ptr + offset));
}
KERNEL void k_topk_dense(
$dims
// ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM}
$dstv
// INPUT_TYPE *dstv
$dstv_strides
// ga_ssize dstv_strides_0, ga_ssize dstv_strides_1, ... , dstv_strides_$${NDIM}
$dsti
// INDEX_TYPE *dsti
$dsti_strides
// ga_ssize dsti_strides_0, ga_ssize dsti_strides_1, ... , dsti_strides_$${NDIM}
ga_ssize k,
INPUT_TYPE* src,
$src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
ga_size size) {
LOCAL_MEM radix_t smem[32 * RADIX_SIZE];
ga_ssize LOCAL_MEM bins[RADIX_SIZE]; // TODO: does using 32-bit gives speedup?
bool is_topk=true, is_topkth=true;
radix_t out_idx;
const ga_size idx = LID_0;
ga_size LOCAL_MEM k2, exceed;
const ga_uint warp_id = idx / GA_WARP_SIZE;
const ga_uint lane_id = idx % GA_WARP_SIZE;
const bool in_range = (idx < size);
is_topk &= in_range;
// 0. get the slice for thread block to work on
// TODO if ndim <= 3, use native indexing ? (blockIdx.[xyz])
ga_size gid = GID_0, gidx;
$set_slice
//for(int i=1; i<NDIM; i++) {
// gidx = gid % dims_$${i};
// gid /= dims_$${i};
// dsti = ptr_add(dsti, gidx*dsti_strides_$${i};
// dstv = ptr_add(dstv, gidx*dstv_strides_$${i};
// src = ptr_add(src, gidx*src_strides_$${i});
//}
// get input and its radix friendly form
const INPUT_TYPE xval = in_range ? ptr_at(src, idx*src_strides_0) : (INPUT_TYPE)0;
radix_t x = in_range ? RadixConfig<INPUT_TYPE>::convert(xval) : 0;
// resolve negative k
if (k<0) { x = ~x; k = -k; }
if (idx==0) k2 = k;
// 1. filter is_topk and is_topkth using radix select
#pragma unroll
for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
int digit = (x>>i) & (RADIX_SIZE-1);
// count within warp
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
bool incr_bin = (bin == digit) && is_topkth && in_range;
ga_uint incr_bin_warp = __ballot(incr_bin);
if (lane_id==0)
smem[bin + RADIX_SIZE*warp_id] = __popc(incr_bin_warp);
}
local_barrier();
// sum counts across all warps
// TODO: test in-block parallel sum?
if (idx < RADIX_SIZE) {
for(int w=RADIX_SIZE; w<LDIM_0*RADIX_SIZE / GA_WARP_SIZE; w+=RADIX_SIZE)
smem[idx] += smem[idx + w];
}
local_barrier();
// calculate k minus cumsum(count)
if (idx<RADIX_SIZE)
bins[idx] = 0;
if (idx == 0) {
exceed = k; // how many the number of is_topk exceeds k
bins[RADIX_SIZE-1] = k2 - smem[RADIX_SIZE-1];
if (bins[RADIX_SIZE-1] > 0)
k2 = bins[RADIX_SIZE-1];
else
exceed = min(exceed, bins[RADIX_SIZE-1]);
#pragma unroll
for(int bin=RADIX_SIZE-1; bin; --bin) {
bins[bin-1] = bins[bin] - smem[bin-1];
if (bins[bin-1] > 0)
k2 = bins[bin-1];
else
exceed = min(exceed, bins[bin-1]);
}
}
local_barrier();
// smem -> count
// bins -> k2 - cumsum(count)
if (is_topk && is_topkth) {
ga_ssize icount = bins[digit];
if (icount > 0) {
is_topkth = false;
} else if (icount < 0) {
if (digit+1!=RADIX_SIZE) {
if (bins[digit+1] <= 0) {
is_topk = false;
is_topkth = false;
}
}
}
}
}
// 2. find the index of output array, if exists
if (exceed != 0) {
// top_kth value may not be unique, so we need to
// perform binary cumsum on is_topkth to drop exceeding top-kth values
out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topkth);
is_topk &= (out_idx < exceed);
}
// perform binary cumsum on is_topk to determine the indices to put result
out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topk);
local_barrier();
if (is_topk) {
#if WRITE_VALUE == 1
ptr_at(dstv, out_idx * dstv_strides_0) = xval;
#endif
#if WRITE_INDEX == 1
ptr_at(dsti, out_idx * dsti_strides_0) = (INDEX_TYPE)idx;
#endif
}
// read array element using raw(byte) offset
template <typename T>
static __device__ inline T ptr_read(T *ptr, ga_ssize offset) {
return __ldg(((T*)((char*)ptr + offset)));
}
// works when length on axis is within max allowed threads in block (1024)
KERNEL void k_topk_dense(
$dims
// ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM}
$dstv
// INPUT_TYPE *dstv
$dstv_strides
// ga_ssize dstv_strides_0, ga_ssize dstv_strides_1, ... , dstv_strides_$${NDIM}
$dsti
// INDEX_TYPE *dsti
$dsti_strides
// ga_ssize dsti_strides_0, ga_ssize dsti_strides_1, ... , dsti_strides_$${NDIM}
ga_ssize k,
INPUT_TYPE* src,
$src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
ga_size size) {
LOCAL_MEM radix_t smem[32 * RADIX_SIZE];
ga_ssize LOCAL_MEM bins[RADIX_SIZE+1]; // TODO: does using 32-bit gives good speedup?
bool is_topk=true, is_topkth=true;
radix_t out_idx;
const ga_ushort idx = LID_0;
ga_size LOCAL_MEM k2, exceed;
const ga_ubyte warp_id = idx / GA_WARP_SIZE;
const ga_ubyte lane_id = idx % GA_WARP_SIZE;
const bool in_range = (idx < size);
is_topk &= in_range;
// 0. get the slice for thread block to work on
// TODO if ndim <= 3, use native indexing ? (blockIdx.[xyz])
ga_size gid = GID_0, gidx;
$set_slice
//for(int i=1; i<NDIM; i++) {
// gidx = gid % dims_$${i};
// gid /= dims_$${i};
// dsti = ptr_add(dsti, gidx*dsti_strides_$${i};
// dstv = ptr_add(dstv, gidx*dstv_strides_$${i};
// src = ptr_add(src, gidx*src_strides_$${i});
//}
// get input and its radix friendly form
const INPUT_TYPE xval = in_range ? ptr_at(src, idx*src_strides_0) : (INPUT_TYPE)0;
radix_t x = in_range ? RadixConfig<INPUT_TYPE>::convert(xval) : 0;
// resolve negative k
if (k<0) { x = ~x; k = -k; }
if (idx==0) {
k2 = k;
bins[RADIX_SIZE] = 1;
}
// 1. filter is_topk and is_topkth using radix select
#pragma unroll
for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
int digit = (x>>i) & (RADIX_SIZE-1);
// count within warp
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
bool incr_bin = (bin == digit) && is_topkth && in_range;
ga_uint incr_bin_warp = __ballot(incr_bin);
if (lane_id==0)
smem[bin + RADIX_SIZE*warp_id] = __popc(incr_bin_warp);
}
local_barrier();
// sum counts across all warps
// TODO: test in-block parallel sum?
if (idx < RADIX_SIZE) {
for(int w=RADIX_SIZE; w<LDIM_0*RADIX_SIZE / GA_WARP_SIZE; w+=RADIX_SIZE)
smem[idx] += smem[idx + w];
}
local_barrier();
// bins = k - cumsum(smem[:RADIX_SIZE])
if (idx == 0) {
bins[RADIX_SIZE-1] = k2 - smem[RADIX_SIZE-1];
if (bins[RADIX_SIZE-1] > 0)
k2 = bins[RADIX_SIZE-1];
#pragma unroll
for (int bin=RADIX_SIZE-1; bin; --bin) {
bins[bin-1] = bins[bin] - smem[bin-1];
if (bins[bin-1] > 0)
k2 = bins[bin-1];
}
}
local_barrier();
// smem -> count
// bins -> k2 - cumsum(count)
if (is_topk && is_topkth) {
ga_ssize icount = bins[digit];
if (icount > 0) {
is_topkth = false;
} else if (bins[digit+1] <= 0) {
is_topk = false;
is_topkth = false;
}
}
}
if (idx==0) {
#pragma unroll
for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
if (bins[bin] <= 0) {
exceed = -bins[bin];
break;
}
}
}
local_barrier();
// 2. find the index of output array, if exists
if (exceed != 0) {
// top_kth value may not be unique, so we need to
// perform binary cumsum on is_topkth to drop exceeding top-kth values
out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topkth);
is_topk &= ((!is_topkth) || out_idx>=exceed);
}
// perform binary cumsum on is_topk to determine the indices to put result
out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topk);
local_barrier();
if (is_topk) {
#if WRITE_VALUE == 1
ptr_at(dstv, out_idx * dstv_strides_0) = xval;
#endif
#if WRITE_INDEX == 1
ptr_at(dsti, out_idx * dsti_strides_0) = (INDEX_TYPE)idx;
#endif
}
}
// works when length on axis is larger than max allowed threads in block (1024)
KERNEL void k_topk_dense_large(
$dims
// ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM}
$dstv
// INPUT_TYPE *dstv
$dstv_strides
// ga_ssize dstv_strides_0, ga_ssize dstv_strides_1, ... , dstv_strides_$${NDIM}
$dsti
// INDEX_TYPE *dsti
$dsti_strides
// ga_ssize dsti_strides_0, ga_ssize dsti_strides_1, ... , dsti_strides_$${NDIM}
ga_ssize k,
INPUT_TYPE* src,
$src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
ga_size size, ga_ushort inp_per_thread) {
LOCAL_MEM radix_t smem[32 * RADIX_SIZE];
LOCAL_MEM radix_t known_bits, known_bits_mask;
radix_t out_idx;
ga_size LOCAL_MEM write_base;
INPUT_TYPE xval;
radix_t x;
ga_int i;
bool in_range, is_topk;
const ga_size idx = LID_0;
ga_size LOCAL_MEM k2;
const ga_ushort warp_id = idx / GA_WARP_SIZE;
const ga_ushort lane_id = idx % GA_WARP_SIZE;
// 0. get the slice for thread block to work on
// TODO if ndim <= 3, use native indexing ? (blockIdx.[xyz])
ga_size gid = GID_0, gidx;
$set_slice
//for(int i=1; i<NDIM; i++) {
// gidx = gid % dims_$${i};
// gid /= dims_$${i};
// dsti = ptr_add(dsti, gidx*dsti_strides_$${i};
// dstv = ptr_add(dstv, gidx*dstv_strides_$${i};
// src = ptr_add(src, gidx*src_strides_$${i});
//}
src = ptr_add(src, idx*inp_per_thread*src_strides_0);
LOCAL_MEM radix_t inv_bits;
if (idx==0) {
known_bits = known_bits_mask = 0;
k2 = abs(k);
inv_bits = (k>=0) ? 0 : (~0);
write_base = 0;
}
if (k<0) { k = -k; }
local_barrier();
// 1. find bits of top-k-th value using radix select
#pragma unroll
for (i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
/*for (i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i*=-1) {*/
if (lane_id == 0) {
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
smem[bin + warp_id*RADIX_SIZE] = 0;
}
}
local_barrier();
for (int j=0; j<inp_per_thread; ++j) {
in_range = (idx*inp_per_thread+j) < size;
xval = in_range ? ptr_read(src, j*src_strides_0) : (INPUT_TYPE)0;
x = inv_bits^RadixConfig<INPUT_TYPE>::convert(xval);
ga_int digit = (int)((x>>i) & (RADIX_SIZE-1));
// count within warp
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
bool incr_bin = (
(bin == digit) &&
((x&known_bits_mask) == known_bits) &&
in_range);
ga_uint incr_bin_warp = __ballot(incr_bin);
if (lane_id==0)
smem[bin + RADIX_SIZE*warp_id] += __popc(incr_bin_warp);
}
}
local_barrier();
// sum counts across all warps
// TODO: test in-block parallel sum?
if (idx < RADIX_SIZE) {
for(int w=RADIX_SIZE;
w<(LDIM_0/ GA_WARP_SIZE)*RADIX_SIZE;
w+=RADIX_SIZE)
smem[idx] += smem[idx + w];
}
local_barrier();
// update known bits
if (idx==0) {
#pragma unroll
for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
if (smem[bin] >= k2) {
known_bits |= (bin << i);
known_bits_mask |= ((RADIX_SIZE-1) << i);
break;
} else
k2 -= smem[bin];
}
}
local_barrier();
}
/*
if (idx < RADIX_SIZE) {
ptr_at(dstv, idx*dstv_strides_0) = known_bits;
ptr_at(dstv, idx*dstv_strides_0) = smem[idx];
}
return;
*/
// 2. write values smaller than top-kth
for (i=0; i<inp_per_thread; ++i) {
in_range = (idx*inp_per_thread+i) < size;
xval = in_range ? ptr_read(src, i*src_strides_0) : (INPUT_TYPE)0;
x = inv_bits ^ RadixConfig<INPUT_TYPE>::convert(xval);
is_topk = (x > known_bits) && in_range;
out_idx = binary_cumsum<radix_t>(idx, warp_id, lane_id, smem, is_topk);
if (is_topk) {
#if WRITE_VALUE == 1
ptr_at(dstv, (out_idx+write_base-1) * dstv_strides_0) = xval;
#endif
#if WRITE_INDEX == 1
ptr_at(dsti, (out_idx+write_base-1) * dsti_strides_0) = (INDEX_TYPE)(idx*inp_per_thread + i);
#endif
}
local_barrier();
if (idx == blockDim.x - 1)
write_base += out_idx;
local_barrier();
}
// 3. write values equal to top-kth
for (i=0; i<inp_per_thread; ++i) {
in_range = (idx*inp_per_thread+i) < size;
xval = in_range ? ptr_read(src, i*src_strides_0) : (INPUT_TYPE)0;
x = inv_bits ^ RadixConfig<INPUT_TYPE>::convert(xval);
is_topk = (x == known_bits) && in_range;
out_idx = binary_cumsum<radix_t>(idx, warp_id, lane_id, smem, is_topk);
is_topk = ((out_idx+write_base) <= abs(k)) && is_topk;
if (is_topk) {
#if WRITE_VALUE == 1
ptr_at(dstv, (out_idx+write_base-1) * dstv_strides_0) = xval;
#endif
#if WRITE_INDEX == 1
ptr_at(dsti, (out_idx+write_base-1) * dsti_strides_0) = (INDEX_TYPE)(idx*inp_per_thread + i);
#endif
}
local_barrier();
if (idx == blockDim.x - 1)
write_base += out_idx;
local_barrier();
if(write_base >= abs(k))
break;
}
}
差异被折叠。
......@@ -342,20 +342,21 @@ class TopKOp(theano.Op):
# TODO c_code
__props__ = ('axis', 'return_values', 'return_indices')
__props__ = ('axis', 'return_values', 'return_indices', 'idx_dtype')
def __init__(self, axis=-1, return_indices=False, return_values=True):
def __init__(self, axis=-1, return_indices=False, return_values=True, idx_dtype='int64'):
assert isinstance(axis, int)
assert return_indices or return_values
self.axis = axis
self.return_indices = return_indices
self.return_values = return_values
self.idx_dtype = idx_dtype
def __str__(self):
return '%(op)s{axis=%(axis)d}' % dict(
op=self.__class__.__name__, axis=self.axis)
def make_node(self, inp, k, idx_dtype='int64'):
def make_node(self, inp, k):
# numpy always uses float64 as output dtype for arg*() routines
# however, we add this option as memory is more precious on gpu
inp = theano.tensor.as_tensor_variable(inp)
......@@ -366,7 +367,7 @@ class TopKOp(theano.Op):
outs.append(inp.type())
if self.return_indices:
outs.append(
theano.tensor.TensorType(dtype=idx_dtype, broadcastable=bcast)())
theano.tensor.TensorType(dtype=self.idx_dtype, broadcastable=bcast)())
return theano.Apply(self, [inp, k], outs)
def perform(self, node, inputs, output_storage):
......@@ -458,18 +459,18 @@ def argtopk(x, k, axis=-1, idx_dtype='int64'):
if axis is None:
x = theano.tensor.flatten(x)
axis = -1
return TopKOp(axis=axis, return_indices=True, return_values=False)(x, k, idx_dtype=idx_dtype)
return TopKOp(axis=axis, return_indices=True, return_values=False, idx_dtype=idx_dtype)(x, k)
def topk_and_argtopk(x, k, axis=-1, idx_dtype='int64'):
'''
"""
Returns the results of both topk() and argtopk() in one Op.
See the respective documentation for details.
'''
"""
if axis is None:
x = theano.tensor.flatten(x)
axis = -1
return TopKOp(axis=axis, return_indices=True)(x, k, idx_dtype=idx_dtype)
return TopKOp(axis=axis, return_indices=True, idx_dtype=idx_dtype)(x, k)
......@@ -21,14 +21,12 @@ _int_dtypes = (
'int8', 'int16', 'int32', 'int64',
'uint8', 'uint16', 'uint32', 'uint64')
def gen_unique_vector(size, dtype):
# generate a randomized vector with unique elements
retval = np.arange(size*3) + np.random.uniform(-1., 1.)
return (retval[np.random.permutation(size)] - size*1.5).astype(dtype)
'''
class Test_sort(unittest.TestCase):
def setUp(self):
......@@ -236,7 +234,6 @@ def test_argsort_grad():
data = np.random.rand(2, 3, 3).astype(theano.config.floatX)
utt.verify_grad(lambda x: argsort(x, axis=2), [data])
'''
class Test_TopK(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论