提交 95f6eda6 authored 作者: Adam Becker's avatar Adam Becker

mixed changes

- add multidim support for top_k - use unified TopKOp, can implement topk, argtopk, and both
上级 ece8b25b
from __future__ import absolute_import, print_function, division
import os
from string import Template
import theano
from theano import Apply
from theano.tensor import as_tensor_variable
from theano.tensor.sort import ArgTopKOp
from theano.tensor.sort import TopKOp
from .basic_ops import (GpuKernelBase, Kernel, infer_context_name,
as_gpuarray_variable, gpu_contiguous)
from .opt import register_opt, op_lifter, register_opt2
from .type import GpuArrayType
try:
import pygpu
import pygpu.gpuarray as ga
except ImportError as e:
# To make sure theano is importable
pass
class GpuSortOp(object):
# TODO
pass
class GpuArgSortOp(object):
# TODO
pass
# TODO add support is slice size is larger than max allowed block size (1024)
# TODO add runtime opt, if k==1, use max/min reduce
# TODO sort / argsort
class GpuArgTopKOp(ArgTopKOp, GpuKernelBase):
class GpuTopKOp(GpuKernelBase, TopKOp):
'''
implement argtopk() on gpu
Implements TopKOp() on gpu
'''
__props__ = ArgTopKOp.__props__
def __init__(self, axis=-1):
ArgTopKOp.__init__(self, axis=axis)
__props__ = TopKOp.__props__
def __init__(self, axis=-1, return_indices=False, return_values=True):
GpuKernelBase.__init__(self)
TopKOp.__init__(
self, axis=axis,
return_values=return_values,
return_indices=return_indices)
def c_headers(self):
return ['gpuarray_api.h', 'gpuarray_helper.h', 'numpy_compat.h']
......@@ -40,67 +43,180 @@ class GpuArgTopKOp(ArgTopKOp, GpuKernelBase):
def c_header_dirs(self):
return [os.path.dirname(__file__), pygpu.get_include()]
'''
def c_code_cache_version(self):
return (1,)
'''
def gpu_kernels(self, node, nodename):
device_type = str(node.inputs[0].type.context.kind)
kernel_ext = dict(cuda='.cu', opencl='.cl')[device_type]
flags = Kernel.get_flags(node.inputs[0].dtype)
# load kernel source
device_type = node.inputs[0].type.context.kind
kernel_ext = {b'cuda':'.cu', b'opencl':'.cl'}[device_type]
try:
kernel_filename = 'topk_kernel%s' % kernel_ext
with open(os.path.join(
os.path.dirname(__file__), kernel_filename
)) as f:
), 'r') as f:
kernel_src = f.read()
except FileNotFoundError:
raise RuntimeError(
'Cannot find GPU kernel '
'implementation for device "%s"' % device_type)
return [Kernel(
kernel_src,
params='TODO_params',
name='topk_kernel',
flags=flags,
)]
# prepare "$" macros
ndim = node.inputs[0].ndim
dstv_strides_code = ''.join('ga_ssize dstv_strides_%d, ' % i for i in range(ndim))
dsti_strides_code = ''.join('ga_ssize dsti_strides_%d, ' % i for i in range(ndim))
src_strides_code = ''.join('ga_ssize src_strides_%d, ' % i for i in range(ndim))
set_slice_code = '''
gidx = gid %% dims_%(i)d;
gid /= dims_%(i)d;
{dstv};
{dsti};
src = ptr_add(src, gidx*src_strides_%(i)d);\n'''.format(
dstv='dstv = ptr_add(dstv, gidx*dstv_strides_%(i)d)' if self.return_values else '',
dsti='dsti = ptr_add(dsti, gidx*dsti_strides_%(i)d)' if self.return_indices else '')
set_slice_code = ''.join(
set_slice_code % dict(i=j) for j in range(1, ndim))
flags = Kernel.get_flags(node.inputs[0].dtype)
dst = ''
if self.return_values:
dst += 'INPUT_TYPE *dstv, '
if self.return_values:
dst += 'INDEX_TYPE *dsti, '
write_value = 'ptr_at(dstv, out_idx * dstv_strides_0) = xval' if self.return_values else ''
write_index = 'ptr_at(dsti, out_idx * dsti_strides_0) = (INDEX_TYPE)idx' if self.return_indices else ''
subs = dict(
inp_t=ga.dtype_to_ctype(node.inputs[0].dtype),
out_t=ga.dtype_to_ctype(node.outputs[0].dtype),
dims=''.join('ga_size dims_%d, ' % i for i in range(1, ndim)),
dstv='INPUT_TYPE *dstv,' if self.return_values else '',
dsti='INDEX_TYPE *dsti,' if self.return_indices else '',
dstv_strides=dstv_strides_code,
dsti_strides=dsti_strides_code,
src_strides=src_strides_code,
set_slice=set_slice_code,
write_value=write_value,
write_index=write_index,
ndim=str(ndim))
# substitute "$" macros in kernel code
kernel_src = Template(kernel_src).substitute(**subs)
# compile kernel
param_types = [ga.SIZE] * (ndim - 1) # dims
for _ in range(int(self.return_values) + int(self.return_indices)):
param_types.append(ga.GpuArray) # dst*
param_types.extend([ga.SSIZE] * ndim) # dst*_strides
param_types.append(ga.SIZE) # k
param_types.append(ga.GpuArray) # src
param_types.extend([ga.SSIZE] * ndim) # src_strides
param_types.append(ga.SIZE) # size
return [Kernel(
code=kernel_src,
name='k_topk_dense',
params=param_types,
flags=flags,
objvar='k_topk_dense_' + nodename
)]
def c_code(self, node, nodename, inps, outs, sub):
if node.inputs[0].type.context.kind != b'cuda':
raise NotImplementedError('We only have CUDA implementation so far.')
x, k = inps
y, = outs
inp_dtc = pygpu.dtypes.dtype_to_ctype(node.inputs[0].dtype).upper()
if not self.return_indices:
yv, = outs
out_dtype_s = ''
out_dtc = ''
else:
if self.return_values:
yv, yi = outs
else:
yi, = outs
out_dtype_s = node.outputs[0].dtype
out_dtc = pygpu.dtypes.dtype_to_ctype(out_dtype_s).upper()
fail = sub['fail']
ctx = sub['params']
out_dtype = pygpu.dtypes.dtype_to_ctype(self.out_dtype).upper()
k_dtype = node.inputs[1].type.dtype_specs()[1]
MAX_TPB = 1024 # max thread per block
WARP_SIZE = 32
ndim = node.inputs[0].ndim
reordered_axes = list(range(ndim))
axis = self.axis % ndim
del(reordered_axes[axis])
reordered_axes = [axis] + reordered_axes
dims = ', '.join('(void*)(dims+%d)' % i for i in reordered_axes[1:])
prep_output = ''
if self.return_values:
def_dvstrides = 'const ssize_t *dvstrides = PyGpuArray_STRIDES(%s)' % yv
params_dv = '(void*)(%s->ga.data),\n' % yv
params_dv += ''.join('(void*)(dvstrides+%d), ' % i for i in reordered_axes)
prep_output += '''
if (0 != theano_prep_output(
&%(yv)s, %(ndim)d, odims,
%(inp_dtc)s, GA_C_ORDER, %(ctx)s)) {
%(fail)s;
}\n''' % locals()
else:
def_dvstrides = params_dv = ''
if self.return_indices:
def_distrides = 'const ssize_t *distrides = PyGpuArray_STRIDES(%s)' % yi
params_di = '(void*)(%s->ga.data),\n' % yi
params_di += ''.join('(void*)(distrides+%d), ' % i for i in reordered_axes)
prep_output += '''
if (0 != theano_prep_output(
&%(yi)s, %(ndim)d, odims,
%(out_dtc)s, GA_C_ORDER, %(ctx)s)) {
%(fail)s;
}\n''' % locals()
else:
def_distrides = params_di = ''
sstrides = ', '.join('(void*)(sstrides+%d)' % i for i in reordered_axes)
code = '''
{
// prepare output
const size_t *dims = PyGpuArray_DIMS(%(x)s);
const size_t *odims[1] = {*((%(out_dtype)s)PyArray_DATA(%(k)s))};
size_t odims[%(ndim)d];
for (int i=0; i<%(ndim)d; i++) {
odims[i] = dims[i];
}
odims[%(axis)d] = *((%(k_dtype)s*)(PyArray_DATA(%(k)s)));
if (odims[0] > %(MAX_TPB)d) {
PyErr_SetString(
PyExc_ValueError,
"topk: slice size larger than %(MAX_TPB)d is not supported");
%(fail)s; }
if (0 != theano_prep_output(
&%(y)s, 1, odims,
%(out_dtype)s, GA_C_ORDER, %(ctx)s)) {
%(fail)s;
}
size_t blk[6] = ;
size_t grd = blk+3;
%(prep_output)s
// TODO better scheduling?
size_t blk[6];
size_t *grd = blk+3;
blk[1] = blk[2] = 1;
grd[0] = grd[1] = grd[2] = 1;
// round up to multiples of warp size
blk[0] = (dims[0] + (%(WARP_SIZE)d - 1) / %(WARP_SIZE)d) * %(WARP_SIZE)d;
blk[0] = ((dims[0] + %(WARP_SIZE)d - 1) / %(WARP_SIZE)d) * %(WARP_SIZE)d;
for(int i=0; i<%(ndim)d; ++i) {
if (i!=%(axis)d)
grd[0] *= dims[i];
}
%(def_dvstrides)s;
%(def_distrides)s;
const ssize_t *sstrides = PyGpuArray_STRIDES(%(x)s);
void* args[] = {
((void*)(%(y)s->ga.data)),
((void*)(%(x)s->ga.data)),
(void*)dims, (void*)odims
%(dims)s
%(params_dv)s
%(params_di)s
(void*)(odims+%(axis)d),
(void*)(%(x)s->ga.data),
%(sstrides)s,
(void*)(dims+%(axis)d)
};
int err = GpuKernel_call(
&topk_kernel, 3,
&k_topk_dense_%(nodename)s, 3,
grd, blk,
blk[0] * gpuarray_get_elsize(%(x)s->ga.typecode),
args);
......@@ -114,21 +230,37 @@ class GpuArgTopKOp(ArgTopKOp, GpuKernelBase):
'''
return code % locals()
def make_node(self, inp, k, out_dtype='int64'):
def make_node(self, inp, k, idx_dtype='int64'):
ctx_name = infer_context_name(inp)
inp = as_gpuarray_variable(inp, ctx_name)
k = as_tensor_variable(k)
bcast = inp.type.broadcastable
return Apply(
self, [inp, k],
[GpuArrayType(
dtype=out_dtype,
outs = []
if self.return_indices:
outs.append(GpuArrayType(
dtype=idx_dtype,
broadcastable=bcast,
context_name=ctx_name)()])
context_name=ctx_name)())
if self.return_values:
outs.append(inp.type())
return Apply(self, [inp, k], outs)
def get_params(self, node):
return node.inputs[0].type.context
def get_op_params(self):
return [('AXIS', self.axis)]
# def get_op_params(self):
# return [('AXIS', self.axis)]
@register_opt('fast_compile')
@op_lifter([TopKOp])
@register_opt2([TopKOp], 'fast_compile')
def local_gpua_topkop(op, ctx_name, inputs, outputs):
axis = op.axis
rv = op.return_values
ri = op.return_indices
x, k = inputs
x = as_gpuarray_variable(x, ctx_name)
y = outputs[-1]
return GpuTopKOp(
axis=axis, return_values=rv, return_indices=ri)(x, k, idx_dtype=y.dtype)
......@@ -113,6 +113,7 @@ struct RadixConfig<double> {
}
};
#ifdef USE_HALF
template <>
struct RadixConfig<half> {
typedef unsigned int RadixType;
......@@ -138,23 +139,29 @@ struct RadixConfig<half> {
#endif
}
};
#endif
// $$inp_t should be replaced in c_code
// we cannot use templated __global__ because gpuarray API does not support it yet
#define NDIM $ndim
#define INPUT_TYPE $inp_t
#define INDEX_TYPE $out_t
#define bitsof(T) (sizeof(T)*8)
#define RADIX_BITS 2
#define RADIX_SIZE (1<<RADIX_BITS)
#define RADIX_MASK(n) ((RADIX_SIZE-1) << (n*RADIX_BITS))
#define RADIX_DIGITS(T) (bitsof(T)/RADIX_BITS)
#define radix_t RadixConfig<T>::RadixType
#define radix_t RadixConfig<INPUT_TYPE>::RadixType
#if RADIX_SIZE > 32
#error "RADIX_SIZE must be smaller than warp size (32)"
#endif
template <typename T>
inline __device__ T binary_cumsum(int idx, int warp_id, int lane_id, T* smem, bool value) {
static inline __device__ T binary_cumsum(int idx, int warp_id, int lane_id, T* smem, bool value) {
// cumsum within 1D thread block, which adds up `value` of all threads whose id is *no greater than* the current thread
// cumsum within warp
unsigned int warp_bits = __ballot(in);
unsigned int warp_bits = __ballot(value);
T warp_sum = __popc(((2<<lane_id)-1) & warp_bits);
if (lane_id == 0)
......@@ -175,20 +182,21 @@ inline __device__ T binary_cumsum(int idx, int warp_id, int lane_id, T* smem, bo
__syncthreads();
// load the carry from the preceding warp
if (warp >= 1) {
warp_sum = warp_sum+smem[warp - 1];
if (warp_id >= 1) {
warp_sum = warp_sum+smem[warp_id - 1];
}
return warp_sum;
}
template <typename T>
inline __device__ T binary_cumsum_exclusive(
static inline __device__ T binary_cumsum_exclusive(
int idx, int warp_id, int lane_id, T* smem, bool value) {
// cumsum within 1D thread block, which adds up `value` of all threads
// whose id is *less than* the current thread
// cumsum within warp
unsigned int warp_bits = __ballot(in);
unsigned int warp_bits = __ballot(value);
T warp_sum = __popc(((1<<lane_id)-1) & warp_bits);
if (lane_id == 0)
......@@ -209,35 +217,77 @@ inline __device__ T binary_cumsum_exclusive(
__syncthreads();
// load the carry from the preceding warp
if (warp >= 1) {
warp_sum = warp_sum+smem[warp - 1];
}
if (warp_id >= 1)
warp_sum += smem[warp_id - 1];
return warp_sum;
}
// apply raw(byte) offset to pointer
template <typename T>
static __device__ inline T* ptr_add(T *ptr, ga_ssize offset) {
return (T*)((char*)ptr + offset);
}
// get array element using raw(byte) offset
template <typename T>
void __global__ topk_1d_contig_kernel(T* dst, T* src, size_t size, size_t k) {
extern radix_t smem[];
ssize_t bins[RADIX_SIZE]; // TODO: does using 32-bit gives speedup?
static __device__ inline T& ptr_at(T *ptr, ga_ssize offset) {
return *((T*)((char*)ptr + offset));
}
KERNEL void k_topk_dense(
$dims
// ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM}
$dstv
// INPUT_TYPE *dstv
$dstv_strides
// ga_ssize dstv_strides_0, ga_ssize dstv_strides_1, ... , dstv_strides_$${NDIM}
$dsti
// INDEX_TYPE *dsti
$dsti_strides
// ga_ssize dsti_strides_0, ga_ssize dsti_strides_1, ... , dsti_strides_$${NDIM}
ga_ssize k,
INPUT_TYPE* src,
$src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
size_t size) {
/*
extern __shared__ radix_t smem[];
ga_ssize __shared__ bins[RADIX_SIZE]; // TODO: does using 32-bit gives speedup?
bool is_topk = true;
bool is_topkth = true; // exactly k-th largest
radix_t out_idx;
const size_t idx = threadIdx.x;
size_t __shared__ k2, exceed;
const ga_uint warp_id = idx / 32;
const ga_uint lane_id = idx % 32;
radix_t *wmem = (radix_t*)(smem) + warp_id * 32;
const bool in_range = (idx < size);
is_topk &= in_range;
const INPUT_TYPE xval = in_range ? ptr_at(src, idx*src_strides_0) : (INPUT_TYPE)0;
radix_t x = in_range ? RadixConfig<INPUT_TYPE>::convert(xval) : 0;
// resolve negative k
if (k<0) { x = ~x; k = -k; }
if (idx==0) k2 = k;
// 0. get the slice for thread block to work on
size_t gid = blockIdx.x, gidx;
$set_slice
//for(int i=0; i<NDIM; i++) {
//gidx = gid % dims_$${i};
//gid /= dims_$${i};
//dsti = ptr_add(dsti, gidx*dsti_strides_$${i+1};
//dstv = ptr_add(dstv, gidx*dstv_strides_$${i+1};
//src = ptr_add(src, gidx*src_strides_$${i+1});
//}
// 1. filter is_topk and is_topkth using radix select
size_t idx = threadIdx.x;
size_t k2 = k, exceed;
int warp_id = idx / 32;
int lane_id = idx % 32;
radix_t wmem = smem + warp_id * 32;
bool in_range = (idx < size);
RadixConfig<T>::RadixType x = in_range ? RadixConfig<T>::convert(src[idx]) : 0;
// 1. find the kth largest value using radix select
// 1.1 for each radix mask, count
smem[threadIdx.x] = 0;
#pragma unroll
for (int i=bitsof(T)-RADIX_BITS; i; i-=RADIX_BITS) {
radix_t mask = (RADIX_SIZE-1)<<i;
for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
smem[idx] = 0;
int digit = (x>>i) & (RADIX_SIZE-1);
// count within warp
#pragma unroll
......@@ -245,43 +295,34 @@ void __global__ topk_1d_contig_kernel(T* dst, T* src, size_t size, size_t k) {
bool incr_bin = (bin == digit) && is_topkth && in_range;
unsigned int incr_bin_warp = __ballot(incr_bin);
if (lane_id==0)
wmem[bin] += __popc(bin_warp);
wmem[bin] += __popc(incr_bin_warp);
}
__syncthreads();
// sum counts across all warps
// TODO: test in-block parallel sum?
if (idx<RADIX_SIZE)
bins[idx] = 0;
if (idx==0) {
for(int w=1; w<blockDim.x/32; ++w) {
#pragma unroll
for(int bin=0; bin<RADIX_SIZE; ++bin) {
smem[bin] += wmem[bin];
}
}
if (idx < RADIX_SIZE) {
for(int w=32; w<blockDim.x; w+=32)
smem[idx] += smem[idx + w];
}
__syncthreads();
// broadcast sum result
if (idx < RADIX_SIZE)
smem[idx] = bins[idx];
__syncthreads();
// calculate k minus cumsum(count)
exceed = -k; // how many the number of is_topk exceeds k
if (idx<RADIX_SIZE)
bins[idx] = 0;
if (idx == 0) {
bins[0] = k2 - smem[0];
if (bins[0] > 0)
k2 = bins[0];
else if (bins[0] < 0)
exceed = max(exceed, bins[0]);
exceed = k; // how many the number of is_topk exceeds k
bins[RADIX_SIZE-1] = k2 - smem[RADIX_SIZE-1];
if (bins[RADIX_SIZE-1] > 0)
k2 = bins[RADIX_SIZE-1];
else
exceed = min(exceed, bins[RADIX_SIZE-1]);
#pragma unroll
for(int bin=1; bin<RADIX_SIZE; ++bin) {
bins[bin] = bins[bin-1] - smem[bin];
if (bins[bin] > 0)
k2 = bins[bin];
else if (bins[bin] < 0)
exceed = max(exceed, bins[bin]);
for(int bin=RADIX_SIZE-1; bin; --bin) {
bins[bin-1] = bins[bin] - smem[bin-1];
if (bins[bin-1] > 0)
k2 = bins[bin-1];
else
exceed = min(exceed, bins[bin-1]);
}
}
__syncthreads();
......@@ -290,7 +331,7 @@ void __global__ topk_1d_contig_kernel(T* dst, T* src, size_t size, size_t k) {
// smem -> count
// bins -> k2 - cumsum(count)
if (is_topk && is_topkth) {
ssize_t icount = bins[digit];
ga_ssize icount = bins[digit];
if (icount > 0) {
is_topkth = false;
} else if (icount < 0) {
......@@ -305,17 +346,23 @@ void __global__ topk_1d_contig_kernel(T* dst, T* src, size_t size, size_t k) {
}
// 2. find the index of output array, if exists
//
// top_kth value may not be unique, so we need to
// count how many is needed
// perform binary cumsum on is_topkth to drop exceeding top-kth values
radix_t topkth_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topkth);
if (topkth_idx >= exceed)
is_topk = false;
// perform binary cumsum on is_topk to determine idx to put result
topkth_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topkth);
if (is_topk)
dst[topkth_idx] = idx;
if (exceed != 0) {
// top_kth value may not be unique, so we need to
// perform binary cumsum on is_topkth to drop exceeding top-kth values
out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topkth);
is_topk &= (out_idx < exceed);
}
// perform binary cumsum on is_topk to determine the indices to put result
out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topk);
__syncthreads();
if (is_topk) {
$write_value;
// ptr_at(dstv, out_idx * dstv_strides_0) = xval;
$write_index;
// ptr_at(dsti, out_idx * dsti_strides_0) = (INDEX_TYPE)idx;
}
*/
}
......@@ -40,7 +40,7 @@ from theano.tensor import nnet # used for softmax, sigmoid, etc.
from theano.gradient import Rop, Lop, grad, numeric_grad, verify_grad, \
jacobian, hessian, consider_constant
from theano.tensor.sort import sort, argsort, argtopk
from theano.tensor.sort import sort, argsort, topk, argtopk, topk_and_argtopk
from theano.tensor.extra_ops import (DiffOp, bincount, squeeze,
repeat, bartlett, fill_diagonal, fill_diagonal_offset,
cumsum, cumprod)
......
......@@ -221,117 +221,217 @@ def argsort(a, axis=-1, kind='quicksort', order=None):
if hasattr(np, 'argpartition'):
def _argtopk_py_impl(x, k, axis, out_dtype):
# numpy >= 1.8 implementation
if k == 1:
return np.expand_dims(
np.argmax(x, axis=axis).astype(out_dtype), axis)
elif k == -1:
return np.expand_dims(
np.argmin(x, axis=axis).astype(out_dtype), axis)
# numpy >= 1.8 implementation
def _topk_py_impl(op, x, k, axis, idx_dtype):
ndim = x.ndim
asize = x.shape[axis]
if asize == abs(k):
z = np.arange(abs(k), dtype=out_dtype)
l = axis % ndim
r = ndim - l
z = z.reshape((1,) * l + (k,) + (1,) * (r - 1))
reps = list(x.shape)
reps[axis] = 1
return np.tile(z, reps)
print('used axis %d' % axis)
z = np.argpartition(x, -k, axis=axis)
idx = (slice(None),) * (axis % ndim)
if k > 0:
idx += (slice(-k, None),)
elif k < 0:
idx += (slice(-k),)
else:
raise ValueError('k cannot be zero')
return z[idx].astype(out_dtype)
else:
def _argtopk_py_impl(x, k, axis, out_dtype):
if k == 1:
return np.argmax(x, axis=axis).astype(out_dtype)
elif k == -1:
return np.argmin(x, axis=axis).astype(out_dtype)
if abs(k) == 1:
i = (k + 1) // 2
fn_max = [np.min, np.max][i]
fn_argmax = [np.argmin, np.argmax][i]
if not op.return_indices:
return np.expand_dims(fn_max(x, axis=axis), axis)
elif op.return_values:
zi = np.expand_dims(
fn_argmax(x, axis=axis).astype(idx_dtype), axis)
idx2 = tuple(np.arange(s).reshape((s,) + (1,) * (ndim - i - 1)) if i != axis else zi for i, s in enumerate(x.shape))
zv = x[idx2]
return zv, zi.astype(idx_dtype)
else:
zi = np.expand_dims(
fn_argmax(x, axis=axis).astype(idx_dtype), axis)
return zi.astype(idx_dtype)
ndim = x.ndim
asize = x.shape[axis]
if asize == abs(k):
z = np.arange(abs(k), dtype=out_dtype)
l = axis % ndim
r = ndim - l
z = z.reshape((1,) * l + (k,) + (1,) * r)
reps = list(x.shape)
reps[axis] = 1
return np.tile(z, reps)
# numpy implementation for older version
z = np.argsort(x, axis=axis)
idx = (slice(None),) * (axis - 1)
if not op.return_indices:
return x.copy()
else:
l = axis
r = ndim - l
reps = list(x.shape)
reps[axis] = 1
zi = np.arange(abs(k), dtype=idx_dtype)
zi = zi.reshape((1,) * l + (k,) + (1,) * (r - 1))
zi = np.tile(zi, reps)
if op.return_values:
return x.copy(), zi
else:
return zi
idx = [slice(None)] * ndim
if k > 0:
idx += (slice(-k, None),)
idx[axis] = slice(-k, None)
elif k < 0:
idx += (slice(-k),)
idx[axis] = slice(-k)
else:
raise ValueError('k cannot be zero')
return z[idx].astype(out_dtype)
if not op.return_indices:
zv = np.partition(x, -k, axis=axis)[idx]
return zv
elif op.return_values:
zi = np.argpartition(x, -k, axis=axis)[idx]
idx2 = tuple(np.arange(s).reshape((s,)+(1,)*(ndim-i-1)) if i != axis else zi for i, s in enumerate(x.shape))
zv = x[idx2]
return zv, zi.astype(idx_dtype)
else:
zi = np.argpartition(x, -k, axis=axis)[idx]
return zi
else:
def _topk_py_impl(op, x, k, axis, idx_dtype):
# TODO better compatibility?
raise NotImplementedError('TopKOp: need numpy.argpartition() method (numpy >= 1.8)')
class ArgTopKOp(theano.Op):
class TopKOp(theano.Op):
"""
See help(theano.argtopk)
Operations related to finding k-largest elements.
The outputs of this Op depends on ``returns_values`` and ``return_indices``,
if both ``True``, will return two outputs, corresponding to k-largest values
and indices. If only one is ``True``, this Op shall have only one output. Can't
be both ``False``.
Parameters
----------
axis: integer
The axis to perform the operation. Must be in range ``[-ndim, ndim)``, where
``ndim`` is the dimensionality of input tensor.
return_values: bool
Defaults to ``True``
If ``True``, one output of the Op will return k-largest array values.
return_indices: bool
Defaults to ``False``
If ``True``, one output of the Op will return the indices on the given axis.
Notes
-----
- ``return_values`` and ``return_indices`` cannot be both ``False``
See Also
--------
topk
argtopk
argtopk_and_topk
"""
__props__ = ('axis',)
# TODO more params
'''
sorted: bool
Defaults to ``False``
If True, the result array would be incremental-sorted. Mutually exclusive with ``sparse``
sparse: bool
Defaults to ``False``
if ``True``, the output array will always have the same shape as input.
The non-top-k values will be replaced by zero.
def __init__(self, axis=-1):
only_top_kth: bool
Defaults to ``False``
If ``True``, will only find the exact top k-th element. The Op behaves
like a reduction.
'''
# TODO c_code
__props__ = ('axis', 'return_values', 'return_indices')
def __init__(self, axis=-1, return_indices=False, return_values=True):
assert isinstance(axis, int)
assert return_indices or return_values
self.axis = axis
self.return_indices = return_indices
self.return_values = return_values
def __str__(self):
return '%(op)s{axis=%(axis)d}' % dict(
op=self.__class__.__name__, axis=self.axis)
def make_node(self, inp, k, out_dtype='int64'):
def make_node(self, inp, k, idx_dtype='int64'):
# numpy always uses float64 as output dtype for arg*() routines
# however, we add this option as memory is more precious on gpu
inp = theano.tensor.as_tensor_variable(inp)
k = theano.tensor.as_tensor_variable(k)
bcast = inp.type.broadcastable
return theano.Apply(self, [inp, k], [
theano.tensor.TensorType(
dtype=out_dtype,
broadcastable=bcast)()])
outs = []
if self.return_values:
outs.append(inp.type())
if self.return_indices:
outs.append(
theano.tensor.TensorType(dtype=idx_dtype, broadcastable=bcast)())
return theano.Apply(self, [inp, k], outs)
def perform(self, node, inputs, output_storage):
x, k = inputs
pz = output_storage[0]
print("Op's axis: %d" % self.axis)
pz[0] = _argtopk_py_impl(x, k, self.axis, node.outputs[0].dtype)
ndim = x.ndim
axis = self.axis
assert -ndim <= axis < ndim
axis %= ndim
if not self.return_indices:
pzv = output_storage[0]
pzv[0] = _topk_py_impl(self, x, k, axis, None)
elif self.return_values:
pzv = output_storage[0]
pzi = output_storage[1]
pzv[0], pzi[0] = _topk_py_impl(self, x, k, axis, node.outputs[1].dtype)
else:
pzi = output_storage[0]
pzi[0] = _topk_py_impl(self, x, k, axis, node.outputs[0].dtype)
def infer_shape(self, node, inp_shapes):
# numpy always uses float64 as output dtype for arg*() routines
# however, we add this option as memory is more precious on gpu
_check_tensor_is_scalar(node.inputs[1])
shp = list(inp_shapes[0])
if not isinstance(self.axis, int):
raise TypeError(
'axis parameter must be integer, got "%s"' % type(self.axis))
'"axis" parameter must be integer, got "%s"' % type(self.axis))
ndim = node.inputs[0].ndim
if ndim == 0:
raise ValueError('cannot use 0d tensor')
raise ValueError('Cannot take 0d tensor as input')
if not -ndim <= self.axis < ndim:
raise IndexError(
'axis parameter out of range,'
'"axis" parameter out of range,'
' expected integer within [%d, %d]' % (-ndim, ndim - 1))
shp[self.axis] = np.abs(node.inputs[1])
return [tuple(shp)]
shp = tuple(shp)
return [shp for i in [self.return_values, self.return_indices] if i]
def topk(x, k, axis=-1):
"""
Returns the k-largest elements along an axis.
Parameters
----------
x: tensor instance
k: integer constant/variable
Must not be 0. If negative, gives k-smallest elements instead.
axis: integer or ``None``
Upon which axis shall the operation be performed on. If ``None``,
works on flattened array.
Notes
-----
- The returned values may not be sorted.
def argtopk(x, k, axis=-1, out_dtype='int64'):
"""
if axis is None:
x = theano.tensor.flatten(x)
axis = -1
return TopKOp(axis=axis)(x, k)
def argtopk(x, k, axis=-1, idx_dtype='int64'):
"""
Returns the indices of k-largest elements along an axis.
......@@ -341,21 +441,35 @@ def argtopk(x, k, axis=-1, out_dtype='int64'):
x: tensor instance
k: integer constant/variable
Must not be 0. If negative, gives k-least elements instead.
Must not be 0. If negative, gives k-smallest elements instead.
axis: integer or ``None``
Upon which axis shall the operation be performed on. If ``None``,
works on flattened array.
out_dtype: string
Specify output dtype, defaults to ``int64``, must be integer type
idx_dtype: string
Specify output dtype, defaults to ``int64``, must be integer type.
Notes
-----
- The corresponding value of returned indices may not be sorted themselves
- The corresponding values of returned indices may not be sorted.
"""
if axis is None:
x = theano.tensor.flatten(x)
axis = -1
return ArgTopKOp(axis=axis)(x, k, out_dtype=out_dtype)
return TopKOp(axis=axis, return_indices=True, return_values=False)(x, k, idx_dtype=idx_dtype)
def topk_and_argtopk(x, k, axis=-1, idx_dtype='int64'):
'''
Returns the results of both topk() and argtopk() in one Op.
See the respective documentation for details.
'''
if axis is None:
x = theano.tensor.flatten(x)
axis = -1
return TopKOp(axis=axis, return_indices=True)(x, k, idx_dtype=idx_dtype)
......@@ -11,7 +11,7 @@ from theano import tensor
from theano.tensor.sort import sort, SortOp
from theano.tensor.sort import argsort, ArgSortOp
from theano.tensor.sort import argtopk, ArgTopKOp
from theano.tensor.sort import topk, argtopk, topk_and_argtopk, TopKOp
_dtypes = (
'float32', 'float64',
......@@ -24,10 +24,11 @@ _int_dtypes = (
def gen_unique_vector(size, dtype):
# generate a randomized vector with unique elements
retval = np.cumsum(np.random.uniform(1.01, 3.01, size))
return (retval[np.random.permutation(size)] - size).astype(dtype)
retval = np.arange(size*3) + np.random.uniform(-1., 1.)
return (retval[np.random.permutation(size)] - size*1.5).astype(dtype)
'''
class Test_sort(unittest.TestCase):
def setUp(self):
......@@ -235,21 +236,67 @@ def test_argsort_grad():
data = np.random.rand(2, 3, 3).astype(theano.config.floatX)
utt.verify_grad(lambda x: argsort(x, axis=2), [data])
'''
class Test_topk(unittest.TestCase):
class Test_TopK(unittest.TestCase):
def setUp(self):
pass
@utt.parameterized.expand(product(
_dtypes, _int_dtypes, [-1, 0, None]))
def test_sanity(self, dtype, out_dtype, axis):
def test_argtopk_sanity(self, dtype, idx_dtype, axis):
x = tensor.vector(name='x', dtype=dtype)
fn = theano.function([x], argtopk(x, 1, axis=axis, out_dtype=out_dtype))
fn = theano.function([x], argtopk(x, 1, axis=axis, idx_dtype=idx_dtype))
xval = np.asarray([1]).astype(dtype)
yval = fn(xval)
assert yval == np.asarray([0], dtype=out_dtype)
assert yval == np.asarray([0], dtype=idx_dtype)
@utt.parameterized.expand(product(
_dtypes, [-1, 0, None]))
def test_topk_sanity(self, dtype, axis):
x = tensor.vector(name='x', dtype=dtype)
fn = theano.function([x], topk(x, 1, axis=axis))
xval = np.asarray([1]).astype(dtype)
yval = fn(xval)
assert yval == xval
@utt.parameterized.expand(product(
_dtypes, _int_dtypes, [-1, 0, None]))
def test_combined_sanity(self, dtype, idx_dtype, axis):
x = tensor.vector(name='x', dtype=dtype)
yv, yi = topk_and_argtopk(x, 1, axis=axis, idx_dtype=idx_dtype)
fn = theano.function([x], [yv, yi])
xval = np.asarray([1]).astype(dtype)
yvval, yival = fn(xval)
assert yival == np.asarray([0], dtype=idx_dtype)
assert np.allclose(xval, yvval)
@utt.parameterized.expand(chain(
product(
(16, 61, 257),
(1, -1, 10, -10, 'n//2', 'n-1', '-n', '1-n'),
('float64', 'int16', 'int8')),
((2049, 1337, 'float64'),)))
def test_topk_1d(self, size, k, dtype):
if isinstance(k, str):
k = eval(k.replace('n', str(size)))
x = theano.tensor.vector(name='x', dtype=dtype)
y = topk(x, k)
fn = theano.function([x], y)
# generate a all-unique array
xval = gen_unique_vector(size, dtype)
yval = fn(xval)
idx = slice(-k, None) if k > 0 else slice(-k)
goal = np.sort(xval)[idx]
print(np.sort(yval))
print(goal)
assert yval.dtype == goal.dtype
assert np.allclose(np.sort(yval), goal)
@utt.parameterized.expand(chain(
product(
......@@ -258,38 +305,60 @@ class Test_topk(unittest.TestCase):
('float32', 'int32'),
('int32', 'int64')),
((2049, 1337, 'float32', 'int32'),)))
def test_1d(self, size, k, dtype, out_dtype):
def test_argtopk_1d(self, size, k, dtype, idx_dtype):
if isinstance(k, str):
k = eval(k.replace('n', str(size)))
x = theano.tensor.vector(name='x', dtype=dtype)
y = argtopk(x, k, out_dtype=out_dtype)
y = argtopk(x, k, idx_dtype=idx_dtype)
fn = theano.function([x], y)
# generate a all-unique array
xval = gen_unique_vector(size, dtype)
yval = fn(xval)
idx = slice(-k, None) if k > 0 else slice(-k)
goal = np.argsort(xval)[idx].astype(out_dtype)
print(yval)
print(goal)
print(np.argsort(xval))
goal = np.argsort(xval)[idx].astype(idx_dtype)
# due to uniqueness, we expect indices same
assert np.all(xval[np.sort(yval)] == xval[np.sort(goal)])
@utt.parameterized.expand(chain(
product(
(16, 61, 257),
(1, -1, 10, -10, 'n//2', 'n-1', '-n', '1-n'),
('float32', 'int32'),
('int32', 'int64')),
((2049, 1337, 'float32', 'int32'),)))
def test_combined_1d(self, size, k, dtype, idx_dtype):
if isinstance(k, str):
k = eval(k.replace('n', str(size)))
x = theano.tensor.vector(name='x', dtype=dtype)
yv, yi = topk_and_argtopk(x, k, idx_dtype=idx_dtype)
fn = theano.function([x], [yv, yi])
# generate a all-unique array
xval = gen_unique_vector(size, dtype)
yvval, yival = fn(xval)
idx = slice(-k, None) if k > 0 else slice(-k)
goali = np.argsort(xval)[idx].astype(idx_dtype)
goalv = xval[goali]
# due to uniqueness, we expect indices same
assert np.all(xval[np.sort(yival)] == xval[np.sort(goali)])
assert np.allclose(np.sort(yvval), goalv)
@utt.parameterized.expand(chain(
product(
(18, 62, 258),
(1, -1, 'n//2'),
('int32', 'float32')),
((2048, 1337, 'float32'),)))
def test_1d_collision(self, size, k, dtype):
def test_argtopk_1d_collision(self, size, k, dtype):
# with non-unique kth max value
if isinstance(k, str):
k = eval(k.replace('n', str(size)))
x = theano.tensor.vector(name='x', dtype=dtype)
y = argtopk(x, k, out_dtype='int32')
y = argtopk(x, k, idx_dtype='int32')
fn = theano.function([x], y)
xval = np.repeat(np.random.uniform(-100., 100., size=size // 2).astype(dtype), 2)
xval = xval[np.random.permutation(size)]
......@@ -305,7 +374,7 @@ class Test_topk(unittest.TestCase):
(1, -1, '(1+n)//2', 'n-1', '-n', '1-n'),
('float32', 'int32'),
('int32', 'int64')))
def test_nd(self, shp, k_, dtype, out_dtype):
def test_argtopk_nd(self, shp, k_, dtype, idx_dtype):
ndim = len(shp)
for axis in range(-ndim, ndim):
if isinstance(k_, str):
......@@ -318,7 +387,7 @@ class Test_topk(unittest.TestCase):
x = theano.tensor.tensor(
name='x', broadcastable=(False,) * len(shp), dtype=dtype)
y = argtopk(x, k, axis=axis, out_dtype=out_dtype)
y = argtopk(x, k, axis=axis, idx_dtype=idx_dtype)
fn = theano.function([x], y)
size = reduce(int.__mul__, shp)
xval = gen_unique_vector(size, dtype).reshape(shp)
......@@ -327,20 +396,47 @@ class Test_topk(unittest.TestCase):
l = axis % ndim
r = ndim - l
idx = (slice(None),) * l + (idx,) + (slice(None),) * (r - 1)
goal = np.argsort(xval, axis=axis)[idx].astype(out_dtype)
goal = np.argsort(xval, axis=axis)[idx].astype(idx_dtype)
print(dict(k=k, axis=axis, shp=shp))
print('x:')
print(xval)
print('y:')
print(np.sort(yval, axis=axis))
print('goal:')
print(np.sort(goal, axis=axis))
# print(np.argsort(xval))
assert np.all(np.sort(yval, axis=axis) == np.sort(goal, axis=axis))
class ArgTopKInferShapeTester(utt.InferShapeTester):
class TopKInferShapeTester(utt.InferShapeTester):
@utt.parameterized.expand(product(
((2, 3), (15, 17), (11, 7, 5), (2, 3, 5, 7, 11), (2, 4, 3, 1)),
(1, -1, '(1+n)//2', 'n-1', '-n', '1-n')))
def test_topk_infer_shape(self, shp, k_):
ndim = len(shp)
for axis in range(-ndim, ndim):
if isinstance(k_, str):
k = eval(k_.replace('n', str(shp[axis])))
else:
k = k_
if k == 0:
continue
x = theano.tensor.tensor(
name='x', broadcastable=(False,) * len(shp),
dtype=theano.config.floatX)
y = topk(x, k, axis=axis)
size = reduce(int.__mul__, shp)
xval = gen_unique_vector(size, theano.config.floatX).reshape(shp)
self._compile_and_check(
[x], [y], [xval], TopKOp)
@utt.parameterized.expand(product(
((2, 3), (15, 17), (11, 7, 5), (2, 3, 5, 7, 11), (2, 4, 3, 1)),
(1, -1, '(1+n)//2', 'n-1', '-n', '1-n')))
def test_infer_shape(self, shp, k_):
def test_argtopk_infer_shape(self, shp, k_):
ndim = len(shp)
for axis in range(-ndim, ndim):
if isinstance(k_, str):
......@@ -354,8 +450,32 @@ class ArgTopKInferShapeTester(utt.InferShapeTester):
x = theano.tensor.tensor(
name='x', broadcastable=(False,) * len(shp),
dtype=theano.config.floatX)
y = argtopk(x, k, axis=axis, out_dtype='int32')
y = argtopk(x, k, axis=axis, idx_dtype='int32')
size = reduce(int.__mul__, shp)
xval = gen_unique_vector(size, theano.config.floatX).reshape(shp)
self._compile_and_check(
[x], [y], [xval], ArgTopKOp)
[x], [y], [xval], TopKOp)
@utt.parameterized.expand(product(
((2, 3), (15, 17), (11, 7, 5), (2, 3, 5, 7, 11), (2, 4, 3, 1)),
(1, -1, '(1+n)//2', 'n-1', '-n', '1-n')))
def test_combined_infer_shape(self, shp, k_):
ndim = len(shp)
for axis in range(-ndim, ndim):
if isinstance(k_, str):
k = eval(k_.replace('n', str(shp[axis])))
else:
k = k_
if k == 0:
continue
x = theano.tensor.tensor(
name='x', broadcastable=(False,) * len(shp),
dtype=theano.config.floatX)
yv, yi = topk_and_argtopk(x, k, axis=axis, idx_dtype='int32')
size = reduce(int.__mul__, shp)
xval = gen_unique_vector(size, theano.config.floatX).reshape(shp)
self._compile_and_check(
[x], [yv, yi], [xval], TopKOp)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论