提交 933cb859 authored 作者: Adam Becker's avatar Adam Becker

Implement new kernel to handle arbitrary shape

上级 bf83d342
......@@ -9,26 +9,52 @@
// will all be adjacent
template <typename T>
struct RadixConfig {};
struct RadixConfig {
typedef T RadixType;
static inline __device__ RadixType convert(T v) {
return v;
}
static inline __device__ float deconvert(RadixType v) {
return v;
}
};
template <>
struct RadixConfig<float> {
struct RadixConfig<ga_float> {
typedef ga_uint RadixType;
static inline __device__ RadixType convert(float v) {
static inline __device__ RadixType convert(ga_float v) {
RadixType x = __float_as_int(v);
RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
return (x ^ mask);
}
static inline __device__ float deconvert(RadixType v) {
static inline __device__ ga_float deconvert(RadixType v) {
RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff;
return __int_as_float(v ^ mask);
}
};
template <>
struct RadixConfig<ga_double> {
typedef unsigned long long int RadixType;
static inline __device__ RadixType convert(ga_double v) {
RadixType x = __double_as_longlong(v);
RadixType mask = -((x >> 63)) | 0x8000000000000000;
return (x ^ mask);
}
static inline __device__ ga_double deconvert(RadixType v) {
RadixType mask = ((v >> 63) - 1) | 0x8000000000000000;
return __longlong_as_double(v ^ mask);
}
};
template <>
struct RadixConfig<ga_ubyte> {
typedef ga_uint RadixType;
......@@ -43,14 +69,14 @@ struct RadixConfig<ga_ubyte> {
};
template <>
struct RadixConfig<char> {
struct RadixConfig<ga_byte> {
typedef ga_uint RadixType;
static inline __device__ RadixType convert(char v) {
static inline __device__ RadixType convert(ga_byte v) {
return 128u + v;
}
static inline __device__ char deconvert(RadixType v) {
static inline __device__ ga_byte deconvert(RadixType v) {
return v - 128;
}
};
......@@ -61,7 +87,7 @@ struct RadixConfig<ga_short> {
static inline __device__ RadixType convert(ga_short v) {
assert(sizeof(ga_short) == 2);
return 32768u + v;
return 32768u ^ v;
}
static inline __device__ ga_short deconvert(RadixType v) {
......@@ -75,45 +101,30 @@ struct RadixConfig<int> {
static inline __device__ RadixType convert(int v) {
assert(sizeof(int) == 4);
return 2147483648u + v;
return (1u << 31) ^ v;
}
static inline __device__ int deconvert(RadixType v) {
return v - 2147483648u;
return (1u << 31) ^ v;
}
};
template <>
struct RadixConfig<long> {
struct RadixConfig<ga_long> {
typedef unsigned long long int RadixType;
static inline __device__ RadixType convert(long v) {
static inline __device__ RadixType convert(ga_long v) {
assert(sizeof(long) == 8);
return 9223372036854775808ull + v;
return (1ull << 63) ^ v;
}
static inline __device__ long deconvert(RadixType v) {
return v - 9223372036854775808ull;
}
};
template <>
struct RadixConfig<double> {
typedef unsigned long long int RadixType;
static inline __device__ RadixType convert(double v) {
RadixType x = __double_as_longlong(v);
RadixType mask = -((x >> 63)) | 0x8000000000000000;
return (x ^ mask);
}
static inline __device__ double deconvert(RadixType v) {
RadixType mask = ((v >> 63) - 1) | 0x8000000000000000;
return __longlong_as_double(v ^ mask);
static inline __device__ ga_long deconvert(RadixType v) {
return (1ull << 63) ^ v;
}
};
#ifdef USE_HALF
// TODO: make this work
template <>
struct RadixConfig<half> {
typedef ga_uint RadixType;
......@@ -242,135 +253,9 @@ static __device__ inline T& ptr_at(T *ptr, ga_ssize offset) {
return *((T*)((char*)ptr + offset));
}
KERNEL void k_topk_dense(
$dims
// ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM}
$dstv
// INPUT_TYPE *dstv
$dstv_strides
// ga_ssize dstv_strides_0, ga_ssize dstv_strides_1, ... , dstv_strides_$${NDIM}
$dsti
// INDEX_TYPE *dsti
$dsti_strides
// ga_ssize dsti_strides_0, ga_ssize dsti_strides_1, ... , dsti_strides_$${NDIM}
ga_ssize k,
INPUT_TYPE* src,
$src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
ga_size size) {
LOCAL_MEM radix_t smem[32 * RADIX_SIZE];
ga_ssize LOCAL_MEM bins[RADIX_SIZE]; // TODO: does using 32-bit gives speedup?
bool is_topk=true, is_topkth=true;
radix_t out_idx;
const ga_size idx = LID_0;
ga_size LOCAL_MEM k2, exceed;
const ga_uint warp_id = idx / GA_WARP_SIZE;
const ga_uint lane_id = idx % GA_WARP_SIZE;
const bool in_range = (idx < size);
is_topk &= in_range;
// 0. get the slice for thread block to work on
// TODO if ndim <= 3, use native indexing ? (blockIdx.[xyz])
ga_size gid = GID_0, gidx;
$set_slice
//for(int i=1; i<NDIM; i++) {
// gidx = gid % dims_$${i};
// gid /= dims_$${i};
// dsti = ptr_add(dsti, gidx*dsti_strides_$${i};
// dstv = ptr_add(dstv, gidx*dstv_strides_$${i};
// src = ptr_add(src, gidx*src_strides_$${i});
//}
// get input and its radix friendly form
const INPUT_TYPE xval = in_range ? ptr_at(src, idx*src_strides_0) : (INPUT_TYPE)0;
radix_t x = in_range ? RadixConfig<INPUT_TYPE>::convert(xval) : 0;
// resolve negative k
if (k<0) { x = ~x; k = -k; }
if (idx==0) k2 = k;
// 1. filter is_topk and is_topkth using radix select
#pragma unroll
for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
int digit = (x>>i) & (RADIX_SIZE-1);
// count within warp
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
bool incr_bin = (bin == digit) && is_topkth && in_range;
ga_uint incr_bin_warp = __ballot(incr_bin);
if (lane_id==0)
smem[bin + RADIX_SIZE*warp_id] = __popc(incr_bin_warp);
}
local_barrier();
// sum counts across all warps
// TODO: test in-block parallel sum?
if (idx < RADIX_SIZE) {
for(int w=RADIX_SIZE; w<LDIM_0*RADIX_SIZE / GA_WARP_SIZE; w+=RADIX_SIZE)
smem[idx] += smem[idx + w];
}
local_barrier();
// calculate k minus cumsum(count)
if (idx<RADIX_SIZE)
bins[idx] = 0;
if (idx == 0) {
exceed = k; // how many the number of is_topk exceeds k
bins[RADIX_SIZE-1] = k2 - smem[RADIX_SIZE-1];
if (bins[RADIX_SIZE-1] > 0)
k2 = bins[RADIX_SIZE-1];
else
exceed = min(exceed, bins[RADIX_SIZE-1]);
#pragma unroll
for(int bin=RADIX_SIZE-1; bin; --bin) {
bins[bin-1] = bins[bin] - smem[bin-1];
if (bins[bin-1] > 0)
k2 = bins[bin-1];
else
exceed = min(exceed, bins[bin-1]);
}
}
local_barrier();
// smem -> count
// bins -> k2 - cumsum(count)
if (is_topk && is_topkth) {
ga_ssize icount = bins[digit];
if (icount > 0) {
is_topkth = false;
} else if (icount < 0) {
if (digit+1!=RADIX_SIZE) {
if (bins[digit+1] <= 0) {
is_topk = false;
is_topkth = false;
}
}
}
}
}
// 2. find the index of output array, if exists
if (exceed != 0) {
// top_kth value may not be unique, so we need to
// perform binary cumsum on is_topkth to drop exceeding top-kth values
out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topkth);
is_topk &= (out_idx < exceed);
}
// perform binary cumsum on is_topk to determine the indices to put result
out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topk);
local_barrier();
if (is_topk) {
#if WRITE_VALUE == 1
ptr_at(dstv, out_idx * dstv_strides_0) = xval;
#endif
#if WRITE_INDEX == 1
ptr_at(dsti, out_idx * dsti_strides_0) = (INDEX_TYPE)idx;
#endif
}
// read array element using raw(byte) offset
template <typename T>
static __device__ inline T ptr_read(T *ptr, ga_ssize offset) {
return __ldg(((T*)((char*)ptr + offset)));
}
// works when length on axis is within max allowed threads in block (1024)
KERNEL void k_topk_dense(
$dims
// ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM}
$dstv
// INPUT_TYPE *dstv
$dstv_strides
// ga_ssize dstv_strides_0, ga_ssize dstv_strides_1, ... , dstv_strides_$${NDIM}
$dsti
// INDEX_TYPE *dsti
$dsti_strides
// ga_ssize dsti_strides_0, ga_ssize dsti_strides_1, ... , dsti_strides_$${NDIM}
ga_ssize k,
INPUT_TYPE* src,
$src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
ga_size size) {
LOCAL_MEM radix_t smem[32 * RADIX_SIZE];
ga_ssize LOCAL_MEM bins[RADIX_SIZE+1]; // TODO: does using 32-bit gives good speedup?
bool is_topk=true, is_topkth=true;
radix_t out_idx;
const ga_ushort idx = LID_0;
ga_size LOCAL_MEM k2, exceed;
const ga_ubyte warp_id = idx / GA_WARP_SIZE;
const ga_ubyte lane_id = idx % GA_WARP_SIZE;
const bool in_range = (idx < size);
is_topk &= in_range;
// 0. get the slice for thread block to work on
// TODO if ndim <= 3, use native indexing ? (blockIdx.[xyz])
ga_size gid = GID_0, gidx;
$set_slice
//for(int i=1; i<NDIM; i++) {
// gidx = gid % dims_$${i};
// gid /= dims_$${i};
// dsti = ptr_add(dsti, gidx*dsti_strides_$${i};
// dstv = ptr_add(dstv, gidx*dstv_strides_$${i};
// src = ptr_add(src, gidx*src_strides_$${i});
//}
// get input and its radix friendly form
const INPUT_TYPE xval = in_range ? ptr_at(src, idx*src_strides_0) : (INPUT_TYPE)0;
radix_t x = in_range ? RadixConfig<INPUT_TYPE>::convert(xval) : 0;
// resolve negative k
if (k<0) { x = ~x; k = -k; }
if (idx==0) {
k2 = k;
bins[RADIX_SIZE] = 1;
}
// 1. filter is_topk and is_topkth using radix select
#pragma unroll
for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
int digit = (x>>i) & (RADIX_SIZE-1);
// count within warp
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
bool incr_bin = (bin == digit) && is_topkth && in_range;
ga_uint incr_bin_warp = __ballot(incr_bin);
if (lane_id==0)
smem[bin + RADIX_SIZE*warp_id] = __popc(incr_bin_warp);
}
local_barrier();
// sum counts across all warps
// TODO: test in-block parallel sum?
if (idx < RADIX_SIZE) {
for(int w=RADIX_SIZE; w<LDIM_0*RADIX_SIZE / GA_WARP_SIZE; w+=RADIX_SIZE)
smem[idx] += smem[idx + w];
}
local_barrier();
// bins = k - cumsum(smem[:RADIX_SIZE])
if (idx == 0) {
bins[RADIX_SIZE-1] = k2 - smem[RADIX_SIZE-1];
if (bins[RADIX_SIZE-1] > 0)
k2 = bins[RADIX_SIZE-1];
#pragma unroll
for (int bin=RADIX_SIZE-1; bin; --bin) {
bins[bin-1] = bins[bin] - smem[bin-1];
if (bins[bin-1] > 0)
k2 = bins[bin-1];
}
}
local_barrier();
// smem -> count
// bins -> k2 - cumsum(count)
if (is_topk && is_topkth) {
ga_ssize icount = bins[digit];
if (icount > 0) {
is_topkth = false;
} else if (bins[digit+1] <= 0) {
is_topk = false;
is_topkth = false;
}
}
}
if (idx==0) {
#pragma unroll
for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
if (bins[bin] <= 0) {
exceed = -bins[bin];
break;
}
}
}
local_barrier();
// 2. find the index of output array, if exists
if (exceed != 0) {
// top_kth value may not be unique, so we need to
// perform binary cumsum on is_topkth to drop exceeding top-kth values
out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topkth);
is_topk &= ((!is_topkth) || out_idx>=exceed);
}
// perform binary cumsum on is_topk to determine the indices to put result
out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topk);
local_barrier();
if (is_topk) {
#if WRITE_VALUE == 1
ptr_at(dstv, out_idx * dstv_strides_0) = xval;
#endif
#if WRITE_INDEX == 1
ptr_at(dsti, out_idx * dsti_strides_0) = (INDEX_TYPE)idx;
#endif
}
}
// works when length on axis is larger than max allowed threads in block (1024)
KERNEL void k_topk_dense_large(
$dims
// ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM}
$dstv
// INPUT_TYPE *dstv
$dstv_strides
// ga_ssize dstv_strides_0, ga_ssize dstv_strides_1, ... , dstv_strides_$${NDIM}
$dsti
// INDEX_TYPE *dsti
$dsti_strides
// ga_ssize dsti_strides_0, ga_ssize dsti_strides_1, ... , dsti_strides_$${NDIM}
ga_ssize k,
INPUT_TYPE* src,
$src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
ga_size size, ga_ushort inp_per_thread) {
LOCAL_MEM radix_t smem[32 * RADIX_SIZE];
LOCAL_MEM radix_t known_bits, known_bits_mask;
radix_t out_idx;
ga_size LOCAL_MEM write_base;
INPUT_TYPE xval;
radix_t x;
ga_int i;
bool in_range, is_topk;
const ga_size idx = LID_0;
ga_size LOCAL_MEM k2;
const ga_ushort warp_id = idx / GA_WARP_SIZE;
const ga_ushort lane_id = idx % GA_WARP_SIZE;
// 0. get the slice for thread block to work on
// TODO if ndim <= 3, use native indexing ? (blockIdx.[xyz])
ga_size gid = GID_0, gidx;
$set_slice
//for(int i=1; i<NDIM; i++) {
// gidx = gid % dims_$${i};
// gid /= dims_$${i};
// dsti = ptr_add(dsti, gidx*dsti_strides_$${i};
// dstv = ptr_add(dstv, gidx*dstv_strides_$${i};
// src = ptr_add(src, gidx*src_strides_$${i});
//}
src = ptr_add(src, idx*inp_per_thread*src_strides_0);
LOCAL_MEM radix_t inv_bits;
if (idx==0) {
known_bits = known_bits_mask = 0;
k2 = abs(k);
inv_bits = (k>=0) ? 0 : (~0);
write_base = 0;
}
if (k<0) { k = -k; }
local_barrier();
// 1. find bits of top-k-th value using radix select
#pragma unroll
for (i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
/*for (i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i*=-1) {*/
if (lane_id == 0) {
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
smem[bin + warp_id*RADIX_SIZE] = 0;
}
}
local_barrier();
for (int j=0; j<inp_per_thread; ++j) {
in_range = (idx*inp_per_thread+j) < size;
xval = in_range ? ptr_read(src, j*src_strides_0) : (INPUT_TYPE)0;
x = inv_bits^RadixConfig<INPUT_TYPE>::convert(xval);
ga_int digit = (int)((x>>i) & (RADIX_SIZE-1));
// count within warp
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
bool incr_bin = (
(bin == digit) &&
((x&known_bits_mask) == known_bits) &&
in_range);
ga_uint incr_bin_warp = __ballot(incr_bin);
if (lane_id==0)
smem[bin + RADIX_SIZE*warp_id] += __popc(incr_bin_warp);
}
}
local_barrier();
// sum counts across all warps
// TODO: test in-block parallel sum?
if (idx < RADIX_SIZE) {
for(int w=RADIX_SIZE;
w<(LDIM_0/ GA_WARP_SIZE)*RADIX_SIZE;
w+=RADIX_SIZE)
smem[idx] += smem[idx + w];
}
local_barrier();
// update known bits
if (idx==0) {
#pragma unroll
for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
if (smem[bin] >= k2) {
known_bits |= (bin << i);
known_bits_mask |= ((RADIX_SIZE-1) << i);
break;
} else
k2 -= smem[bin];
}
}
local_barrier();
}
/*
if (idx < RADIX_SIZE) {
ptr_at(dstv, idx*dstv_strides_0) = known_bits;
ptr_at(dstv, idx*dstv_strides_0) = smem[idx];
}
return;
*/
// 2. write values smaller than top-kth
for (i=0; i<inp_per_thread; ++i) {
in_range = (idx*inp_per_thread+i) < size;
xval = in_range ? ptr_read(src, i*src_strides_0) : (INPUT_TYPE)0;
x = inv_bits ^ RadixConfig<INPUT_TYPE>::convert(xval);
is_topk = (x > known_bits) && in_range;
out_idx = binary_cumsum<radix_t>(idx, warp_id, lane_id, smem, is_topk);
if (is_topk) {
#if WRITE_VALUE == 1
ptr_at(dstv, (out_idx+write_base-1) * dstv_strides_0) = xval;
#endif
#if WRITE_INDEX == 1
ptr_at(dsti, (out_idx+write_base-1) * dsti_strides_0) = (INDEX_TYPE)(idx*inp_per_thread + i);
#endif
}
local_barrier();
if (idx == blockDim.x - 1)
write_base += out_idx;
local_barrier();
}
// 3. write values equal to top-kth
for (i=0; i<inp_per_thread; ++i) {
in_range = (idx*inp_per_thread+i) < size;
xval = in_range ? ptr_read(src, i*src_strides_0) : (INPUT_TYPE)0;
x = inv_bits ^ RadixConfig<INPUT_TYPE>::convert(xval);
is_topk = (x == known_bits) && in_range;
out_idx = binary_cumsum<radix_t>(idx, warp_id, lane_id, smem, is_topk);
is_topk = ((out_idx+write_base) <= abs(k)) && is_topk;
if (is_topk) {
#if WRITE_VALUE == 1
ptr_at(dstv, (out_idx+write_base-1) * dstv_strides_0) = xval;
#endif
#if WRITE_INDEX == 1
ptr_at(dsti, (out_idx+write_base-1) * dsti_strides_0) = (INDEX_TYPE)(idx*inp_per_thread + i);
#endif
}
local_barrier();
if (idx == blockDim.x - 1)
write_base += out_idx;
local_barrier();
if(write_base >= abs(k))
break;
}
}
from __future__ import absolute_import, print_function, division
import os
from string import Template
import pdb
import numpy as np
import theano
from theano import Apply
from theano.tensor import as_tensor_variable
......@@ -20,7 +22,6 @@ except ImportError as e:
pass
# TODO add support when slice size is larger than max allowed block size (1024)
# TODO add runtime opt, if k==1, use max/min reduce
# TODO add opt to merge argtopk / topk, or split topk_and_argtopk when only
# one result is needed
......@@ -33,12 +34,13 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
'''
__props__ = TopKOp.__props__
def __init__(self, axis=-1, return_indices=False, return_values=True):
def __init__(self, axis=-1, return_values=True, return_indices=False, idx_dtype='int64'):
GpuKernelBase.__init__(self)
TopKOp.__init__(
self, axis=axis,
return_values=return_values,
return_indices=return_indices)
return_indices=return_indices,
idx_dtype=idx_dtype)
def c_headers(self):
return ['gpuarray_api.h', 'gpuarray_helper.h', 'numpy_compat.h']
......@@ -54,52 +56,56 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
def gpu_kernels(self, node, nodename):
# load kernel source
device_type = node.inputs[0].type.context.kind
knames = ['k_topk_dense', 'k_topk_dense_large']
kernel_ext = {b'cuda':'.cu', b'opencl':'.cl'}[device_type]
try:
kernel_filename = 'topk_kernel%s' % kernel_ext
common_ext = {b'cuda':'.cuh', b'opencl':'.h'}[device_type]
kernel_src = {}
for kname in knames:
with open(os.path.join(
os.path.dirname(__file__), kernel_filename
os.path.dirname(__file__), kname + kernel_ext
), 'r') as f:
kernel_src = f.read()
except FileNotFoundError:
raise RuntimeError(
'Cannot find GPU kernel '
'implementation for device "%s"' % device_type)
kernel_src[kname] = f.read()
# prepare "$" macros
ndim = node.inputs[0].ndim
dstv_strides_code = ''.join('ga_ssize dstv_strides_%d, ' % i for i in range(ndim))
dsti_strides_code = ''.join('ga_ssize dsti_strides_%d, ' % i for i in range(ndim))
src_strides_code = ''.join('ga_ssize src_strides_%d, ' % i for i in range(ndim))
set_slice_code = '''
gidx = gid %% dims_%(i)d;
gid /= dims_%(i)d;
{dstv};
{dsti};
src = ptr_add(src, gidx*src_strides_%(i)d);\n'''.format(
dstv='dstv = ptr_add(dstv, gidx*dstv_strides_%(i)d)' if self.return_values else '',
dsti='dsti = ptr_add(dsti, gidx*dsti_strides_%(i)d)' if self.return_indices else '')
set_slice_code = ''.join(
set_slice_code % dict(i=j) for j in range(1, ndim))
flags = Kernel.get_flags(node.inputs[0].dtype)
subs = dict(
inp_t=ga.dtype_to_ctype(node.inputs[0].dtype),
out_t=ga.dtype_to_ctype(node.outputs[0].dtype),
dims=''.join('ga_size dims_%d, ' % i for i in range(1, ndim)),
dstv='INPUT_TYPE *dstv,' if self.return_values else '',
dsti='INDEX_TYPE *dsti,' if self.return_indices else '',
dstv_strides=dstv_strides_code if self.return_values else '',
dsti_strides=dsti_strides_code if self.return_indices else '',
src_strides=src_strides_code,
set_slice=set_slice_code,
write_value=int(self.return_values),
write_index=int(self.return_indices),
ndim=str(ndim))
with open(os.path.join(
os.path.dirname(__file__), 'k_topk_common' + common_ext
), 'r') as f:
common_src = f.read()
# substitute "$" macros in kernel code
kernel_src = Template(kernel_src).substitute(**subs)
# prepare "$" macros
if device_type == b'cuda':
ndim = node.inputs[0].ndim
dstv_strides_code = ''.join('ga_ssize dstv_strides_%d, ' % i for i in range(ndim))
dsti_strides_code = ''.join('ga_ssize dsti_strides_%d, ' % i for i in range(ndim))
src_strides_code = ''.join('ga_ssize src_strides_%d, ' % i for i in range(ndim))
set_slice_code = '''
gidx = gid %% dims_%(i)d;
gid /= dims_%(i)d;
{dstv};
{dsti};
src = ptr_add(src, gidx*src_strides_%(i)d);\n'''.format(
dstv='dstv = ptr_add(dstv, gidx*dstv_strides_%(i)d)' if self.return_values else '',
dsti='dsti = ptr_add(dsti, gidx*dsti_strides_%(i)d)' if self.return_indices else '')
set_slice_code = ''.join(
set_slice_code % dict(i=j) for j in range(1, ndim))
flags = Kernel.get_flags(node.inputs[0].dtype)
subs = dict(
inp_t=ga.dtype_to_ctype(node.inputs[0].dtype),
out_t=ga.dtype_to_ctype(self.idx_dtype),
dims=''.join('ga_size dims_%d, ' % i for i in range(1, ndim)),
dstv='INPUT_TYPE *dstv,' if self.return_values else '',
dsti='INDEX_TYPE *dsti,' if self.return_indices else '',
dstv_strides=dstv_strides_code if self.return_values else '',
dsti_strides=dsti_strides_code if self.return_indices else '',
src_strides=src_strides_code,
set_slice=set_slice_code,
write_value=int(self.return_values),
write_index=int(self.return_indices),
ndim=str(ndim))
elif device_type == b'opencl':
raise NotImplementedError()
# compile kernel
# compile kernels
kernels = []
param_types = [ga.SIZE] * (ndim - 1) # dims
for _ in range(int(self.return_values) + int(self.return_indices)):
param_types.append(ga.GpuArray) # dst*
......@@ -108,31 +114,39 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
param_types.append(ga.GpuArray) # src
param_types.extend([ga.SSIZE] * ndim) # src_strides
param_types.append(ga.SIZE) # size
self.nargs = len(param_types)
return [Kernel(
code=kernel_src,
kernels.append(Kernel(
code=Template(common_src + kernel_src['k_topk_dense']).substitute(**subs),
name='k_topk_dense',
params=param_types,
flags=flags,
objvar='k_topk_dense_' + nodename
)]
))
param_types.append(np.uint16) # inp_per_thread
kernels.append(Kernel(
code=Template(common_src + kernel_src['k_topk_dense_large']).substitute(**subs),
name='k_topk_dense_large',
params=param_types,
flags=flags,
objvar='k_topk_dense_large_' + nodename
))
return kernels
def c_code(self, node, nodename, inps, outs, sub):
if node.inputs[0].type.context.kind != b'cuda':
raise NotImplementedError('We only have CUDA implementation so far.')
raise NotImplementedError(
'%s: We only have CUDA '
'implementation so far.' % self.__class__.__name__)
x, k = inps
inp_dtc = pygpu.dtypes.dtype_to_ctype(node.inputs[0].dtype).upper()
inp_dtc = ga.dtype_to_typecode(node.inputs[0].dtype)
if not self.return_indices:
yv, = outs
out_dtype_s = ''
out_dtc = ''
else:
if self.return_values:
yv, yi = outs
else:
yi, = outs
out_dtype_s = node.outputs[0].dtype
out_dtc = pygpu.dtypes.dtype_to_ctype(out_dtype_s).upper()
out_dtype_s = self.idx_dtype
out_dtc = ga.dtype_to_typecode(out_dtype_s)
fail = sub['fail']
ctx = sub['params']
k_dtype = node.inputs[1].type.dtype_specs()[1]
......@@ -140,7 +154,6 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
WARP_SIZE = 32
ndim = node.inputs[0].ndim
nargs = self.nargs
reordered_axes = list(range(ndim))
axis = self.axis % ndim
del(reordered_axes[axis])
......@@ -175,16 +188,21 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
sstrides = ', '.join('(void*)(sstrides+%d)' % i for i in reordered_axes)
code = '''
{
const ssize_t k_ = ((%(k_dtype)s*)(PyArray_DATA(%(k)s)))[0];
const size_t *dims = PyGpuArray_DIMS(%(x)s);
size_t odims[%(ndim)d];
for (int i=0; i<%(ndim)d; i++)
odims[i] = dims[i];
odims[%(axis)d] = *((%(k_dtype)s*)(PyArray_DATA(%(k)s)));
if (odims[0] > %(MAX_TPB)d) {
odims[%(axis)d] = k_>=0 ? k_ : -k_;
if (0 == odims[%(axis)d]) {
PyErr_SetString(
PyExc_ValueError,
"topk: slice size larger than %(MAX_TPB)d is not supported");
%(fail)s; }
"topk: k must not be zero");
%(fail)s;
}
%(prep_output)s
// TODO better scheduling?
......@@ -192,32 +210,45 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
size_t *grd = blk+3;
blk[0] = blk[1] = blk[2] = 1;
grd[0] = grd[1] = grd[2] = 1;
// round up to multiples of warp size
for(int i=0; i<%(ndim)d; ++i) {
if (i!=%(axis)d)
grd[0] *= dims[i];
else
blk[0] = dims[i];
}
// round up to multiples of warp size
blk[0] = ((blk[0] + %(WARP_SIZE)d - 1) / %(WARP_SIZE)d) * %(WARP_SIZE)d;
%(def_dvstrides)s;
%(def_distrides)s;
const ssize_t *sstrides = PyGpuArray_STRIDES(%(x)s);
// inputs per thread
unsigned short ipt = (dims[%(axis)d] + (%(MAX_TPB)d/2)-1) / (%(MAX_TPB)d/2);
void* args[] = {
%(dims)s
%(params_dv)s
%(params_di)s
(void*)(odims+%(axis)d),
(void*)(&k_),
(void*)(%(x)s->ga.data),
%(sstrides)s,
(void*)(dims+%(axis)d)
(void*)(dims+%(axis)d),
(void*)(&ipt)
};
int err = GpuKernel_call(
&k_topk_dense_%(nodename)s, 3,
grd, blk, 0,
args);
int err;
if (blk[0] > %(MAX_TPB)d) {
// CUDA_OUT_OF_RESOURCE if a max sized block is used
blk[0] = %(MAX_TPB)d / 2;
err = GpuKernel_call(
&k_topk_dense_large_%(nodename)s, 3,
grd, blk, 0,
args);
} else {
err = GpuKernel_call(
&k_topk_dense_%(nodename)s, 3,
grd, blk, 0,
args);
}
if (err != GA_NO_ERROR) {
PyErr_SetString(
PyExc_RuntimeError,
......@@ -228,37 +259,39 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
'''
return code % locals()
def make_node(self, inp, k, idx_dtype='int64'):
def make_node(self, inp, k):
ctx_name = infer_context_name(inp)
inp = as_gpuarray_variable(inp, ctx_name)
k = as_tensor_variable(k)
bcast = inp.type.broadcastable
outs = []
if self.return_values:
outs.append(inp.type())
if self.return_indices:
outs.append(GpuArrayType(
dtype=idx_dtype,
dtype=self.idx_dtype,
broadcastable=bcast,
context_name=ctx_name)())
if self.return_values:
outs.append(inp.type())
return Apply(self, [inp, k], outs)
def get_params(self, node):
return node.inputs[0].type.context
# def get_op_params(self):
# return [('AXIS', self.axis)]
@register_opt('fast_compile')
@op_lifter([TopKOp])
@register_opt2([TopKOp], 'fast_compile')
def local_gpua_topkop(op, ctx_name, inputs, outputs):
if isinstance(op, GpuTopKOp):
return False
axis = op.axis
rv = op.return_values
ri = op.return_indices
x, k = inputs
x = as_gpuarray_variable(x, ctx_name)
y = outputs[-1]
return GpuTopKOp(
axis=axis, return_values=rv, return_indices=ri)(x, k, idx_dtype=y.dtype)
rets = GpuTopKOp(
axis=axis, return_values=rv, return_indices=ri, idx_dtype=op.idx_dtype)(x, k)
return rets
......@@ -342,20 +342,21 @@ class TopKOp(theano.Op):
# TODO c_code
__props__ = ('axis', 'return_values', 'return_indices')
__props__ = ('axis', 'return_values', 'return_indices', 'idx_dtype')
def __init__(self, axis=-1, return_indices=False, return_values=True):
def __init__(self, axis=-1, return_indices=False, return_values=True, idx_dtype='int64'):
assert isinstance(axis, int)
assert return_indices or return_values
self.axis = axis
self.return_indices = return_indices
self.return_values = return_values
self.idx_dtype = idx_dtype
def __str__(self):
return '%(op)s{axis=%(axis)d}' % dict(
op=self.__class__.__name__, axis=self.axis)
def make_node(self, inp, k, idx_dtype='int64'):
def make_node(self, inp, k):
# numpy always uses float64 as output dtype for arg*() routines
# however, we add this option as memory is more precious on gpu
inp = theano.tensor.as_tensor_variable(inp)
......@@ -366,7 +367,7 @@ class TopKOp(theano.Op):
outs.append(inp.type())
if self.return_indices:
outs.append(
theano.tensor.TensorType(dtype=idx_dtype, broadcastable=bcast)())
theano.tensor.TensorType(dtype=self.idx_dtype, broadcastable=bcast)())
return theano.Apply(self, [inp, k], outs)
def perform(self, node, inputs, output_storage):
......@@ -458,18 +459,18 @@ def argtopk(x, k, axis=-1, idx_dtype='int64'):
if axis is None:
x = theano.tensor.flatten(x)
axis = -1
return TopKOp(axis=axis, return_indices=True, return_values=False)(x, k, idx_dtype=idx_dtype)
return TopKOp(axis=axis, return_indices=True, return_values=False, idx_dtype=idx_dtype)(x, k)
def topk_and_argtopk(x, k, axis=-1, idx_dtype='int64'):
'''
"""
Returns the results of both topk() and argtopk() in one Op.
See the respective documentation for details.
'''
"""
if axis is None:
x = theano.tensor.flatten(x)
axis = -1
return TopKOp(axis=axis, return_indices=True)(x, k, idx_dtype=idx_dtype)
return TopKOp(axis=axis, return_indices=True, idx_dtype=idx_dtype)(x, k)
......@@ -21,14 +21,12 @@ _int_dtypes = (
'int8', 'int16', 'int32', 'int64',
'uint8', 'uint16', 'uint32', 'uint64')
def gen_unique_vector(size, dtype):
# generate a randomized vector with unique elements
retval = np.arange(size*3) + np.random.uniform(-1., 1.)
return (retval[np.random.permutation(size)] - size*1.5).astype(dtype)
'''
class Test_sort(unittest.TestCase):
def setUp(self):
......@@ -236,7 +234,6 @@ def test_argsort_grad():
data = np.random.rand(2, 3, 3).astype(theano.config.floatX)
utt.verify_grad(lambda x: argsort(x, axis=2), [data])
'''
class Test_TopK(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论