提交 5cc9c326 authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #5959 from khaotik/topk

TopKOp implementation
...@@ -1367,21 +1367,27 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1367,21 +1367,27 @@ class LocalOptGroup(LocalOptimizer):
self.process_count[opt] += 1 self.process_count[opt] += 1
if not new_repl: if not new_repl:
continue continue
else: if isinstance(new_repl, (tuple, list)):
if self.profile: new_vars = new_repl
self.node_created[opt] += len(graph.ops(fgraph.variables, new_repl)) else: # It must be a dict
self.applied_true[opt] += 1 new_vars = list(new_repl.values())
break # break from the for loop over optimization. if self.profile:
self.node_created[opt] += len(graph.ops(fgraph.variables, new_vars))
self.applied_true[opt] += 1
break # break from the for loop over optimization.
if not new_repl: # No optimization applied in the last iteration if not new_repl: # No optimization applied in the last iteration
return repl return repl
# only 1 iteration or we are at the start of the graph. # only 1 iteration
if not self.apply_all_opts or not new_repl[0].owner: if not self.apply_all_opts:
return new_repl
if not new_vars[0].owner:
# We are at the start of the graph.
return new_repl return new_repl
if len(new_repl) > 1: if len(new_repl) > 1:
s = set([v.owner for v in new_repl]) s = set([v.owner for v in new_repl])
assert len(s) == 1 assert len(s) == 1
repl = new_repl repl = new_repl
node = repl[0].owner node = new_vars[0].owner
@staticmethod @staticmethod
def print_profile(stream, prof, level=0): def print_profile(stream, prof, level=0):
......
...@@ -28,7 +28,7 @@ from .type import (GpuArrayType, GpuArrayVariable, GpuArrayConstant, ...@@ -28,7 +28,7 @@ from .type import (GpuArrayType, GpuArrayVariable, GpuArrayConstant,
GpuArraySharedVariable, gpuarray_shared_constructor, GpuArraySharedVariable, gpuarray_shared_constructor,
reg_context, get_context, ContextNotDefined) reg_context, get_context, ContextNotDefined)
from .basic_ops import as_gpuarray_variable from .basic_ops import as_gpuarray_variable
from . import fft, dnn, opt, extra_ops, multinomial, reduction, rng_mrg, ctc from . import fft, dnn, opt, extra_ops, multinomial, reduction, sort, rng_mrg, ctc
def transfer(x, target): def transfer(x, target):
......
差异被折叠。
#define RADIX_BITS 4
#define RADIX_SIZE (1<<RADIX_BITS)
#define RADIX_MASK(n) ((RADIX_SIZE-1) << (n*RADIX_BITS))
#define RADIX_DIGITS(T) (bitsof(T)/RADIX_BITS)
// works when length on axis is within max allowed threads in block (1024)
KERNEL void k_topk_dense(
$dims
// size_t dims_1, ssize_t dims_2, ... , dims_$${NDIM}
$dstv
// INPUT_TYPE *dstv
$dstv_strides
// ssize_t dstv_strides_0, ssize_t dstv_strides_1, ... , dstv_strides_$${NDIM}
$dsti
// INDEX_TYPE *dsti
$dsti_strides
// ssize_t dsti_strides_0, ssize_t dsti_strides_1, ... , dsti_strides_$${NDIM}
ssize_t k,
INPUT_TYPE* src,
$src_strides
// ssize_t src_strides_0, ssize_t src_strides_1, ... , src_strides_$${NDIM}
size_t size) {
__shared__ int smem[32 * RADIX_SIZE];
__shared__ int k2;
const unsigned int idx = threadIdx.x;
bool is_topk= (idx < size);
bool is_topkth = is_topk;
size_t out_idx;
const unsigned char warp_id = idx / GA_WARP_SIZE;
// 0. get the slice for thread block to work on
size_t gid = blockIdx.x, gidx;
$set_slice
// $$set_slice expands into:
//for(int i=1; i<NDIM; i++) {
// gidx = gid % dims_$${i};
// gid /= dims_$${i};
// dsti = ptr_add(dsti, gidx*dsti_strides_$${i};
// dstv = ptr_add(dstv, gidx*dstv_strides_$${i};
// src = ptr_add(src, gidx*src_strides_$${i});
//}
// get input and its radix friendly form
const INPUT_TYPE xval = is_topk ? ptr_at(src, idx*src_strides_0) : (INPUT_TYPE)0;
radix_t x = RadixConfig<INPUT_TYPE>::convert(xval);
// resolve negative k
if (k<0) { x = ~x; k = -k; }
if (idx==0)
k2 = k;
// 1. filter is_topk and is_topkth using radix select
#pragma unroll
for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
const int digit = Bitfield<radix_t>::get(x, i, RADIX_BITS);
/*int digit = (x>>i) & (RADIX_SIZE-1);*/
// count within warp
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
bool vote = (bin == digit) && is_topkth;
unsigned int votes = __ballot(vote);
if (lane_id()==0)
smem[bin + RADIX_SIZE*warp_id] = __popc(votes);
}
local_barrier();
// sum counts across all warps
if (idx < RADIX_SIZE) {
int sum = smem[idx];
#pragma unroll
for(int w=RADIX_SIZE; w<blockDim.x*RADIX_SIZE / GA_WARP_SIZE; w+=RADIX_SIZE)
sum += smem[idx + w];
smem[idx] = sum;
}
local_barrier();
// find the bucket and update k2
// smem[:RADIX_SIZE:-1] = k2 - cumsum(smem[:RADIX_SIZE-1:-1])
if (idx == 0) {
int sum = k2;
#pragma unroll
for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
sum -= smem[bin];
smem[bin] = sum;
k2 = (sum > 0) ? sum : k2;
}
smem[RADIX_SIZE] = 1;
}
local_barrier();
if (is_topkth) {
is_topk &= (smem[digit+1] > 0);
is_topkth &= (smem[digit] <= 0) && (smem[digit+1] > 0);
}
local_barrier();
}
// set k2 as number of exceeding values
if (idx==0) {
#pragma unroll
for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
if (smem[bin] <= 0)
break;
k2 = smem[bin];
}
}
local_barrier();
// 2. find the index of output array, if exists
if (k2 != 0) {
// top_kth value may not be unique, so we need to
// perform binary cumsum on is_topkth to drop exceeding top-kth values
out_idx = binary_cumsum_exclusive(idx, warp_id, smem, is_topkth);
if ((out_idx >= k2) && is_topkth)
is_topk = false;
local_barrier();
}
// perform binary cumsum on is_topk to determine the indices to put result
out_idx = binary_cumsum_exclusive(idx, warp_id, smem, is_topk);
if (is_topk) {
#if WRITE_VALUE == 1
ptr_at(dstv, out_idx * dstv_strides_0) = xval;
#endif
#if WRITE_INDEX == 1
ptr_at(dsti, out_idx * dsti_strides_0) = (INDEX_TYPE)idx;
#endif
}
}
差异被折叠。
差异被折叠。
...@@ -40,7 +40,7 @@ from theano.tensor import nnet # used for softmax, sigmoid, etc. ...@@ -40,7 +40,7 @@ from theano.tensor import nnet # used for softmax, sigmoid, etc.
from theano.gradient import Rop, Lop, grad, numeric_grad, verify_grad, \ from theano.gradient import Rop, Lop, grad, numeric_grad, verify_grad, \
jacobian, hessian, consider_constant jacobian, hessian, consider_constant
from theano.tensor.sort import sort, argsort from theano.tensor.sort import sort, argsort, topk, argtopk, topk_and_argtopk
from theano.tensor.extra_ops import (DiffOp, bincount, squeeze, from theano.tensor.extra_ops import (DiffOp, bincount, squeeze,
repeat, bartlett, fill_diagonal, fill_diagonal_offset, repeat, bartlett, fill_diagonal, fill_diagonal_offset,
cumsum, cumprod, unravel_index, ravel_multi_index) cumsum, cumprod, unravel_index, ravel_multi_index)
......
...@@ -35,6 +35,7 @@ from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice, ...@@ -35,6 +35,7 @@ from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice,
advanced_subtensor, advanced_subtensor,
advanced_subtensor1, advanced_subtensor1,
advanced_inc_subtensor1) advanced_inc_subtensor1)
from theano.tensor.sort import TopKOp
from theano import scalar from theano import scalar
from theano.scalar import basic from theano.scalar import basic
from theano.tensor import basic as T from theano.tensor import basic as T
...@@ -7548,3 +7549,35 @@ def local_merge_alloc(node): ...@@ -7548,3 +7549,35 @@ def local_merge_alloc(node):
dim_outer, T.eq(dim_outer, dim_inner)) dim_outer, T.eq(dim_outer, dim_inner))
i += 1 i += 1
return [T.alloc(inputs_inner[0], *dims_outer)] return [T.alloc(inputs_inner[0], *dims_outer)]
@register_useless('fast_compile')
@gof.local_optimizer([TopKOp])
def local_useless_topk(node):
"""
TopKOp generates two outputs by default
This opt removes the useless ones
"""
op = node.op
if not isinstance(op, TopKOp):
return
if not (op.return_values and op.return_indices):
return False
x, k = node.inputs
ret_val = bool(node.outputs[0].clients)
ret_idx = bool(node.outputs[1].clients)
if not (ret_val ^ ret_idx):
# both true -> nothing to remove
# both false -> let pruner handle
return False
old_output = node.outputs[ret_idx]
new_output = TopKOp(
axis=op.axis,
idx_dtype=op.idx_dtype,
return_values=ret_val,
return_indices=ret_idx)(x, k)
return {old_output: new_output}
差异被折叠。
...@@ -83,7 +83,7 @@ def seed_rng(pseed=None): ...@@ -83,7 +83,7 @@ def seed_rng(pseed=None):
def verify_grad(op, pt, n_tests=2, rng=None, *args, **kwargs): def verify_grad(op, pt, n_tests=2, rng=None, *args, **kwargs):
""" """
Wrapper for tensor/basic.py:verify_grad Wrapper for gradient.py:verify_grad
Takes care of seeding the random number generator if None is given Takes care of seeding the random number generator if None is given
""" """
if rng is None: if rng is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论