提交 7db7bec1 authored 作者: Adam Becker's avatar Adam Becker

flake & cleanups

上级 7b13e2f4
......@@ -33,6 +33,7 @@ KERNEL void k_topk_dense(
ga_size gid = GID_0, gidx;
$set_slice
// $$set_slice expands into:
//for(int i=1; i<NDIM; i++) {
// gidx = gid % dims_$${i};
// gid /= dims_$${i};
......
......@@ -5,7 +5,8 @@
#define COUNT_TYPE $count_t
#define KERNEL_NAME $kname
// works when array size along axis is within [1025, 2^63-1]
// if count_t is int, work for array size within [1025, 2^31-1]
// if count_t is long long, work for array size within [2^31, 2^63-1]
template <typename DataType, typename RadixType, typename CountType>
__device__ DataType find_pattern(DataType* smem,
DataType* data,
......@@ -58,7 +59,7 @@ __device__ void count_radix_masked(CountType counts[RADIX_SIZE],
CountType slice_size,
CountType stride,
DataType* data) {
// Clear out per-thread counts from a previous round
// Clear out per-thread counts from a previous round
#pragma unroll
for (int i = 0; i < RADIX_SIZE; ++i)
counts[i] = 0;
......@@ -73,12 +74,12 @@ __device__ void count_radix_masked(CountType counts[RADIX_SIZE],
for (CountType i = LID_0; i < slice_size; i += LDIM_0) {
RadixType val = RadixConfig<DataType>::convert(ptr_read_cached(data, i*stride));
bool hasVal = ((val & known_bits_mask) == known_bits);
bool has_val = ((val & known_bits_mask) == known_bits);
RadixType digit_in_radix = Bitfield<RadixType>::get(val, radix_digit_pos, RADIX_BITS);
#pragma unroll
for (int j = 0; j < RADIX_SIZE; ++j) {
bool vote = hasVal && (digit_in_radix == j);
bool vote = has_val && (digit_in_radix == j);
counts[j] += __popc(__ballot(vote));
}
}
......@@ -222,6 +223,7 @@ KERNEL void KERNEL_NAME(
// dims_1+ <- batched dimensions
ga_uint gid = GID_0, gidx;
$set_slice
// $$set_slice expands into:
//for(int i=1; i<NDIM; i++) {
// gidx = gid % dims_$${i};
// gid /= dims_$${i};
......@@ -245,7 +247,7 @@ KERNEL void KERNEL_NAME(
// `has_topk`. This will return the resulting index into which we
// need to write the result, if a thread has a result.
// All threads need to participate in the loop and the prefix sum,
// All threads need to participate in the loop and the cumsum
// but not necessarily in the load; hence loop bounds being rounded
// up to a multiple of the block dim.
COUNT_TYPE iter_bound = size + LDIM_0-1;
......
......@@ -2,7 +2,6 @@ from __future__ import absolute_import, print_function, division
import os
from string import Template
import numpy as np
from theano import Apply
from theano.tensor import as_tensor_variable
from theano.tensor.sort import TopKOp
......@@ -116,7 +115,8 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
def build_kernel(fname, kname, subs):
with open(os.path.join(
os.path.dirname(__file__), 'c_code', fname)) as f:
os.path.dirname(__file__), 'c_code', fname)
) as f:
kernel_src = f.read()
ker = Kernel(
code=Template(common_src + kernel_src).substitute(**subs),
......
......@@ -359,10 +359,11 @@ class TopKOp(theano.Op):
'"idx_dtype" parameter must be an integer dtype, got "%s"' % idx_dtype)
if not (return_indices or return_values):
raise ValueError("Neither return_values nor return_indices is True, this isn't allowed")
raise ValueError(
"Neither return_values nor return_indices is True, this isn't allowed")
self.axis = axis
self.sorted=sorted
self.sorted = sorted
self.return_values = return_values
self.return_indices = return_indices
self.idx_dtype = idx_dtype
......@@ -454,7 +455,7 @@ def topk(x, kth, axis=-1, sorted=True, idx_dtype='int64'):
If ``None``, works on flattened array.
sorted: bool
NOTE: NOT IMPLEMENTED YET
NOTE: NOT IMPLEMENTED YET, USE ``False`` FOR NOW.
Defaults to ``True``
If True, the result array would be sorted in descending order.
......@@ -497,6 +498,7 @@ def argtopk(x, kth, axis=-1, sorted=True, idx_dtype='int64'):
Must not be 0. If negative, gives k-smallest elements instead.
sorted: bool
NOTE: NOT IMPLEMENTED YET, USE ``False`` FOR NOW.
Defaults to ``True``
If True, the result array of corresponding indices would be sorted in descending order.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论