提交 7db7bec1 authored 作者: Adam Becker's avatar Adam Becker

flake & cleanups

上级 7b13e2f4
...@@ -33,6 +33,7 @@ KERNEL void k_topk_dense( ...@@ -33,6 +33,7 @@ KERNEL void k_topk_dense(
ga_size gid = GID_0, gidx; ga_size gid = GID_0, gidx;
$set_slice $set_slice
// $$set_slice expands into:
//for(int i=1; i<NDIM; i++) { //for(int i=1; i<NDIM; i++) {
// gidx = gid % dims_$${i}; // gidx = gid % dims_$${i};
// gid /= dims_$${i}; // gid /= dims_$${i};
......
...@@ -5,7 +5,8 @@ ...@@ -5,7 +5,8 @@
#define COUNT_TYPE $count_t #define COUNT_TYPE $count_t
#define KERNEL_NAME $kname #define KERNEL_NAME $kname
// works when array size along axis is within [1025, 2^63-1] // if count_t is int, work for array size within [1025, 2^31-1]
// if count_t is long long, work for array size within [2^31, 2^63-1]
template <typename DataType, typename RadixType, typename CountType> template <typename DataType, typename RadixType, typename CountType>
__device__ DataType find_pattern(DataType* smem, __device__ DataType find_pattern(DataType* smem,
DataType* data, DataType* data,
...@@ -58,7 +59,7 @@ __device__ void count_radix_masked(CountType counts[RADIX_SIZE], ...@@ -58,7 +59,7 @@ __device__ void count_radix_masked(CountType counts[RADIX_SIZE],
CountType slice_size, CountType slice_size,
CountType stride, CountType stride,
DataType* data) { DataType* data) {
// Clear out per-thread counts from a previous round // Clear out per-thread counts from a previous round
#pragma unroll #pragma unroll
for (int i = 0; i < RADIX_SIZE; ++i) for (int i = 0; i < RADIX_SIZE; ++i)
counts[i] = 0; counts[i] = 0;
...@@ -73,12 +74,12 @@ __device__ void count_radix_masked(CountType counts[RADIX_SIZE], ...@@ -73,12 +74,12 @@ __device__ void count_radix_masked(CountType counts[RADIX_SIZE],
for (CountType i = LID_0; i < slice_size; i += LDIM_0) { for (CountType i = LID_0; i < slice_size; i += LDIM_0) {
RadixType val = RadixConfig<DataType>::convert(ptr_read_cached(data, i*stride)); RadixType val = RadixConfig<DataType>::convert(ptr_read_cached(data, i*stride));
bool hasVal = ((val & known_bits_mask) == known_bits); bool has_val = ((val & known_bits_mask) == known_bits);
RadixType digit_in_radix = Bitfield<RadixType>::get(val, radix_digit_pos, RADIX_BITS); RadixType digit_in_radix = Bitfield<RadixType>::get(val, radix_digit_pos, RADIX_BITS);
#pragma unroll #pragma unroll
for (int j = 0; j < RADIX_SIZE; ++j) { for (int j = 0; j < RADIX_SIZE; ++j) {
bool vote = hasVal && (digit_in_radix == j); bool vote = has_val && (digit_in_radix == j);
counts[j] += __popc(__ballot(vote)); counts[j] += __popc(__ballot(vote));
} }
} }
...@@ -222,6 +223,7 @@ KERNEL void KERNEL_NAME( ...@@ -222,6 +223,7 @@ KERNEL void KERNEL_NAME(
// dims_1+ <- batched dimensions // dims_1+ <- batched dimensions
ga_uint gid = GID_0, gidx; ga_uint gid = GID_0, gidx;
$set_slice $set_slice
// $$set_slice expands into:
//for(int i=1; i<NDIM; i++) { //for(int i=1; i<NDIM; i++) {
// gidx = gid % dims_$${i}; // gidx = gid % dims_$${i};
// gid /= dims_$${i}; // gid /= dims_$${i};
...@@ -245,7 +247,7 @@ KERNEL void KERNEL_NAME( ...@@ -245,7 +247,7 @@ KERNEL void KERNEL_NAME(
// `has_topk`. This will return the resulting index into which we // `has_topk`. This will return the resulting index into which we
// need to write the result, if a thread has a result. // need to write the result, if a thread has a result.
// All threads need to participate in the loop and the prefix sum, // All threads need to participate in the loop and the cumsum
// but not necessarily in the load; hence loop bounds being rounded // but not necessarily in the load; hence loop bounds being rounded
// up to a multiple of the block dim. // up to a multiple of the block dim.
COUNT_TYPE iter_bound = size + LDIM_0-1; COUNT_TYPE iter_bound = size + LDIM_0-1;
......
...@@ -2,7 +2,6 @@ from __future__ import absolute_import, print_function, division ...@@ -2,7 +2,6 @@ from __future__ import absolute_import, print_function, division
import os import os
from string import Template from string import Template
import numpy as np
from theano import Apply from theano import Apply
from theano.tensor import as_tensor_variable from theano.tensor import as_tensor_variable
from theano.tensor.sort import TopKOp from theano.tensor.sort import TopKOp
...@@ -116,7 +115,8 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -116,7 +115,8 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
def build_kernel(fname, kname, subs): def build_kernel(fname, kname, subs):
with open(os.path.join( with open(os.path.join(
os.path.dirname(__file__), 'c_code', fname)) as f: os.path.dirname(__file__), 'c_code', fname)
) as f:
kernel_src = f.read() kernel_src = f.read()
ker = Kernel( ker = Kernel(
code=Template(common_src + kernel_src).substitute(**subs), code=Template(common_src + kernel_src).substitute(**subs),
......
...@@ -359,10 +359,11 @@ class TopKOp(theano.Op): ...@@ -359,10 +359,11 @@ class TopKOp(theano.Op):
'"idx_dtype" parameter must be an integer dtype, got "%s"' % idx_dtype) '"idx_dtype" parameter must be an integer dtype, got "%s"' % idx_dtype)
if not (return_indices or return_values): if not (return_indices or return_values):
raise ValueError("Neither return_values nor return_indices is True, this isn't allowed") raise ValueError(
"Neither return_values nor return_indices is True, this isn't allowed")
self.axis = axis self.axis = axis
self.sorted=sorted self.sorted = sorted
self.return_values = return_values self.return_values = return_values
self.return_indices = return_indices self.return_indices = return_indices
self.idx_dtype = idx_dtype self.idx_dtype = idx_dtype
...@@ -454,7 +455,7 @@ def topk(x, kth, axis=-1, sorted=True, idx_dtype='int64'): ...@@ -454,7 +455,7 @@ def topk(x, kth, axis=-1, sorted=True, idx_dtype='int64'):
If ``None``, works on flattened array. If ``None``, works on flattened array.
sorted: bool sorted: bool
NOTE: NOT IMPLEMENTED YET NOTE: NOT IMPLEMENTED YET, USE ``False`` FOR NOW.
Defaults to ``True`` Defaults to ``True``
If True, the result array would be sorted in descending order. If True, the result array would be sorted in descending order.
...@@ -497,6 +498,7 @@ def argtopk(x, kth, axis=-1, sorted=True, idx_dtype='int64'): ...@@ -497,6 +498,7 @@ def argtopk(x, kth, axis=-1, sorted=True, idx_dtype='int64'):
Must not be 0. If negative, gives k-smallest elements instead. Must not be 0. If negative, gives k-smallest elements instead.
sorted: bool sorted: bool
NOTE: NOT IMPLEMENTED YET, USE ``False`` FOR NOW.
Defaults to ``True`` Defaults to ``True``
If True, the result array of corresponding indices would be sorted in descending order. If True, the result array of corresponding indices would be sorted in descending order.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论