提交 5cc9c326 authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #5959 from khaotik/topk

TopKOp implementation
......@@ -1367,21 +1367,27 @@ class LocalOptGroup(LocalOptimizer):
self.process_count[opt] += 1
if not new_repl:
continue
else:
if self.profile:
self.node_created[opt] += len(graph.ops(fgraph.variables, new_repl))
self.applied_true[opt] += 1
break # break from the for loop over optimization.
if isinstance(new_repl, (tuple, list)):
new_vars = new_repl
else: # It must be a dict
new_vars = list(new_repl.values())
if self.profile:
self.node_created[opt] += len(graph.ops(fgraph.variables, new_vars))
self.applied_true[opt] += 1
break # break from the for loop over optimization.
if not new_repl: # No optimization applied in the last iteration
return repl
# only 1 iteration or we are at the start of the graph.
if not self.apply_all_opts or not new_repl[0].owner:
# only 1 iteration
if not self.apply_all_opts:
return new_repl
if not new_vars[0].owner:
# We are at the start of the graph.
return new_repl
if len(new_repl) > 1:
s = set([v.owner for v in new_repl])
assert len(s) == 1
repl = new_repl
node = repl[0].owner
node = new_vars[0].owner
@staticmethod
def print_profile(stream, prof, level=0):
......
......@@ -28,7 +28,7 @@ from .type import (GpuArrayType, GpuArrayVariable, GpuArrayConstant,
GpuArraySharedVariable, gpuarray_shared_constructor,
reg_context, get_context, ContextNotDefined)
from .basic_ops import as_gpuarray_variable
from . import fft, dnn, opt, extra_ops, multinomial, reduction, rng_mrg, ctc
from . import fft, dnn, opt, extra_ops, multinomial, reduction, sort, rng_mrg, ctc
def transfer(x, target):
......
差异被折叠。
#define RADIX_BITS 4
#define RADIX_SIZE (1<<RADIX_BITS)
#define RADIX_MASK(n) ((RADIX_SIZE-1) << (n*RADIX_BITS))
#define RADIX_DIGITS(T) (bitsof(T)/RADIX_BITS)
// works when length on axis is within max allowed threads in block (1024)
KERNEL void k_topk_dense(
$dims
// size_t dims_1, ssize_t dims_2, ... , dims_$${NDIM}
$dstv
// INPUT_TYPE *dstv
$dstv_strides
// ssize_t dstv_strides_0, ssize_t dstv_strides_1, ... , dstv_strides_$${NDIM}
$dsti
// INDEX_TYPE *dsti
$dsti_strides
// ssize_t dsti_strides_0, ssize_t dsti_strides_1, ... , dsti_strides_$${NDIM}
ssize_t k,
INPUT_TYPE* src,
$src_strides
// ssize_t src_strides_0, ssize_t src_strides_1, ... , src_strides_$${NDIM}
size_t size) {
__shared__ int smem[32 * RADIX_SIZE];
__shared__ int k2;
const unsigned int idx = threadIdx.x;
bool is_topk= (idx < size);
bool is_topkth = is_topk;
size_t out_idx;
const unsigned char warp_id = idx / GA_WARP_SIZE;
// 0. get the slice for thread block to work on
size_t gid = blockIdx.x, gidx;
$set_slice
// $$set_slice expands into:
//for(int i=1; i<NDIM; i++) {
// gidx = gid % dims_$${i};
// gid /= dims_$${i};
// dsti = ptr_add(dsti, gidx*dsti_strides_$${i};
// dstv = ptr_add(dstv, gidx*dstv_strides_$${i};
// src = ptr_add(src, gidx*src_strides_$${i});
//}
// get input and its radix friendly form
const INPUT_TYPE xval = is_topk ? ptr_at(src, idx*src_strides_0) : (INPUT_TYPE)0;
radix_t x = RadixConfig<INPUT_TYPE>::convert(xval);
// resolve negative k
if (k<0) { x = ~x; k = -k; }
if (idx==0)
k2 = k;
// 1. filter is_topk and is_topkth using radix select
#pragma unroll
for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
const int digit = Bitfield<radix_t>::get(x, i, RADIX_BITS);
/*int digit = (x>>i) & (RADIX_SIZE-1);*/
// count within warp
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
bool vote = (bin == digit) && is_topkth;
unsigned int votes = __ballot(vote);
if (lane_id()==0)
smem[bin + RADIX_SIZE*warp_id] = __popc(votes);
}
local_barrier();
// sum counts across all warps
if (idx < RADIX_SIZE) {
int sum = smem[idx];
#pragma unroll
for(int w=RADIX_SIZE; w<blockDim.x*RADIX_SIZE / GA_WARP_SIZE; w+=RADIX_SIZE)
sum += smem[idx + w];
smem[idx] = sum;
}
local_barrier();
// find the bucket and update k2
// smem[:RADIX_SIZE:-1] = k2 - cumsum(smem[:RADIX_SIZE-1:-1])
if (idx == 0) {
int sum = k2;
#pragma unroll
for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
sum -= smem[bin];
smem[bin] = sum;
k2 = (sum > 0) ? sum : k2;
}
smem[RADIX_SIZE] = 1;
}
local_barrier();
if (is_topkth) {
is_topk &= (smem[digit+1] > 0);
is_topkth &= (smem[digit] <= 0) && (smem[digit+1] > 0);
}
local_barrier();
}
// set k2 as number of exceeding values
if (idx==0) {
#pragma unroll
for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
if (smem[bin] <= 0)
break;
k2 = smem[bin];
}
}
local_barrier();
// 2. find the index of output array, if exists
if (k2 != 0) {
// top_kth value may not be unique, so we need to
// perform binary cumsum on is_topkth to drop exceeding top-kth values
out_idx = binary_cumsum_exclusive(idx, warp_id, smem, is_topkth);
if ((out_idx >= k2) && is_topkth)
is_topk = false;
local_barrier();
}
// perform binary cumsum on is_topk to determine the indices to put result
out_idx = binary_cumsum_exclusive(idx, warp_id, smem, is_topk);
if (is_topk) {
#if WRITE_VALUE == 1
ptr_at(dstv, out_idx * dstv_strides_0) = xval;
#endif
#if WRITE_INDEX == 1
ptr_at(dsti, out_idx * dsti_strides_0) = (INDEX_TYPE)idx;
#endif
}
}
差异被折叠。
差异被折叠。
......@@ -40,7 +40,7 @@ from theano.tensor import nnet # used for softmax, sigmoid, etc.
from theano.gradient import Rop, Lop, grad, numeric_grad, verify_grad, \
jacobian, hessian, consider_constant
from theano.tensor.sort import sort, argsort
from theano.tensor.sort import sort, argsort, topk, argtopk, topk_and_argtopk
from theano.tensor.extra_ops import (DiffOp, bincount, squeeze,
repeat, bartlett, fill_diagonal, fill_diagonal_offset,
cumsum, cumprod, unravel_index, ravel_multi_index)
......
......@@ -35,6 +35,7 @@ from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice,
advanced_subtensor,
advanced_subtensor1,
advanced_inc_subtensor1)
from theano.tensor.sort import TopKOp
from theano import scalar
from theano.scalar import basic
from theano.tensor import basic as T
......@@ -7548,3 +7549,35 @@ def local_merge_alloc(node):
dim_outer, T.eq(dim_outer, dim_inner))
i += 1
return [T.alloc(inputs_inner[0], *dims_outer)]
@register_useless('fast_compile')
@gof.local_optimizer([TopKOp])
def local_useless_topk(node):
"""
TopKOp generates two outputs by default
This opt removes the useless ones
"""
op = node.op
if not isinstance(op, TopKOp):
return
if not (op.return_values and op.return_indices):
return False
x, k = node.inputs
ret_val = bool(node.outputs[0].clients)
ret_idx = bool(node.outputs[1].clients)
if not (ret_val ^ ret_idx):
# both true -> nothing to remove
# both false -> let pruner handle
return False
old_output = node.outputs[ret_idx]
new_output = TopKOp(
axis=op.axis,
idx_dtype=op.idx_dtype,
return_values=ret_val,
return_indices=ret_idx)(x, k)
return {old_output: new_output}
差异被折叠。
......@@ -83,7 +83,7 @@ def seed_rng(pseed=None):
def verify_grad(op, pt, n_tests=2, rng=None, *args, **kwargs):
"""
Wrapper for tensor/basic.py:verify_grad
Wrapper for gradient.py:verify_grad
Takes care of seeding the random number generator if None is given
"""
if rng is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论