Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
7b13e2f4
提交
7b13e2f4
authored
6月 25, 2017
作者:
Adam Becker
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
multiple improvements to gpu topk
- added xlarge kernel to handle array size >= 2^31 - ported original pytorch kernel - various small fixes
上级
330dd345
隐藏空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
318 行增加
和
156 行删除
+318
-156
k_topk_common.cuh
theano/gpuarray/c_code/k_topk_common.cuh
+7
-1
k_topk_dense.cu
theano/gpuarray/c_code/k_topk_dense.cu
+2
-3
k_topk_dense_large.cu
theano/gpuarray/c_code/k_topk_dense_large.cu
+256
-106
sort.py
theano/gpuarray/sort.py
+49
-45
sort.py
theano/tensor/sort.py
+4
-1
没有找到文件。
theano/gpuarray/c_code/k_topk_common.cuh
浏览文件 @
7b13e2f4
...
@@ -260,6 +260,12 @@ struct RadixConfig<ga_half> {
...
@@ -260,6 +260,12 @@ struct RadixConfig<ga_half> {
#error "RADIX_SIZE must be smaller than warp size (32)"
#error "RADIX_SIZE must be smaller than warp size (32)"
#endif
#endif
void __device__ atomicAdd(long long *dst, long long &src) {
atomicAdd(
reinterpret_cast<unsigned long long*>(dst),
reinterpret_cast<unsigned long long&>(src));
}
template <typename T>
template <typename T>
static inline __device__ T binary_cumsum(
static inline __device__ T binary_cumsum(
int idx, int warp_id, T* smem, bool value) {
int idx, int warp_id, T* smem, bool value) {
...
@@ -343,7 +349,7 @@ static __device__ inline T& ptr_at(T *ptr, ga_ssize offset) {
...
@@ -343,7 +349,7 @@ static __device__ inline T& ptr_at(T *ptr, ga_ssize offset) {
// read array element using raw(byte) offset
// read array element using raw(byte) offset
template <typename T>
template <typename T>
static __device__ inline T ptr_read(T *ptr, ga_ssize offset) {
static __device__ inline T ptr_read
_cached
(T *ptr, ga_ssize offset) {
return __ldg(((T*)((char*)ptr + offset)));
return __ldg(((T*)((char*)ptr + offset)));
}
}
theano/gpuarray/c_code/k_topk_dense.cu
浏览文件 @
7b13e2f4
...
@@ -29,9 +29,8 @@ KERNEL void k_topk_dense(
...
@@ -29,9 +29,8 @@ KERNEL void k_topk_dense(
const ga_ubyte warp_id = idx / GA_WARP_SIZE;
const ga_ubyte warp_id = idx / GA_WARP_SIZE;
// 0. get the slice for thread block to work on
// 0. get the slice for thread block to work on
// TODO if ndim <= 3, use native indexing ? (blockIdx.[xyz])
ga_size gid = GID_0, gidx;
ga_size gid = GID_0, gidx;
$set_slice
$set_slice
//for(int i=1; i<NDIM; i++) {
//for(int i=1; i<NDIM; i++) {
...
@@ -76,6 +75,7 @@ KERNEL void k_topk_dense(
...
@@ -76,6 +75,7 @@ KERNEL void k_topk_dense(
}
}
local_barrier();
local_barrier();
// find the bucket and update k2
// smem[:RADIX_SIZE:-1] = k2 - cumsum(smem[:RADIX_SIZE-1:-1])
// smem[:RADIX_SIZE:-1] = k2 - cumsum(smem[:RADIX_SIZE-1:-1])
if (idx == 0) {
if (idx == 0) {
ga_int sum = k2;
ga_int sum = k2;
...
@@ -130,4 +130,3 @@ KERNEL void k_topk_dense(
...
@@ -130,4 +130,3 @@ KERNEL void k_topk_dense(
#endif
#endif
}
}
}
}
theano/gpuarray/c_code/k_topk_dense_large.cu
浏览文件 @
7b13e2f4
#define RADIX_BITS 2
#define RADIX_BITS 2
#define RADIX_SIZE (1<<RADIX_BITS)
#define RADIX_SIZE (1<<RADIX_BITS)
#define RADIX_MASK(n) ((RADIX_SIZE-1) << (n*RADIX_BITS))
#define RADIX_DIGITS(T) (bitsof(T)/RADIX_BITS)
#define RADIX_DIGITS(T) (bitsof(T)/RADIX_BITS)
// works when length on axis is in [1025, 2^31-1]
#define COUNT_TYPE $count_t
KERNEL void k_topk_dense_large(
#define KERNEL_NAME $kname
// works when array size along axis is within [1025, 2^63-1]
template <typename DataType, typename RadixType, typename CountType>
__device__ DataType find_pattern(DataType* smem,
DataType* data,
CountType slice_size,
CountType stride,
RadixType known_bits,
RadixType known_bits_mask) {
if (LID_0 < 32)
smem[LID_0] = 0;
local_barrier();
// All threads participate in the loop, in order to sync on the flag
for (CountType i = LID_0; i < (slice_size + (CountType)LDIM_0-1); i += LDIM_0) {
bool in_range = (i < slice_size);
DataType v = in_range ? ptr_read_cached(data, i*stride) : 0;
if (in_range && ((RadixConfig<DataType>::convert(v) & known_bits_mask) == known_bits)) {
// There should not be conflicts if we are using find_pattern,
// since the result is unique
smem[0] = 1;
smem[1] = v; // can't use val as the flag, since it could be 0
}
local_barrier();
DataType found = smem[0];
DataType val = smem[1];
local_barrier();
// Check to see if a thread found the value
if (found != 0)
return val;
}
return 0;
}
// This function counts the distribution of all input values in a
// slice we are selecting by radix digit at `radix_digit_pos`, but only
// those that pass the filter `((v & known_bits_mask) == known_bits)`.
// This produces and broadcasts the seen counts for a single block only.
// `smem` must have at least `RADIX_SIZE` elements.
template <typename DataType, typename RadixType, typename CountType>
__device__ void count_radix_masked(CountType counts[RADIX_SIZE],
CountType* smem,
RadixType known_bits,
RadixType known_bits_mask,
int radix_digit_pos,
CountType slice_size,
CountType stride,
DataType* data) {
// Clear out per-thread counts from a previous round
#pragma unroll
for (int i = 0; i < RADIX_SIZE; ++i)
counts[i] = 0;
if (LID_0 < RADIX_SIZE)
smem[LID_0] = 0;
local_barrier();
// Scan over all the data. Upon a read, the warp will accumulate
// counts per each digit in the radix using warp voting.
for (CountType i = LID_0; i < slice_size; i += LDIM_0) {
RadixType val = RadixConfig<DataType>::convert(ptr_read_cached(data, i*stride));
bool hasVal = ((val & known_bits_mask) == known_bits);
RadixType digit_in_radix = Bitfield<RadixType>::get(val, radix_digit_pos, RADIX_BITS);
#pragma unroll
for (int j = 0; j < RADIX_SIZE; ++j) {
bool vote = hasVal && (digit_in_radix == j);
counts[j] += __popc(__ballot(vote));
}
}
// Now, for each warp, sum values
if (lane_id() == 0) {
for (int i=0; i<RADIX_SIZE; ++i)
atomicAdd(&smem[i], counts[i]);
}
/*
// not sure why, but this just give wrong results
if (lane_id() < RADIX_SIZE)
atomicAdd(&smem[lane_id()], counts[lane_id()]);
*/
local_barrier();
// For each thread, read in the total counts
#pragma unroll
for (unsigned int i = 0; i < RADIX_SIZE; ++i)
counts[i] = smem[i];
local_barrier();
}
template <typename DataType, typename RadixType, typename CountType>
__device__ void radix_select(DataType* data,
CountType k,
bool order,
CountType slice_size,
CountType stride,
CountType* smem,
DataType* top_kth) {
// Per-thread buckets into which we accumulate digit counts in our
// radix
register CountType counts[RADIX_SIZE];
// We only consider elements x such that (x & known_bits_mask) == known_bits
// Initially, we consider all elements of the array, so the above
// statement is true regardless of input.
RadixType known_bits = 0, known_bits_mask = 0;
// We are looking for the top k_to_find-th element when iterating over
// digits; this count gets reduced by elimination when counting
// successive digits
CountType k_to_find = abs(k);
// We start at the most significant digit in our radix, scanning
// through to the least significant digit
#pragma unroll
for (int digit_pos = bitsof(DataType) - RADIX_BITS;
digit_pos >= 0; digit_pos -= RADIX_BITS) {
// Count radix distribution for the current position and reduce
// across all threads
count_radix_masked<DataType, RadixType, CountType>(
counts, smem,
known_bits, known_bits_mask, digit_pos,
slice_size, stride, data);
// All threads participate in the comparisons below to know the
// final result
#define CHECK_RADIX(i) \\
int count = counts[i]; \\
/* All threads have the same value in counts here, so all */ \\
/* threads will return from the function. */ \\
if (count == 1 && k_to_find == 1) { \\
/* There is a unique answer. */ \\
known_bits = Bitfield<RadixType>::set( \\
known_bits, i, digit_pos, RADIX_BITS); \\
known_bits_mask = Bitfield<RadixType>::set( \\
known_bits_mask, RADIX_SIZE-1, digit_pos, RADIX_BITS); \\
/* The answer is now the unique element v such that: */ \\
/* (v & known_bits_mask) == known_bits */ \\
/* However, we do not yet know what the actual element is. We */ \\
/* need to perform a search through the data to find the */ \\
/* element that matches this pattern. */ \\
*top_kth = find_pattern<DataType, RadixType, CountType>( \\
(DataType*) smem, data, slice_size, \\
stride, known_bits, known_bits_mask); \\
return; \\
} \\
if (count >= k_to_find) { \\
known_bits = Bitfield<RadixType>::set(known_bits, i, digit_pos, RADIX_BITS); \\
known_bits_mask = Bitfield<RadixType>::set( \\
known_bits_mask, RADIX_SIZE-1, digit_pos, RADIX_BITS); \\
/* The top-Kth element v must now be one such that: */ \\
/* (v & known_bits_mask == known_bits) */ \\
/* but we haven't narrowed it down; we must check the next */ \\
/* least-significant digit */ \\
break; \\
} \\
k_to_find -= count
if (order) {
#pragma unroll
for (int i=RADIX_SIZE - 1; i >= 0; --i) {
CHECK_RADIX(i);
}
} else {
#pragma unroll
for (int i=0; i < RADIX_SIZE; ++i) {
CHECK_RADIX(i);
}
}
#undef CHECK_RADIX
} // end digit_pos for
// There is no unique result, but there is a non-unique result
// matching `known_bits` exactly
*top_kth = RadixConfig<DataType>::deconvert(known_bits);
}
KERNEL void KERNEL_NAME(
$dims
$dims
// ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM}
// ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM}
$dstv
$dstv
...
@@ -19,139 +208,100 @@ KERNEL void k_topk_dense_large(
...
@@ -19,139 +208,100 @@ KERNEL void k_topk_dense_large(
INPUT_TYPE* src,
INPUT_TYPE* src,
$src_strides
$src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
ga_size size, ga_ushort inp_per_thread) {
ga_size size) {
LOCAL_MEM ga_int smem[32];
LOCAL_MEM COUNT_TYPE smem[32];
LOCAL_MEM radix_t known_bits;
INPUT_TYPE topkth_value;
LOCAL_MEM ga_uint k2;
int counts[RADIX_SIZE];
const bool order = (k>0);
unsigned out_idx;
k = (order ? k : -k);
INPUT_TYPE xval;
const ga_int idx = LID_0;
radix_t x;
bool in_range, is_topk;
const ga_uint idx = LID_0;
const ga_uint inp_idx = idx * inp_per_thread;
const ga_int warp_id = idx / GA_WARP_SIZE;
const ga_int warp_id = idx / GA_WARP_SIZE;
// 0. get the slice for thread block to work on
// get the slice for thread block to work on
// TODO if ndim <= 3, use native indexing ? (blockIdx.[xyz])
// size <- the axis to work on
// dims_1+ <- batched dimensions
ga_uint gid = GID_0, gidx;
ga_uint gid = GID_0, gidx;
$set_slice
$set_slice
//for(int i=1; i<NDIM; i++) {
//for(int i=1; i<NDIM; i++) {
// gidx = gid % dims_$${i};
// gidx = gid % dims_$${i};
// gid /= dims_$${i};
// gid /= dims_$${i};
// dsti = ptr_add(dsti, gidx*dsti_strides_$${i};
// dsti = ptr_add(dsti, gidx*dsti_strides_$${i}
)
;
// dstv = ptr_add(dstv, gidx*dstv_strides_$${i};
// dstv = ptr_add(dstv, gidx*dstv_strides_$${i}
)
;
// src = ptr_add(src, gidx*src_strides_$${i});
// src = ptr_add(src, gidx*src_strides_$${i});
//}
//}
src = ptr_add(src, idx*inp_per_thread*src_strides_0);
if (idx==0) {
known_bits = 0;
k2 = (k>=0) ? k : -k;
}
const radix_t inv_bits = (k>=0) ? 0 : ~0;
if (k<0) { k = -k; }
local_barrier();
// 1. find bits of top-k-th value using radix select
radix_select<INPUT_TYPE, radix_t, COUNT_TYPE>(
#pragma unroll
src, k, order, size, src_strides_0,
for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
smem, &topkth_value);
#pragma unroll
for (int j=0; j<RADIX_SIZE; ++j)
counts[j] = 0;
if (warp_id == 0)
smem[idx] = 0;
local_barrier();
// count within warp
// Every value that is strictly less/greater than `pattern`
for (int j=0; j<inp_per_thread; ++j) {
// (depending on sort dir) in sorted int format is in the top-K.
in_range = (inp_idx+j) < size;
// The top-K value itself might not be unique.
xval = in_range ? ptr_read(src, j*src_strides_0) : (INPUT_TYPE)0;
//
x = inv_bits^RadixConfig<INPUT_TYPE>::convert(xval);
// Since there are a variable number of elements that we see that
ga_int digit = (int)((x>>i) & (RADIX_SIZE-1));
// are within the top-k, we don't know at what index to write out
// the resulting values.
#pragma unroll
// In order to get this, we perform an exclusive cumsum of
for (int bin=0; bin<RADIX_SIZE; ++bin) {
// `has_topk`. This will return the resulting index into which we
bool incr_bin = (
// need to write the result, if a thread has a result.
(bin == digit) &&
((x >> (i+RADIX_BITS)) == known_bits) && in_range);
counts[bin] += __popc(__ballot(incr_bin));
}
}
local_barrier();
// sum counts across all warps
// All threads need to participate in the loop and the prefix sum,
if (lane_id() < RADIX_SIZE) {
// but not necessarily in the load; hence loop bounds being rounded
atomicAdd(&smem[lane_id()], counts[lane_id()]);
// up to a multiple of the block dim.
}
COUNT_TYPE iter_bound = size + LDIM_0-1;
local_barrier()
;
INDEX_TYPE write_base = 0
;
// update known bits
for (int i = idx; i < iter_bound; i += LDIM_0) {
if (idx==0) {
bool in_range = (i < size);
#pragma unroll
INPUT_TYPE v = in_range ? ptr_read_cached(src, i*src_strides_0) : 0;
for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
bool has_topk;
if (smem[bin] >= k2) {
if (order) {
known_bits = (known_bits << RADIX_BITS) | bin;
has_topk = in_range && (v > topkth_value);
break;
} else {
} else
has_topk = in_range && (v < topkth_value);
k2 -= smem[bin];
}
}
}
local_barrier();
}
// now we use k2 for base index to write output
int index = binary_cumsum_exclusive(idx, warp_id, smem, has_topk);
if (idx == 0)
int carry = smem[LDIM_0 / 32 - 1];
k2 = 0;
local_barrier();
// 2. write values smaller than top-kth
if (has_topk) {
for (int i=0; i<inp_per_thread; ++i) {
COUNT_TYPE write_idx = write_base + index;
in_range = (inp_idx+i) < size;
xval = in_range ? ptr_read(src, i*src_strides_0) : (INPUT_TYPE)0;
x = inv_bits ^ RadixConfig<INPUT_TYPE>::convert(xval);
is_topk = (x > known_bits) && in_range;
out_idx = binary_cumsum(idx, warp_id, smem, is_topk);
if (is_topk) {
#if WRITE_VALUE == 1
#if WRITE_VALUE == 1
ptr_at(dstv,
(out_idx+k2-1) * dstv_strides_0) = xval
;
ptr_at(dstv,
write_idx * dstv_strides_0) = v
;
#endif
#endif
#if WRITE_INDEX == 1
#if WRITE_INDEX == 1
ptr_at(dsti,
(out_idx+k2-1) * dsti_strides_0) = (INDEX_TYPE)(idx*inp_per_thread + i)
;
ptr_at(dsti,
write_idx * dsti_strides_0) = (INDEX_TYPE)i
;
#endif
#endif
}
}
local_barrier();
if (idx == blockDim.x - 1)
write_base += carry;
k2 += out_idx;
local_barrier();
}
}
// 3. write values equal to top-kth
for (int i=0; i<inp_per_thread; ++i) {
COUNT_TYPE topk_remaining = (k - write_base);
in_range = (inp_idx+i) < size;
xval = in_range ? ptr_read(src, i*src_strides_0) : (INPUT_TYPE)0;
for (COUNT_TYPE i = idx; i < iter_bound; i += LDIM_0) {
x = inv_bits ^ RadixConfig<INPUT_TYPE>::convert(xval);
bool in_range = (i < size);
is_topk = (x == known_bits) && in_range;
INPUT_TYPE v = in_range ? ptr_read_cached(src, i*src_strides_0) : 0;
out_idx = binary_cumsum(idx, warp_id, smem, is_topk);
bool has_topk = in_range && (v == topkth_value);
is_topk &= (out_idx+k2) <= k;
if (is_topk) {
int index = binary_cumsum_exclusive(idx, warp_id, smem, has_topk);
int carry = smem[LDIM_0 / 32 - 1];
if (has_topk && index < topk_remaining) {
COUNT_TYPE write_idx = write_base + index;
#if WRITE_VALUE == 1
#if WRITE_VALUE == 1
ptr_at(dstv, (out_idx+k2-1) * dstv_strides_0) = xval
;
ptr_at(dstv, write_idx * dstv_strides_0) = v
;
#endif
#endif
#if WRITE_INDEX == 1
#if WRITE_INDEX == 1
ptr_at(dsti, (out_idx+k2-1) * dsti_strides_0) = (INDEX_TYPE)(inp_idx+ i)
;
ptr_at(dsti, write_idx * dsti_strides_0) = (INDEX_TYPE)i
;
#endif
#endif
}
}
local_barrier();
if (idx == blockDim.x - 1)
if (carry >= topk_remaining)
k2 += out_idx;
local_barrier();
if(k2 >= k)
break;
break;
topk_remaining -= carry;
write_base += carry;
}
}
}
}
theano/gpuarray/sort.py
浏览文件 @
7b13e2f4
...
@@ -58,20 +58,8 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
...
@@ -58,20 +58,8 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
def
gpu_kernels
(
self
,
node
,
nodename
):
def
gpu_kernels
(
self
,
node
,
nodename
):
# load kernel source
# load kernel source
device_type
=
node
.
inputs
[
0
]
.
type
.
context
.
kind
device_type
=
node
.
inputs
[
0
]
.
type
.
context
.
kind
knames
=
[
'k_topk_dense'
,
'k_topk_dense_large'
]
kernel_ext
=
{
b
'cuda'
:
'.cu'
,
b
'opencl'
:
'.cl'
}[
device_type
]
kernel_ext
=
{
b
'cuda'
:
'.cu'
,
b
'opencl'
:
'.cl'
}[
device_type
]
common_ext
=
{
b
'cuda'
:
'.cuh'
,
b
'opencl'
:
'.h'
}[
device_type
]
common_ext
=
{
b
'cuda'
:
'.cuh'
,
b
'opencl'
:
'.h'
}[
device_type
]
kernel_src
=
{}
for
kname
in
knames
:
with
open
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'c_code'
,
kname
+
kernel_ext
),
'r'
)
as
f
:
kernel_src
[
kname
]
=
f
.
read
()
with
open
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'c_code'
,
'k_topk_common'
+
common_ext
),
'r'
)
as
f
:
common_src
=
f
.
read
()
# prepare "$" macros
# prepare "$" macros
if
device_type
==
b
'cuda'
:
if
device_type
==
b
'cuda'
:
...
@@ -108,31 +96,46 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
...
@@ -108,31 +96,46 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
elif
device_type
==
b
'opencl'
:
elif
device_type
==
b
'opencl'
:
raise
NotImplementedError
()
raise
NotImplementedError
()
# compile kernels
# setup parameters
kernels
=
[]
param_types
=
[
ga
.
SIZE
]
*
(
ndim
-
1
)
# dims
param_types
=
[
ga
.
SIZE
]
*
(
ndim
-
1
)
# dims
for
_
in
range
(
int
(
self
.
return_values
)
+
int
(
self
.
return_indices
)
):
for
_
in
range
(
self
.
return_values
+
self
.
return_indices
):
param_types
.
append
(
ga
.
GpuArray
)
# dst*
param_types
.
append
(
ga
.
GpuArray
)
# dst*
param_types
.
extend
([
ga
.
SSIZE
]
*
ndim
)
# dst*_strides
param_types
.
extend
([
ga
.
SSIZE
]
*
ndim
)
# dst*_strides
param_types
.
append
(
ga
.
SIZE
)
# k
param_types
.
append
(
ga
.
SIZE
)
# k
param_types
.
append
(
ga
.
GpuArray
)
# src
param_types
.
append
(
ga
.
GpuArray
)
# src
param_types
.
extend
([
ga
.
SSIZE
]
*
ndim
)
# src_strides
param_types
.
extend
([
ga
.
SSIZE
]
*
ndim
)
# src_strides
param_types
.
append
(
ga
.
SIZE
)
# size
param_types
.
append
(
ga
.
SIZE
)
# size
kernels
.
append
(
Kernel
(
code
=
Template
(
common_src
+
kernel_src
[
'k_topk_dense'
])
.
substitute
(
**
subs
),
# load and compile kernels
name
=
'k_topk_dense'
,
with
open
(
os
.
path
.
join
(
params
=
param_types
,
os
.
path
.
dirname
(
__file__
),
'c_code'
,
'k_topk_common'
+
common_ext
flags
=
flags
,
))
as
f
:
objvar
=
'k_topk_dense_'
+
nodename
common_src
=
f
.
read
()
))
param_types
.
append
(
np
.
uint16
)
# inp_per_thread
kernels
=
[]
kernels
.
append
(
Kernel
(
code
=
Template
(
common_src
+
kernel_src
[
'k_topk_dense_large'
])
.
substitute
(
**
subs
),
def
build_kernel
(
fname
,
kname
,
subs
):
name
=
'k_topk_dense_large'
,
with
open
(
os
.
path
.
join
(
params
=
param_types
,
os
.
path
.
dirname
(
__file__
),
'c_code'
,
fname
))
as
f
:
flags
=
flags
,
kernel_src
=
f
.
read
()
objvar
=
'k_topk_dense_large_'
+
nodename
ker
=
Kernel
(
))
code
=
Template
(
common_src
+
kernel_src
)
.
substitute
(
**
subs
),
name
=
kname
,
params
=
param_types
,
flags
=
flags
,
objvar
=
kname
+
nodename
)
return
ker
subs
[
'count_t'
]
=
'int'
kernels
.
append
(
build_kernel
(
'k_topk_dense'
+
kernel_ext
,
'k_topk_dense'
,
subs
))
subs
[
'kname'
]
=
'k_topk_dense_large'
kernels
.
append
(
build_kernel
(
'k_topk_dense_large'
+
kernel_ext
,
'k_topk_dense_large'
,
subs
))
subs
[
'count_t'
]
=
'long long'
subs
[
'kname'
]
=
'k_topk_dense_xlarge'
kernels
.
append
(
build_kernel
(
'k_topk_dense_large'
+
kernel_ext
,
'k_topk_dense_xlarge'
,
subs
))
return
kernels
return
kernels
def
c_code
(
self
,
node
,
nodename
,
inps
,
outs
,
sub
):
def
c_code
(
self
,
node
,
nodename
,
inps
,
outs
,
sub
):
...
@@ -204,16 +207,11 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
...
@@ -204,16 +207,11 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
PyExc_ValueError,
PyExc_ValueError,
"topk: kth must not be zero");
"topk: kth must not be zero");
%(fail)
s;
%(fail)
s;
} else if (dims[
%(axis)
d] < odims[
%(axis)
d]){
} else if (dims[
%(axis)
d] < odims[
%(axis)
d])
{
PyErr_SetString(
PyErr_SetString(
PyExc_ValueError,
PyExc_ValueError,
"topk: kth cannot be larger than the size of specified axis
%(axis)
d");
"topk: kth cannot be larger than the size of specified axis
%(axis)
d");
%(fail)
s;
%(fail)
s;
} else if (dims[
%(axis)
d] >= (1u << 31)) {
PyErr_SetString(
PyExc_ValueError,
"topk: on GPU, array size of specified axis cannot larger or equal than 2^31");
%(fail)
s;
}
}
%(prep_output)
s
%(prep_output)
s
...
@@ -221,7 +219,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
...
@@ -221,7 +219,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
size_t *grd = blk+3;
size_t *grd = blk+3;
blk[0] = blk[1] = blk[2] = 1;
blk[0] = blk[1] = blk[2] = 1;
grd[0] = grd[1] = grd[2] = 1;
grd[0] = grd[1] = grd[2] = 1;
for(int i=0; i<
%(ndim)
d; ++i) {
for
(int i=0; i<
%(ndim)
d; ++i) {
if (i!=
%(axis)
d)
if (i!=
%(axis)
d)
grd[0] *= dims[i];
grd[0] *= dims[i];
else
else
...
@@ -233,8 +231,6 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
...
@@ -233,8 +231,6 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
%(def_dvstrides)
s;
%(def_dvstrides)
s;
%(def_distrides)
s;
%(def_distrides)
s;
const ssize_t *sstrides = PyGpuArray_STRIDES(
%(x)
s);
const ssize_t *sstrides = PyGpuArray_STRIDES(
%(x)
s);
// inputs per thread
unsigned short ipt = (dims[
%(axis)
d] + (
%(MAX_TPB)
d / 2)-1) / (
%(MAX_TPB)
d / 2);
void* args[] = {
void* args[] = {
%(dims)
s
%(dims)
s
%(params_dv)
s
%(params_dv)
s
...
@@ -243,19 +239,27 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
...
@@ -243,19 +239,27 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
(void*)(
%(x)
s->ga.data),
(void*)(
%(x)
s->ga.data),
%(sstrides)
s,
%(sstrides)
s,
(void*)(dims+
%(axis)
d),
(void*)(dims+
%(axis)
d),
(void*)(&ipt)
};
};
int err;
int err;
if (blk[0] >
%(MAX_TPB)
d) {
if (dims[
%(axis)
d] > PY_SSIZE_T_MAX) {
// LAUNCH_OUT_OF_RESOURCE if a 1024 sized block is used
PyErr_SetString(
blk[0] =
%(MAX_TPB)
d / 2;
PyExc_ValueError,
"topk: array size on specified axis is too large, should be less than PY_SSIZE_T_MAX.");
%(fail)
s;
} else if (dims[
%(axis)
d] > (1u << 31)) {
blk[0] =
%(MAX_TPB)
d;
err = GpuKernel_call(
&k_topk_dense_xlarge
%(nodename)
s, 3,
grd, blk, 0, args);
} else if (blk[0] >
%(MAX_TPB)
d) {
blk[0] =
%(MAX_TPB)
d;
err = GpuKernel_call(
err = GpuKernel_call(
&k_topk_dense_large
_
%(nodename)
s, 3,
&k_topk_dense_large
%(nodename)
s, 3,
grd, blk, 0, args);
grd, blk, 0, args);
} else {
} else {
err = GpuKernel_call(
err = GpuKernel_call(
&k_topk_dense
_
%(nodename)
s, 3,
&k_topk_dense
%(nodename)
s, 3,
grd, blk, 0, args);
grd, blk, 0, args);
}
}
if (err != GA_NO_ERROR) {
if (err != GA_NO_ERROR) {
...
...
theano/tensor/sort.py
浏览文件 @
7b13e2f4
...
@@ -227,7 +227,10 @@ def _topk_py_impl(op, x, k, axis, idx_dtype):
...
@@ -227,7 +227,10 @@ def _topk_py_impl(op, x, k, axis, idx_dtype):
assert
-
ndim
<=
axis
<
ndim
assert
-
ndim
<=
axis
<
ndim
axis
%=
ndim
axis
%=
ndim
if
k
==
0
:
if
k
==
0
:
raise
ValueError
(
'topk: k cannot be zero'
)
raise
ValueError
(
'topk: kth cannot be zero'
)
elif
k
>
x
.
shape
[
axis
]:
raise
ValueError
(
'topk: kth cannot be larger than the size of specified axis
%
d'
%
axis
)
if
abs
(
k
)
==
1
:
if
abs
(
k
)
==
1
:
# negative k means min instead of max
# negative k means min instead of max
fn_max
=
[
None
,
np
.
max
,
np
.
min
][
k
]
fn_max
=
[
None
,
np
.
max
,
np
.
min
][
k
]
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论