提交 933cb859 authored 作者: Adam Becker's avatar Adam Becker

Implement new kernel to handle arbitrary shape

上级 bf83d342
...@@ -9,26 +9,52 @@ ...@@ -9,26 +9,52 @@
// will all be adjacent // will all be adjacent
template <typename T> template <typename T>
struct RadixConfig {}; struct RadixConfig {
typedef T RadixType;
static inline __device__ RadixType convert(T v) {
return v;
}
static inline __device__ float deconvert(RadixType v) {
return v;
}
};
template <> template <>
struct RadixConfig<float> { struct RadixConfig<ga_float> {
typedef ga_uint RadixType; typedef ga_uint RadixType;
static inline __device__ RadixType convert(float v) { static inline __device__ RadixType convert(ga_float v) {
RadixType x = __float_as_int(v); RadixType x = __float_as_int(v);
RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000; RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
return (x ^ mask); return (x ^ mask);
} }
static inline __device__ float deconvert(RadixType v) { static inline __device__ ga_float deconvert(RadixType v) {
RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff; RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff;
return __int_as_float(v ^ mask); return __int_as_float(v ^ mask);
} }
}; };
template <>
struct RadixConfig<ga_double> {
typedef unsigned long long int RadixType;
static inline __device__ RadixType convert(ga_double v) {
RadixType x = __double_as_longlong(v);
RadixType mask = -((x >> 63)) | 0x8000000000000000;
return (x ^ mask);
}
static inline __device__ ga_double deconvert(RadixType v) {
RadixType mask = ((v >> 63) - 1) | 0x8000000000000000;
return __longlong_as_double(v ^ mask);
}
};
template <> template <>
struct RadixConfig<ga_ubyte> { struct RadixConfig<ga_ubyte> {
typedef ga_uint RadixType; typedef ga_uint RadixType;
...@@ -43,14 +69,14 @@ struct RadixConfig<ga_ubyte> { ...@@ -43,14 +69,14 @@ struct RadixConfig<ga_ubyte> {
}; };
template <> template <>
struct RadixConfig<char> { struct RadixConfig<ga_byte> {
typedef ga_uint RadixType; typedef ga_uint RadixType;
static inline __device__ RadixType convert(char v) { static inline __device__ RadixType convert(ga_byte v) {
return 128u + v; return 128u + v;
} }
static inline __device__ char deconvert(RadixType v) { static inline __device__ ga_byte deconvert(RadixType v) {
return v - 128; return v - 128;
} }
}; };
...@@ -61,7 +87,7 @@ struct RadixConfig<ga_short> { ...@@ -61,7 +87,7 @@ struct RadixConfig<ga_short> {
static inline __device__ RadixType convert(ga_short v) { static inline __device__ RadixType convert(ga_short v) {
assert(sizeof(ga_short) == 2); assert(sizeof(ga_short) == 2);
return 32768u + v; return 32768u ^ v;
} }
static inline __device__ ga_short deconvert(RadixType v) { static inline __device__ ga_short deconvert(RadixType v) {
...@@ -75,45 +101,30 @@ struct RadixConfig<int> { ...@@ -75,45 +101,30 @@ struct RadixConfig<int> {
static inline __device__ RadixType convert(int v) { static inline __device__ RadixType convert(int v) {
assert(sizeof(int) == 4); assert(sizeof(int) == 4);
return 2147483648u + v; return (1u << 31) ^ v;
} }
static inline __device__ int deconvert(RadixType v) { static inline __device__ int deconvert(RadixType v) {
return v - 2147483648u; return (1u << 31) ^ v;
} }
}; };
template <> template <>
struct RadixConfig<long> { struct RadixConfig<ga_long> {
typedef unsigned long long int RadixType; typedef unsigned long long int RadixType;
static inline __device__ RadixType convert(long v) { static inline __device__ RadixType convert(ga_long v) {
assert(sizeof(long) == 8); assert(sizeof(long) == 8);
return 9223372036854775808ull + v; return (1ull << 63) ^ v;
}
static inline __device__ long deconvert(RadixType v) {
return v - 9223372036854775808ull;
}
};
template <>
struct RadixConfig<double> {
typedef unsigned long long int RadixType;
static inline __device__ RadixType convert(double v) {
RadixType x = __double_as_longlong(v);
RadixType mask = -((x >> 63)) | 0x8000000000000000;
return (x ^ mask);
} }
static inline __device__ double deconvert(RadixType v) { static inline __device__ ga_long deconvert(RadixType v) {
RadixType mask = ((v >> 63) - 1) | 0x8000000000000000; return (1ull << 63) ^ v;
return __longlong_as_double(v ^ mask);
} }
}; };
#ifdef USE_HALF #ifdef USE_HALF
// TODO: make this work
template <> template <>
struct RadixConfig<half> { struct RadixConfig<half> {
typedef ga_uint RadixType; typedef ga_uint RadixType;
...@@ -242,135 +253,9 @@ static __device__ inline T& ptr_at(T *ptr, ga_ssize offset) { ...@@ -242,135 +253,9 @@ static __device__ inline T& ptr_at(T *ptr, ga_ssize offset) {
return *((T*)((char*)ptr + offset)); return *((T*)((char*)ptr + offset));
} }
KERNEL void k_topk_dense( // read array element using raw(byte) offset
$dims template <typename T>
// ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM} static __device__ inline T ptr_read(T *ptr, ga_ssize offset) {
$dstv return __ldg(((T*)((char*)ptr + offset)));
// INPUT_TYPE *dstv
$dstv_strides
// ga_ssize dstv_strides_0, ga_ssize dstv_strides_1, ... , dstv_strides_$${NDIM}
$dsti
// INDEX_TYPE *dsti
$dsti_strides
// ga_ssize dsti_strides_0, ga_ssize dsti_strides_1, ... , dsti_strides_$${NDIM}
ga_ssize k,
INPUT_TYPE* src,
$src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
ga_size size) {
LOCAL_MEM radix_t smem[32 * RADIX_SIZE];
ga_ssize LOCAL_MEM bins[RADIX_SIZE]; // TODO: does using 32-bit gives speedup?
bool is_topk=true, is_topkth=true;
radix_t out_idx;
const ga_size idx = LID_0;
ga_size LOCAL_MEM k2, exceed;
const ga_uint warp_id = idx / GA_WARP_SIZE;
const ga_uint lane_id = idx % GA_WARP_SIZE;
const bool in_range = (idx < size);
is_topk &= in_range;
// 0. get the slice for thread block to work on
// TODO if ndim <= 3, use native indexing ? (blockIdx.[xyz])
ga_size gid = GID_0, gidx;
$set_slice
//for(int i=1; i<NDIM; i++) {
// gidx = gid % dims_$${i};
// gid /= dims_$${i};
// dsti = ptr_add(dsti, gidx*dsti_strides_$${i};
// dstv = ptr_add(dstv, gidx*dstv_strides_$${i};
// src = ptr_add(src, gidx*src_strides_$${i});
//}
// get input and its radix friendly form
const INPUT_TYPE xval = in_range ? ptr_at(src, idx*src_strides_0) : (INPUT_TYPE)0;
radix_t x = in_range ? RadixConfig<INPUT_TYPE>::convert(xval) : 0;
// resolve negative k
if (k<0) { x = ~x; k = -k; }
if (idx==0) k2 = k;
// 1. filter is_topk and is_topkth using radix select
#pragma unroll
for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
int digit = (x>>i) & (RADIX_SIZE-1);
// count within warp
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
bool incr_bin = (bin == digit) && is_topkth && in_range;
ga_uint incr_bin_warp = __ballot(incr_bin);
if (lane_id==0)
smem[bin + RADIX_SIZE*warp_id] = __popc(incr_bin_warp);
}
local_barrier();
// sum counts across all warps
// TODO: test in-block parallel sum?
if (idx < RADIX_SIZE) {
for(int w=RADIX_SIZE; w<LDIM_0*RADIX_SIZE / GA_WARP_SIZE; w+=RADIX_SIZE)
smem[idx] += smem[idx + w];
}
local_barrier();
// calculate k minus cumsum(count)
if (idx<RADIX_SIZE)
bins[idx] = 0;
if (idx == 0) {
exceed = k; // how many the number of is_topk exceeds k
bins[RADIX_SIZE-1] = k2 - smem[RADIX_SIZE-1];
if (bins[RADIX_SIZE-1] > 0)
k2 = bins[RADIX_SIZE-1];
else
exceed = min(exceed, bins[RADIX_SIZE-1]);
#pragma unroll
for(int bin=RADIX_SIZE-1; bin; --bin) {
bins[bin-1] = bins[bin] - smem[bin-1];
if (bins[bin-1] > 0)
k2 = bins[bin-1];
else
exceed = min(exceed, bins[bin-1]);
}
}
local_barrier();
// smem -> count
// bins -> k2 - cumsum(count)
if (is_topk && is_topkth) {
ga_ssize icount = bins[digit];
if (icount > 0) {
is_topkth = false;
} else if (icount < 0) {
if (digit+1!=RADIX_SIZE) {
if (bins[digit+1] <= 0) {
is_topk = false;
is_topkth = false;
}
}
}
}
}
// 2. find the index of output array, if exists
if (exceed != 0) {
// top_kth value may not be unique, so we need to
// perform binary cumsum on is_topkth to drop exceeding top-kth values
out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topkth);
is_topk &= (out_idx < exceed);
}
// perform binary cumsum on is_topk to determine the indices to put result
out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topk);
local_barrier();
if (is_topk) {
#if WRITE_VALUE == 1
ptr_at(dstv, out_idx * dstv_strides_0) = xval;
#endif
#if WRITE_INDEX == 1
ptr_at(dsti, out_idx * dsti_strides_0) = (INDEX_TYPE)idx;
#endif
}
} }
// works when length on axis is within max allowed threads in block (1024)
KERNEL void k_topk_dense(
$dims
// ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM}
$dstv
// INPUT_TYPE *dstv
$dstv_strides
// ga_ssize dstv_strides_0, ga_ssize dstv_strides_1, ... , dstv_strides_$${NDIM}
$dsti
// INDEX_TYPE *dsti
$dsti_strides
// ga_ssize dsti_strides_0, ga_ssize dsti_strides_1, ... , dsti_strides_$${NDIM}
ga_ssize k,
INPUT_TYPE* src,
$src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
ga_size size) {
LOCAL_MEM radix_t smem[32 * RADIX_SIZE];
ga_ssize LOCAL_MEM bins[RADIX_SIZE+1]; // TODO: does using 32-bit gives good speedup?
bool is_topk=true, is_topkth=true;
radix_t out_idx;
const ga_ushort idx = LID_0;
ga_size LOCAL_MEM k2, exceed;
const ga_ubyte warp_id = idx / GA_WARP_SIZE;
const ga_ubyte lane_id = idx % GA_WARP_SIZE;
const bool in_range = (idx < size);
is_topk &= in_range;
// 0. get the slice for thread block to work on
// TODO if ndim <= 3, use native indexing ? (blockIdx.[xyz])
ga_size gid = GID_0, gidx;
$set_slice
//for(int i=1; i<NDIM; i++) {
// gidx = gid % dims_$${i};
// gid /= dims_$${i};
// dsti = ptr_add(dsti, gidx*dsti_strides_$${i};
// dstv = ptr_add(dstv, gidx*dstv_strides_$${i};
// src = ptr_add(src, gidx*src_strides_$${i});
//}
// get input and its radix friendly form
const INPUT_TYPE xval = in_range ? ptr_at(src, idx*src_strides_0) : (INPUT_TYPE)0;
radix_t x = in_range ? RadixConfig<INPUT_TYPE>::convert(xval) : 0;
// resolve negative k
if (k<0) { x = ~x; k = -k; }
if (idx==0) {
k2 = k;
bins[RADIX_SIZE] = 1;
}
// 1. filter is_topk and is_topkth using radix select
#pragma unroll
for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
int digit = (x>>i) & (RADIX_SIZE-1);
// count within warp
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
bool incr_bin = (bin == digit) && is_topkth && in_range;
ga_uint incr_bin_warp = __ballot(incr_bin);
if (lane_id==0)
smem[bin + RADIX_SIZE*warp_id] = __popc(incr_bin_warp);
}
local_barrier();
// sum counts across all warps
// TODO: test in-block parallel sum?
if (idx < RADIX_SIZE) {
for(int w=RADIX_SIZE; w<LDIM_0*RADIX_SIZE / GA_WARP_SIZE; w+=RADIX_SIZE)
smem[idx] += smem[idx + w];
}
local_barrier();
// bins = k - cumsum(smem[:RADIX_SIZE])
if (idx == 0) {
bins[RADIX_SIZE-1] = k2 - smem[RADIX_SIZE-1];
if (bins[RADIX_SIZE-1] > 0)
k2 = bins[RADIX_SIZE-1];
#pragma unroll
for (int bin=RADIX_SIZE-1; bin; --bin) {
bins[bin-1] = bins[bin] - smem[bin-1];
if (bins[bin-1] > 0)
k2 = bins[bin-1];
}
}
local_barrier();
// smem -> count
// bins -> k2 - cumsum(count)
if (is_topk && is_topkth) {
ga_ssize icount = bins[digit];
if (icount > 0) {
is_topkth = false;
} else if (bins[digit+1] <= 0) {
is_topk = false;
is_topkth = false;
}
}
}
if (idx==0) {
#pragma unroll
for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
if (bins[bin] <= 0) {
exceed = -bins[bin];
break;
}
}
}
local_barrier();
// 2. find the index of output array, if exists
if (exceed != 0) {
// top_kth value may not be unique, so we need to
// perform binary cumsum on is_topkth to drop exceeding top-kth values
out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topkth);
is_topk &= ((!is_topkth) || out_idx>=exceed);
}
// perform binary cumsum on is_topk to determine the indices to put result
out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topk);
local_barrier();
if (is_topk) {
#if WRITE_VALUE == 1
ptr_at(dstv, out_idx * dstv_strides_0) = xval;
#endif
#if WRITE_INDEX == 1
ptr_at(dsti, out_idx * dsti_strides_0) = (INDEX_TYPE)idx;
#endif
}
}
// works when length on axis is larger than max allowed threads in block (1024)
KERNEL void k_topk_dense_large(
$dims
// ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM}
$dstv
// INPUT_TYPE *dstv
$dstv_strides
// ga_ssize dstv_strides_0, ga_ssize dstv_strides_1, ... , dstv_strides_$${NDIM}
$dsti
// INDEX_TYPE *dsti
$dsti_strides
// ga_ssize dsti_strides_0, ga_ssize dsti_strides_1, ... , dsti_strides_$${NDIM}
ga_ssize k,
INPUT_TYPE* src,
$src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
ga_size size, ga_ushort inp_per_thread) {
LOCAL_MEM radix_t smem[32 * RADIX_SIZE];
LOCAL_MEM radix_t known_bits, known_bits_mask;
radix_t out_idx;
ga_size LOCAL_MEM write_base;
INPUT_TYPE xval;
radix_t x;
ga_int i;
bool in_range, is_topk;
const ga_size idx = LID_0;
ga_size LOCAL_MEM k2;
const ga_ushort warp_id = idx / GA_WARP_SIZE;
const ga_ushort lane_id = idx % GA_WARP_SIZE;
// 0. get the slice for thread block to work on
// TODO if ndim <= 3, use native indexing ? (blockIdx.[xyz])
ga_size gid = GID_0, gidx;
$set_slice
//for(int i=1; i<NDIM; i++) {
// gidx = gid % dims_$${i};
// gid /= dims_$${i};
// dsti = ptr_add(dsti, gidx*dsti_strides_$${i};
// dstv = ptr_add(dstv, gidx*dstv_strides_$${i};
// src = ptr_add(src, gidx*src_strides_$${i});
//}
src = ptr_add(src, idx*inp_per_thread*src_strides_0);
LOCAL_MEM radix_t inv_bits;
if (idx==0) {
known_bits = known_bits_mask = 0;
k2 = abs(k);
inv_bits = (k>=0) ? 0 : (~0);
write_base = 0;
}
if (k<0) { k = -k; }
local_barrier();
// 1. find bits of top-k-th value using radix select
#pragma unroll
for (i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
/*for (i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i*=-1) {*/
if (lane_id == 0) {
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
smem[bin + warp_id*RADIX_SIZE] = 0;
}
}
local_barrier();
for (int j=0; j<inp_per_thread; ++j) {
in_range = (idx*inp_per_thread+j) < size;
xval = in_range ? ptr_read(src, j*src_strides_0) : (INPUT_TYPE)0;
x = inv_bits^RadixConfig<INPUT_TYPE>::convert(xval);
ga_int digit = (int)((x>>i) & (RADIX_SIZE-1));
// count within warp
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
bool incr_bin = (
(bin == digit) &&
((x&known_bits_mask) == known_bits) &&
in_range);
ga_uint incr_bin_warp = __ballot(incr_bin);
if (lane_id==0)
smem[bin + RADIX_SIZE*warp_id] += __popc(incr_bin_warp);
}
}
local_barrier();
// sum counts across all warps
// TODO: test in-block parallel sum?
if (idx < RADIX_SIZE) {
for(int w=RADIX_SIZE;
w<(LDIM_0/ GA_WARP_SIZE)*RADIX_SIZE;
w+=RADIX_SIZE)
smem[idx] += smem[idx + w];
}
local_barrier();
// update known bits
if (idx==0) {
#pragma unroll
for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
if (smem[bin] >= k2) {
known_bits |= (bin << i);
known_bits_mask |= ((RADIX_SIZE-1) << i);
break;
} else
k2 -= smem[bin];
}
}
local_barrier();
}
/*
if (idx < RADIX_SIZE) {
ptr_at(dstv, idx*dstv_strides_0) = known_bits;
ptr_at(dstv, idx*dstv_strides_0) = smem[idx];
}
return;
*/
// 2. write values smaller than top-kth
for (i=0; i<inp_per_thread; ++i) {
in_range = (idx*inp_per_thread+i) < size;
xval = in_range ? ptr_read(src, i*src_strides_0) : (INPUT_TYPE)0;
x = inv_bits ^ RadixConfig<INPUT_TYPE>::convert(xval);
is_topk = (x > known_bits) && in_range;
out_idx = binary_cumsum<radix_t>(idx, warp_id, lane_id, smem, is_topk);
if (is_topk) {
#if WRITE_VALUE == 1
ptr_at(dstv, (out_idx+write_base-1) * dstv_strides_0) = xval;
#endif
#if WRITE_INDEX == 1
ptr_at(dsti, (out_idx+write_base-1) * dsti_strides_0) = (INDEX_TYPE)(idx*inp_per_thread + i);
#endif
}
local_barrier();
if (idx == blockDim.x - 1)
write_base += out_idx;
local_barrier();
}
// 3. write values equal to top-kth
for (i=0; i<inp_per_thread; ++i) {
in_range = (idx*inp_per_thread+i) < size;
xval = in_range ? ptr_read(src, i*src_strides_0) : (INPUT_TYPE)0;
x = inv_bits ^ RadixConfig<INPUT_TYPE>::convert(xval);
is_topk = (x == known_bits) && in_range;
out_idx = binary_cumsum<radix_t>(idx, warp_id, lane_id, smem, is_topk);
is_topk = ((out_idx+write_base) <= abs(k)) && is_topk;
if (is_topk) {
#if WRITE_VALUE == 1
ptr_at(dstv, (out_idx+write_base-1) * dstv_strides_0) = xval;
#endif
#if WRITE_INDEX == 1
ptr_at(dsti, (out_idx+write_base-1) * dsti_strides_0) = (INDEX_TYPE)(idx*inp_per_thread + i);
#endif
}
local_barrier();
if (idx == blockDim.x - 1)
write_base += out_idx;
local_barrier();
if(write_base >= abs(k))
break;
}
}
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import os import os
from string import Template from string import Template
import pdb
import numpy as np
import theano import theano
from theano import Apply from theano import Apply
from theano.tensor import as_tensor_variable from theano.tensor import as_tensor_variable
...@@ -20,7 +22,6 @@ except ImportError as e: ...@@ -20,7 +22,6 @@ except ImportError as e:
pass pass
# TODO add support when slice size is larger than max allowed block size (1024)
# TODO add runtime opt, if k==1, use max/min reduce # TODO add runtime opt, if k==1, use max/min reduce
# TODO add opt to merge argtopk / topk, or split topk_and_argtopk when only # TODO add opt to merge argtopk / topk, or split topk_and_argtopk when only
# one result is needed # one result is needed
...@@ -33,12 +34,13 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -33,12 +34,13 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
''' '''
__props__ = TopKOp.__props__ __props__ = TopKOp.__props__
def __init__(self, axis=-1, return_indices=False, return_values=True): def __init__(self, axis=-1, return_values=True, return_indices=False, idx_dtype='int64'):
GpuKernelBase.__init__(self) GpuKernelBase.__init__(self)
TopKOp.__init__( TopKOp.__init__(
self, axis=axis, self, axis=axis,
return_values=return_values, return_values=return_values,
return_indices=return_indices) return_indices=return_indices,
idx_dtype=idx_dtype)
def c_headers(self): def c_headers(self):
return ['gpuarray_api.h', 'gpuarray_helper.h', 'numpy_compat.h'] return ['gpuarray_api.h', 'gpuarray_helper.h', 'numpy_compat.h']
...@@ -54,19 +56,23 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -54,19 +56,23 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
def gpu_kernels(self, node, nodename): def gpu_kernels(self, node, nodename):
# load kernel source # load kernel source
device_type = node.inputs[0].type.context.kind device_type = node.inputs[0].type.context.kind
knames = ['k_topk_dense', 'k_topk_dense_large']
kernel_ext = {b'cuda':'.cu', b'opencl':'.cl'}[device_type] kernel_ext = {b'cuda':'.cu', b'opencl':'.cl'}[device_type]
try: common_ext = {b'cuda':'.cuh', b'opencl':'.h'}[device_type]
kernel_filename = 'topk_kernel%s' % kernel_ext kernel_src = {}
for kname in knames:
with open(os.path.join( with open(os.path.join(
os.path.dirname(__file__), kernel_filename os.path.dirname(__file__), kname + kernel_ext
), 'r') as f: ), 'r') as f:
kernel_src = f.read() kernel_src[kname] = f.read()
except FileNotFoundError:
raise RuntimeError( with open(os.path.join(
'Cannot find GPU kernel ' os.path.dirname(__file__), 'k_topk_common' + common_ext
'implementation for device "%s"' % device_type) ), 'r') as f:
common_src = f.read()
# prepare "$" macros # prepare "$" macros
if device_type == b'cuda':
ndim = node.inputs[0].ndim ndim = node.inputs[0].ndim
dstv_strides_code = ''.join('ga_ssize dstv_strides_%d, ' % i for i in range(ndim)) dstv_strides_code = ''.join('ga_ssize dstv_strides_%d, ' % i for i in range(ndim))
dsti_strides_code = ''.join('ga_ssize dsti_strides_%d, ' % i for i in range(ndim)) dsti_strides_code = ''.join('ga_ssize dsti_strides_%d, ' % i for i in range(ndim))
...@@ -84,7 +90,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -84,7 +90,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
flags = Kernel.get_flags(node.inputs[0].dtype) flags = Kernel.get_flags(node.inputs[0].dtype)
subs = dict( subs = dict(
inp_t=ga.dtype_to_ctype(node.inputs[0].dtype), inp_t=ga.dtype_to_ctype(node.inputs[0].dtype),
out_t=ga.dtype_to_ctype(node.outputs[0].dtype), out_t=ga.dtype_to_ctype(self.idx_dtype),
dims=''.join('ga_size dims_%d, ' % i for i in range(1, ndim)), dims=''.join('ga_size dims_%d, ' % i for i in range(1, ndim)),
dstv='INPUT_TYPE *dstv,' if self.return_values else '', dstv='INPUT_TYPE *dstv,' if self.return_values else '',
dsti='INDEX_TYPE *dsti,' if self.return_indices else '', dsti='INDEX_TYPE *dsti,' if self.return_indices else '',
...@@ -95,11 +101,11 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -95,11 +101,11 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
write_value=int(self.return_values), write_value=int(self.return_values),
write_index=int(self.return_indices), write_index=int(self.return_indices),
ndim=str(ndim)) ndim=str(ndim))
elif device_type == b'opencl':
raise NotImplementedError()
# substitute "$" macros in kernel code # compile kernels
kernel_src = Template(kernel_src).substitute(**subs) kernels = []
# compile kernel
param_types = [ga.SIZE] * (ndim - 1) # dims param_types = [ga.SIZE] * (ndim - 1) # dims
for _ in range(int(self.return_values) + int(self.return_indices)): for _ in range(int(self.return_values) + int(self.return_indices)):
param_types.append(ga.GpuArray) # dst* param_types.append(ga.GpuArray) # dst*
...@@ -108,31 +114,39 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -108,31 +114,39 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
param_types.append(ga.GpuArray) # src param_types.append(ga.GpuArray) # src
param_types.extend([ga.SSIZE] * ndim) # src_strides param_types.extend([ga.SSIZE] * ndim) # src_strides
param_types.append(ga.SIZE) # size param_types.append(ga.SIZE) # size
self.nargs = len(param_types) kernels.append(Kernel(
return [Kernel( code=Template(common_src + kernel_src['k_topk_dense']).substitute(**subs),
code=kernel_src,
name='k_topk_dense', name='k_topk_dense',
params=param_types, params=param_types,
flags=flags, flags=flags,
objvar='k_topk_dense_' + nodename objvar='k_topk_dense_' + nodename
)] ))
param_types.append(np.uint16) # inp_per_thread
kernels.append(Kernel(
code=Template(common_src + kernel_src['k_topk_dense_large']).substitute(**subs),
name='k_topk_dense_large',
params=param_types,
flags=flags,
objvar='k_topk_dense_large_' + nodename
))
return kernels
def c_code(self, node, nodename, inps, outs, sub): def c_code(self, node, nodename, inps, outs, sub):
if node.inputs[0].type.context.kind != b'cuda': if node.inputs[0].type.context.kind != b'cuda':
raise NotImplementedError('We only have CUDA implementation so far.') raise NotImplementedError(
'%s: We only have CUDA '
'implementation so far.' % self.__class__.__name__)
x, k = inps x, k = inps
inp_dtc = pygpu.dtypes.dtype_to_ctype(node.inputs[0].dtype).upper() inp_dtc = ga.dtype_to_typecode(node.inputs[0].dtype)
if not self.return_indices: if not self.return_indices:
yv, = outs yv, = outs
out_dtype_s = ''
out_dtc = ''
else: else:
if self.return_values: if self.return_values:
yv, yi = outs yv, yi = outs
else: else:
yi, = outs yi, = outs
out_dtype_s = node.outputs[0].dtype out_dtype_s = self.idx_dtype
out_dtc = pygpu.dtypes.dtype_to_ctype(out_dtype_s).upper() out_dtc = ga.dtype_to_typecode(out_dtype_s)
fail = sub['fail'] fail = sub['fail']
ctx = sub['params'] ctx = sub['params']
k_dtype = node.inputs[1].type.dtype_specs()[1] k_dtype = node.inputs[1].type.dtype_specs()[1]
...@@ -140,7 +154,6 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -140,7 +154,6 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
WARP_SIZE = 32 WARP_SIZE = 32
ndim = node.inputs[0].ndim ndim = node.inputs[0].ndim
nargs = self.nargs
reordered_axes = list(range(ndim)) reordered_axes = list(range(ndim))
axis = self.axis % ndim axis = self.axis % ndim
del(reordered_axes[axis]) del(reordered_axes[axis])
...@@ -175,16 +188,21 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -175,16 +188,21 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
sstrides = ', '.join('(void*)(sstrides+%d)' % i for i in reordered_axes) sstrides = ', '.join('(void*)(sstrides+%d)' % i for i in reordered_axes)
code = ''' code = '''
{ {
const ssize_t k_ = ((%(k_dtype)s*)(PyArray_DATA(%(k)s)))[0];
const size_t *dims = PyGpuArray_DIMS(%(x)s); const size_t *dims = PyGpuArray_DIMS(%(x)s);
size_t odims[%(ndim)d]; size_t odims[%(ndim)d];
for (int i=0; i<%(ndim)d; i++) for (int i=0; i<%(ndim)d; i++)
odims[i] = dims[i]; odims[i] = dims[i];
odims[%(axis)d] = *((%(k_dtype)s*)(PyArray_DATA(%(k)s)));
if (odims[0] > %(MAX_TPB)d) {
odims[%(axis)d] = k_>=0 ? k_ : -k_;
if (0 == odims[%(axis)d]) {
PyErr_SetString( PyErr_SetString(
PyExc_ValueError, PyExc_ValueError,
"topk: slice size larger than %(MAX_TPB)d is not supported"); "topk: k must not be zero");
%(fail)s; } %(fail)s;
}
%(prep_output)s %(prep_output)s
// TODO better scheduling? // TODO better scheduling?
...@@ -192,32 +210,45 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -192,32 +210,45 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
size_t *grd = blk+3; size_t *grd = blk+3;
blk[0] = blk[1] = blk[2] = 1; blk[0] = blk[1] = blk[2] = 1;
grd[0] = grd[1] = grd[2] = 1; grd[0] = grd[1] = grd[2] = 1;
// round up to multiples of warp size
for(int i=0; i<%(ndim)d; ++i) { for(int i=0; i<%(ndim)d; ++i) {
if (i!=%(axis)d) if (i!=%(axis)d)
grd[0] *= dims[i]; grd[0] *= dims[i];
else else
blk[0] = dims[i]; blk[0] = dims[i];
} }
// round up to multiples of warp size
blk[0] = ((blk[0] + %(WARP_SIZE)d - 1) / %(WARP_SIZE)d) * %(WARP_SIZE)d; blk[0] = ((blk[0] + %(WARP_SIZE)d - 1) / %(WARP_SIZE)d) * %(WARP_SIZE)d;
%(def_dvstrides)s; %(def_dvstrides)s;
%(def_distrides)s; %(def_distrides)s;
const ssize_t *sstrides = PyGpuArray_STRIDES(%(x)s); const ssize_t *sstrides = PyGpuArray_STRIDES(%(x)s);
// inputs per thread
unsigned short ipt = (dims[%(axis)d] + (%(MAX_TPB)d/2)-1) / (%(MAX_TPB)d/2);
void* args[] = { void* args[] = {
%(dims)s %(dims)s
%(params_dv)s %(params_dv)s
%(params_di)s %(params_di)s
(void*)(odims+%(axis)d), (void*)(&k_),
(void*)(%(x)s->ga.data), (void*)(%(x)s->ga.data),
%(sstrides)s, %(sstrides)s,
(void*)(dims+%(axis)d) (void*)(dims+%(axis)d),
(void*)(&ipt)
}; };
int err = GpuKernel_call( int err;
if (blk[0] > %(MAX_TPB)d) {
// CUDA_OUT_OF_RESOURCE if a max sized block is used
blk[0] = %(MAX_TPB)d / 2;
err = GpuKernel_call(
&k_topk_dense_large_%(nodename)s, 3,
grd, blk, 0,
args);
} else {
err = GpuKernel_call(
&k_topk_dense_%(nodename)s, 3, &k_topk_dense_%(nodename)s, 3,
grd, blk, 0, grd, blk, 0,
args); args);
}
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_SetString( PyErr_SetString(
PyExc_RuntimeError, PyExc_RuntimeError,
...@@ -228,37 +259,39 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -228,37 +259,39 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
''' '''
return code % locals() return code % locals()
def make_node(self, inp, k, idx_dtype='int64'): def make_node(self, inp, k):
ctx_name = infer_context_name(inp) ctx_name = infer_context_name(inp)
inp = as_gpuarray_variable(inp, ctx_name) inp = as_gpuarray_variable(inp, ctx_name)
k = as_tensor_variable(k) k = as_tensor_variable(k)
bcast = inp.type.broadcastable bcast = inp.type.broadcastable
outs = [] outs = []
if self.return_values:
outs.append(inp.type())
if self.return_indices: if self.return_indices:
outs.append(GpuArrayType( outs.append(GpuArrayType(
dtype=idx_dtype, dtype=self.idx_dtype,
broadcastable=bcast, broadcastable=bcast,
context_name=ctx_name)()) context_name=ctx_name)())
if self.return_values:
outs.append(inp.type())
return Apply(self, [inp, k], outs) return Apply(self, [inp, k], outs)
def get_params(self, node): def get_params(self, node):
return node.inputs[0].type.context return node.inputs[0].type.context
# def get_op_params(self):
# return [('AXIS', self.axis)]
@register_opt('fast_compile') @register_opt('fast_compile')
@op_lifter([TopKOp]) @op_lifter([TopKOp])
@register_opt2([TopKOp], 'fast_compile') @register_opt2([TopKOp], 'fast_compile')
def local_gpua_topkop(op, ctx_name, inputs, outputs): def local_gpua_topkop(op, ctx_name, inputs, outputs):
if isinstance(op, GpuTopKOp):
return False
axis = op.axis axis = op.axis
rv = op.return_values rv = op.return_values
ri = op.return_indices ri = op.return_indices
x, k = inputs x, k = inputs
x = as_gpuarray_variable(x, ctx_name) x = as_gpuarray_variable(x, ctx_name)
y = outputs[-1] rets = GpuTopKOp(
return GpuTopKOp( axis=axis, return_values=rv, return_indices=ri, idx_dtype=op.idx_dtype)(x, k)
axis=axis, return_values=rv, return_indices=ri)(x, k, idx_dtype=y.dtype) return rets
...@@ -342,20 +342,21 @@ class TopKOp(theano.Op): ...@@ -342,20 +342,21 @@ class TopKOp(theano.Op):
# TODO c_code # TODO c_code
__props__ = ('axis', 'return_values', 'return_indices') __props__ = ('axis', 'return_values', 'return_indices', 'idx_dtype')
def __init__(self, axis=-1, return_indices=False, return_values=True): def __init__(self, axis=-1, return_indices=False, return_values=True, idx_dtype='int64'):
assert isinstance(axis, int) assert isinstance(axis, int)
assert return_indices or return_values assert return_indices or return_values
self.axis = axis self.axis = axis
self.return_indices = return_indices self.return_indices = return_indices
self.return_values = return_values self.return_values = return_values
self.idx_dtype = idx_dtype
def __str__(self): def __str__(self):
return '%(op)s{axis=%(axis)d}' % dict( return '%(op)s{axis=%(axis)d}' % dict(
op=self.__class__.__name__, axis=self.axis) op=self.__class__.__name__, axis=self.axis)
def make_node(self, inp, k, idx_dtype='int64'): def make_node(self, inp, k):
# numpy always uses float64 as output dtype for arg*() routines # numpy always uses float64 as output dtype for arg*() routines
# however, we add this option as memory is more precious on gpu # however, we add this option as memory is more precious on gpu
inp = theano.tensor.as_tensor_variable(inp) inp = theano.tensor.as_tensor_variable(inp)
...@@ -366,7 +367,7 @@ class TopKOp(theano.Op): ...@@ -366,7 +367,7 @@ class TopKOp(theano.Op):
outs.append(inp.type()) outs.append(inp.type())
if self.return_indices: if self.return_indices:
outs.append( outs.append(
theano.tensor.TensorType(dtype=idx_dtype, broadcastable=bcast)()) theano.tensor.TensorType(dtype=self.idx_dtype, broadcastable=bcast)())
return theano.Apply(self, [inp, k], outs) return theano.Apply(self, [inp, k], outs)
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
...@@ -458,18 +459,18 @@ def argtopk(x, k, axis=-1, idx_dtype='int64'): ...@@ -458,18 +459,18 @@ def argtopk(x, k, axis=-1, idx_dtype='int64'):
if axis is None: if axis is None:
x = theano.tensor.flatten(x) x = theano.tensor.flatten(x)
axis = -1 axis = -1
return TopKOp(axis=axis, return_indices=True, return_values=False)(x, k, idx_dtype=idx_dtype) return TopKOp(axis=axis, return_indices=True, return_values=False, idx_dtype=idx_dtype)(x, k)
def topk_and_argtopk(x, k, axis=-1, idx_dtype='int64'): def topk_and_argtopk(x, k, axis=-1, idx_dtype='int64'):
''' """
Returns the results of both topk() and argtopk() in one Op. Returns the results of both topk() and argtopk() in one Op.
See the respective documentation for details. See the respective documentation for details.
''' """
if axis is None: if axis is None:
x = theano.tensor.flatten(x) x = theano.tensor.flatten(x)
axis = -1 axis = -1
return TopKOp(axis=axis, return_indices=True)(x, k, idx_dtype=idx_dtype) return TopKOp(axis=axis, return_indices=True, idx_dtype=idx_dtype)(x, k)
...@@ -21,14 +21,12 @@ _int_dtypes = ( ...@@ -21,14 +21,12 @@ _int_dtypes = (
'int8', 'int16', 'int32', 'int64', 'int8', 'int16', 'int32', 'int64',
'uint8', 'uint16', 'uint32', 'uint64') 'uint8', 'uint16', 'uint32', 'uint64')
def gen_unique_vector(size, dtype): def gen_unique_vector(size, dtype):
# generate a randomized vector with unique elements # generate a randomized vector with unique elements
retval = np.arange(size*3) + np.random.uniform(-1., 1.) retval = np.arange(size*3) + np.random.uniform(-1., 1.)
return (retval[np.random.permutation(size)] - size*1.5).astype(dtype) return (retval[np.random.permutation(size)] - size*1.5).astype(dtype)
'''
class Test_sort(unittest.TestCase): class Test_sort(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -236,7 +234,6 @@ def test_argsort_grad(): ...@@ -236,7 +234,6 @@ def test_argsort_grad():
data = np.random.rand(2, 3, 3).astype(theano.config.floatX) data = np.random.rand(2, 3, 3).astype(theano.config.floatX)
utt.verify_grad(lambda x: argsort(x, axis=2), [data]) utt.verify_grad(lambda x: argsort(x, axis=2), [data])
'''
class Test_TopK(unittest.TestCase): class Test_TopK(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论