Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
1728cd8c
提交
1728cd8c
authored
6月 13, 2017
作者:
Adam Becker
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
GPU speedup
上级
aa14269f
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
191 行增加
和
125 行删除
+191
-125
k_topk_common.cuh
theano/gpuarray/k_topk_common.cuh
+88
-22
k_topk_dense.cu
theano/gpuarray/k_topk_dense.cu
+43
-47
k_topk_dense_large.cu
theano/gpuarray/k_topk_dense_large.cu
+51
-53
sort.py
theano/gpuarray/sort.py
+9
-3
没有找到文件。
theano/gpuarray/k_topk_common.cuh
浏览文件 @
1728cd8c
...
@@ -48,6 +48,76 @@ POSSIBILITY OF SUCH DAMAGE.
...
@@ -48,6 +48,76 @@ POSSIBILITY OF SUCH DAMAGE.
#endif
#endif
__device__ __forceinline__ int lane_id() {
int id;
asm("mov.s32 %0, %laneid;" : "=r"(id) );
return id;
}
__device__ __forceinline__ unsigned lane_mask_lt() {
unsigned mask;
asm("mov.u32 %0, %%lanemask_lt;" : "=r"(mask));
return mask;
}
__device__ __forceinline__ unsigned lane_mask_le() {
unsigned mask;
asm("mov.u32 %0, %%lanemask_le;" : "=r"(mask));
return mask;
}
__device__ __forceinline__ unsigned lane_mask_gt() {
unsigned mask;
asm("mov.u32 %0, %%lanemask_gt;" : "=r"(mask));
return mask;
}
__device__ __forceinline__ unsigned lane_mask_ge() {
unsigned mask;
asm("mov.u32 %0, %%lanemask_ge;" : "=r"(mask));
return mask;
}
template <typename T>
struct Bitfield {};
template <>
struct Bitfield<unsigned int> {
static __device__ __forceinline__
unsigned int get(unsigned int val, int pos, int len) {
unsigned int ret;
asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len));
return ret;
}
static __device__ __forceinline__
unsigned int set(unsigned int val, unsigned int toInsert, int pos, int len) {
unsigned int ret;
asm("bfi.b32 %0, %1, %2, %3, %4;" :
"=r"(ret) : "r"(toInsert), "r"(val), "r"(pos), "r"(len));
return ret;
}
};
template <>
struct Bitfield<unsigned long long int> {
static __device__ __forceinline__
unsigned long long int get(unsigned long long int val, int pos, int len) {
unsigned long long int ret;
asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len));
return ret;
}
static __device__ __forceinline__
unsigned long long int set(unsigned long long int val, unsigned long long int toInsert, int pos, int len) {
unsigned long long int ret;
asm("bfi.b64 %0, %1, %2, %3, %4;" :
"=l"(ret) : "l"(toInsert), "l"(val), "r"(pos), "r"(len));
return ret;
}
};
template <typename T>
template <typename T>
struct RadixConfig {
struct RadixConfig {
// Converts a type (maybe float) to an integer representation with the same
// Converts a type (maybe float) to an integer representation with the same
...
@@ -182,10 +252,6 @@ struct RadixConfig<ga_half> {
...
@@ -182,10 +252,6 @@ struct RadixConfig<ga_half> {
#define INPUT_TYPE $inp_t
#define INPUT_TYPE $inp_t
#define INDEX_TYPE $out_t
#define INDEX_TYPE $out_t
#define bitsof(T) (sizeof(T)*8)
#define bitsof(T) (sizeof(T)*8)
#define RADIX_BITS 2
#define RADIX_SIZE (1<<RADIX_BITS)
#define RADIX_MASK(n) ((RADIX_SIZE-1) << (n*RADIX_BITS))
#define RADIX_DIGITS(T) (bitsof(T)/RADIX_BITS)
#define radix_t RadixConfig<INPUT_TYPE>::RadixType
#define radix_t RadixConfig<INPUT_TYPE>::RadixType
#define WRITE_VALUE $write_value
#define WRITE_VALUE $write_value
#define WRITE_INDEX $write_index
#define WRITE_INDEX $write_index
...
@@ -194,28 +260,28 @@ struct RadixConfig<ga_half> {
...
@@ -194,28 +260,28 @@ struct RadixConfig<ga_half> {
#error "RADIX_SIZE must be smaller than warp size (32)"
#error "RADIX_SIZE must be smaller than warp size (32)"
#endif
#endif
static inline __device__ ga_size binary_cumsum(
template <typename T>
int idx, int warp_id, int lane_id, ga_size* smem, bool value) {
static inline __device__ T binary_cumsum(
int idx, int warp_id, T* smem, bool value) {
// cumsum within 1D thread block, which adds up `value` of all threads
// cumsum within 1D thread block, which adds up `value` of all threads
// whose id is *no greater than* the current thread
// whose id is *no greater than* the current thread
// binary_cumsum(1, 0, 1, 0, 1) -> (1, 1, 2, 2, 3)
// binary_cumsum(1, 0, 1, 0, 1) -> (1, 1, 2, 2, 3)
// cumsum within warp
// cumsum within warp
ga_uint warp_bits = __ballot(value);
ga_uint warp_bits = __ballot(value);
ga_size warp_sum = __popc(((2<<lane_id)-1
) & warp_bits);
T warp_sum = __popc(lane_mask_le(
) & warp_bits);
if (lane_id == 0)
if (lane_id
()
== 0)
smem[warp_id] = __popc(warp_bits);
smem[warp_id] = __popc(warp_bits);
local_barrier();
local_barrier();
// cumsum across warps in one thread
// cumsum across warps in one thread
if (idx == 0) {
if (idx == 0) {
int current = 0;
T sum = smem[0];
for (int i = 0; i < LDIM_0 / GA_WARP_SIZE; ++i) {
for (int i = 1; i < LDIM_0 / GA_WARP_SIZE; ++i) {
ga_size v = smem[i];
sum += smem[i];
smem[i] = smem[i]+current;
smem[i] = sum;
current = current+v;
}
}
}
}
...
@@ -229,28 +295,28 @@ static inline __device__ ga_size binary_cumsum(
...
@@ -229,28 +295,28 @@ static inline __device__ ga_size binary_cumsum(
return warp_sum;
return warp_sum;
}
}
static inline __device__ ga_size binary_cumsum_exclusive(
template <typename T>
int idx, int warp_id, int lane_id, ga_size* smem, bool value) {
static inline __device__ T binary_cumsum_exclusive(
int idx, int warp_id, T* smem, bool value) {
// cumsum within 1D thread block, which adds up `value` of all threads
// cumsum within 1D thread block, which adds up `value` of all threads
// whose id is *less than* the current thread
// whose id is *less than* the current thread
// binary_cumsum_excl(1, 0, 1, 0, 1) -> (0, 1, 1, 2, 2)
// binary_cumsum_excl(1, 0, 1, 0, 1) -> (0, 1, 1, 2, 2)
// cumsum within warp
// cumsum within warp
ga_uint warp_bits = __ballot(value);
ga_uint warp_bits = __ballot(value);
ga_size warp_sum = __popc(((1<<lane_id)-1
) & warp_bits);
T warp_sum = __popc(lane_mask_lt(
) & warp_bits);
if (lane_id == 0)
if (lane_id
()
== 0)
smem[warp_id] = __popc(warp_bits);
smem[warp_id] = __popc(warp_bits);
local_barrier();
local_barrier();
// cumsum across warps in one thread
// cumsum across warps in one thread
if (idx == 0) {
if (idx == 0) {
int current = 0;
T sum = smem[0];
for (int i = 0; i < LDIM_0 / GA_WARP_SIZE; ++i) {
for (int i = 1; i < LDIM_0 / GA_WARP_SIZE; ++i) {
ga_size v = smem[i];
sum += smem[i];
smem[i] = smem[i]+current;
smem[i] = sum;
current = current+v;
}
}
}
}
...
...
theano/gpuarray/k_topk_dense.cu
浏览文件 @
1728cd8c
#define RADIX_BITS 4
#define RADIX_SIZE (1<<RADIX_BITS)
#define RADIX_MASK(n) ((RADIX_SIZE-1) << (n*RADIX_BITS))
#define RADIX_DIGITS(T) (bitsof(T)/RADIX_BITS)
// works when length on axis is within max allowed threads in block (1024)
// works when length on axis is within max allowed threads in block (1024)
KERNEL void k_topk_dense(
KERNEL void k_topk_dense(
$dims
$dims
...
@@ -15,17 +20,14 @@ KERNEL void k_topk_dense(
...
@@ -15,17 +20,14 @@ KERNEL void k_topk_dense(
$src_strides
$src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
ga_size size) {
ga_size size) {
LOCAL_MEM ga_size smem[32 * RADIX_SIZE];
LOCAL_MEM ga_int smem[32 * RADIX_SIZE];
ga_ssize LOCAL_MEM bins[RADIX_SIZE+1]; // TODO: does using 32-bit gives good speedup?
LOCAL_MEM ga_int k2;
bool is_topk=true, is_topkth=true;
const ga_uint idx = LID_0;
bool is_topk= (idx < size);
bool is_topkth = is_topk;
ga_size out_idx;
ga_size out_idx;
const ga_ushort idx = LID_0;
ga_size LOCAL_MEM k2, exceed;
const ga_ubyte warp_id = idx / GA_WARP_SIZE;
const ga_ubyte warp_id = idx / GA_WARP_SIZE;
const ga_ubyte lane_id = idx % GA_WARP_SIZE;
const bool in_range = (idx < size);
is_topk &= in_range;
// 0. get the slice for thread block to work on
// 0. get the slice for thread block to work on
...
@@ -41,90 +43,84 @@ KERNEL void k_topk_dense(
...
@@ -41,90 +43,84 @@ KERNEL void k_topk_dense(
//}
//}
// get input and its radix friendly form
// get input and its radix friendly form
const INPUT_TYPE xval = i
n_range
? ptr_at(src, idx*src_strides_0) : (INPUT_TYPE)0;
const INPUT_TYPE xval = i
s_topk
? ptr_at(src, idx*src_strides_0) : (INPUT_TYPE)0;
radix_t x =
in_range ? RadixConfig<INPUT_TYPE>::convert(xval) : 0
;
radix_t x =
RadixConfig<INPUT_TYPE>::convert(xval)
;
// resolve negative k
// resolve negative k
if (k<0) { x = ~x; k = -k; }
if (k<0) { x = ~x; k = -k; }
if (idx==0)
{
if (idx==0)
k2 = k;
k2 = k;
bins[RADIX_SIZE] = 1;
}
// 1. filter is_topk and is_topkth using radix select
// 1. filter is_topk and is_topkth using radix select
#pragma unroll
#pragma unroll
for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
int digit = (x>>i) & (RADIX_SIZE-1);
const ga_int digit = Bitfield<radix_t>::get(x, i, RADIX_BITS);
/*ga_int digit = (x>>i) & (RADIX_SIZE-1);*/
// count within warp
// count within warp
#pragma unroll
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
for (int bin=0; bin<RADIX_SIZE; ++bin) {
bool
incr_bin = (bin == digit) && is_topkth && in_range
;
bool
vote = (bin == digit) && is_topkth
;
ga_uint
incr_bin_warp = __ballot(incr_bin
);
ga_uint
votes = __ballot(vote
);
if (lane_id==0)
if (lane_id
()
==0)
smem[bin + RADIX_SIZE*warp_id] = __popc(
incr_bin_warp
);
smem[bin + RADIX_SIZE*warp_id] = __popc(
votes
);
}
}
local_barrier();
local_barrier();
// sum counts across all warps
// sum counts across all warps
// TODO: test in-block parallel sum?
if (idx < RADIX_SIZE) {
if (idx < RADIX_SIZE) {
ga_int sum = smem[idx];
#pragma unroll
for(int w=RADIX_SIZE; w<LDIM_0*RADIX_SIZE / GA_WARP_SIZE; w+=RADIX_SIZE)
for(int w=RADIX_SIZE; w<LDIM_0*RADIX_SIZE / GA_WARP_SIZE; w+=RADIX_SIZE)
smem[idx] += smem[idx + w];
sum += smem[idx + w];
smem[idx] = sum;
}
}
local_barrier();
local_barrier();
//
bins = k - cumsum(smem[:RADIX_SIZE
])
//
smem[:RADIX_SIZE:-1] = k2 - cumsum(smem[:RADIX_SIZE-1:-1
])
if (idx == 0) {
if (idx == 0) {
bins[RADIX_SIZE-1] = k2 - smem[RADIX_SIZE-1];
ga_int sum = k2;
if (bins[RADIX_SIZE-1] > 0)
k2 = bins[RADIX_SIZE-1];
#pragma unroll
#pragma unroll
for (int bin=RADIX_SIZE-1; bin; --bin) {
for (int bin=RADIX_SIZE-1; bin
>=0
; --bin) {
bins[bin-1] = bins[bin] - smem[bin-1
];
sum -= smem[bin
];
if (bins[bin-1] > 0)
smem[bin] = sum;
k2 = bins[bin-1]
;
k2 = (sum > 0) ? sum : k2
;
}
}
smem[RADIX_SIZE] = 1;
}
}
local_barrier();
local_barrier();
if (is_topkth) {
// smem -> count
is_topk &= (smem[digit+1] > 0);
// bins -> k2 - cumsum(count)
is_topkth &= (smem[digit] <= 0) && (smem[digit+1] > 0);
if (is_topk && is_topkth) {
ga_ssize icount = bins[digit];
if (icount > 0) {
is_topkth = false;
} else if (bins[digit+1] <= 0) {
is_topk = false;
is_topkth = false;
}
}
}
local_barrier();
}
}
// set k2 as number of exceeding values
if (idx==0) {
if (idx==0) {
#pragma unroll
#pragma unroll
for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
if (
bins
[bin] <= 0) {
if (
smem
[bin] <= 0) {
exceed = -bins
[bin];
k2 = -smem
[bin];
break;
break;
}
}
}
}
}
}
local_barrier();
local_barrier();
// 2. find the index of output array, if exists
// 2. find the index of output array, if exists
if (
exceed
!= 0) {
if (
k2
!= 0) {
// top_kth value may not be unique, so we need to
// top_kth value may not be unique, so we need to
// perform binary cumsum on is_topkth to drop exceeding top-kth values
// perform binary cumsum on is_topkth to drop exceeding top-kth values
out_idx = binary_cumsum_exclusive(idx, warp_id, lane_id, smem, is_topkth);
out_idx = binary_cumsum_exclusive(idx, warp_id, smem, is_topkth);
is_topk &= ((!is_topkth) || out_idx>=exceed);
if ((out_idx >= k2) && is_topkth)
is_topk = false;
local_barrier();
}
}
// perform binary cumsum on is_topk to determine the indices to put result
// perform binary cumsum on is_topk to determine the indices to put result
out_idx = binary_cumsum_exclusive(idx, warp_id, lane_id, smem, is_topk);
out_idx = binary_cumsum_exclusive(idx, warp_id, smem, is_topk);
local_barrier();
if (is_topk) {
if (is_topk) {
#if WRITE_VALUE == 1
#if WRITE_VALUE == 1
...
...
theano/gpuarray/k_topk_dense_large.cu
浏览文件 @
1728cd8c
// works when length on axis is larger than max allowed threads in block (1024)
#define RADIX_BITS 2
#define RADIX_SIZE (1<<RADIX_BITS)
#define RADIX_MASK(n) ((RADIX_SIZE-1) << (n*RADIX_BITS))
#define RADIX_DIGITS(T) (bitsof(T)/RADIX_BITS)
// works when length on axis is in [1025, 2^31-1]
KERNEL void k_topk_dense_large(
KERNEL void k_topk_dense_large(
$dims
$dims
// ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM}
// ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM}
...
@@ -15,23 +20,22 @@ KERNEL void k_topk_dense_large(
...
@@ -15,23 +20,22 @@ KERNEL void k_topk_dense_large(
$src_strides
$src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
ga_size size, ga_ushort inp_per_thread) {
ga_size size, ga_ushort inp_per_thread) {
LOCAL_MEM ga_size smem[32 * RADIX_SIZE];
LOCAL_MEM ga_int smem[32];
LOCAL_MEM radix_t known_bits, known_bits_mask;
LOCAL_MEM radix_t known_bits;
ga_size out_idx;
LOCAL_MEM ga_uint k2;
ga_size LOCAL_MEM write_base;
int counts[RADIX_SIZE];
unsigned out_idx;
INPUT_TYPE xval;
INPUT_TYPE xval;
radix_t x;
radix_t x;
ga_int i;
bool in_range, is_topk;
bool in_range, is_topk;
const ga_size idx = LID_0;
const ga_uint idx = LID_0;
ga_size LOCAL_MEM k2;
const ga_uint inp_idx = idx * inp_per_thread;
const ga_ushort warp_id = idx / GA_WARP_SIZE;
const ga_int warp_id = idx / GA_WARP_SIZE;
const ga_ushort lane_id = idx % GA_WARP_SIZE;
// 0. get the slice for thread block to work on
// 0. get the slice for thread block to work on
// TODO if ndim <= 3, use native indexing ? (blockIdx.[xyz])
// TODO if ndim <= 3, use native indexing ? (blockIdx.[xyz])
ga_
size
gid = GID_0, gidx;
ga_
uint
gid = GID_0, gidx;
$set_slice
$set_slice
//for(int i=1; i<NDIM; i++) {
//for(int i=1; i<NDIM; i++) {
// gidx = gid % dims_$${i};
// gidx = gid % dims_$${i};
...
@@ -42,55 +46,45 @@ KERNEL void k_topk_dense_large(
...
@@ -42,55 +46,45 @@ KERNEL void k_topk_dense_large(
//}
//}
src = ptr_add(src, idx*inp_per_thread*src_strides_0);
src = ptr_add(src, idx*inp_per_thread*src_strides_0);
LOCAL_MEM radix_t inv_bits;
if (idx==0) {
if (idx==0) {
known_bits = known_bits_mask = 0;
known_bits = 0;
k2 = abs(k);
k2 = (k>=0) ? k : -k;
inv_bits = (k>=0) ? 0 : (~0);
write_base = 0;
}
}
const radix_t inv_bits = (k>=0) ? 0 : ~0;
if (k<0) { k = -k; }
if (k<0) { k = -k; }
local_barrier();
local_barrier();
// 1. find bits of top-k-th value using radix select
// 1. find bits of top-k-th value using radix select
#pragma unroll
#pragma unroll
for (i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
/*for (i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i*=-1) {*/
#pragma unroll
if (lane_id == 0) {
for (int j=0; j<RADIX_SIZE; ++j)
#pragma unroll
counts[j] = 0;
for (int bin=0; bin<RADIX_SIZE; ++bin) {
if (warp_id == 0)
smem[bin + warp_id*RADIX_SIZE] = 0;
smem[idx] = 0;
}
}
local_barrier();
local_barrier();
// count within warp
for (int j=0; j<inp_per_thread; ++j) {
for (int j=0; j<inp_per_thread; ++j) {
in_range = (i
dx*inp_per_thread
+j) < size;
in_range = (i
np_idx
+j) < size;
xval = in_range ? ptr_read(src, j*src_strides_0) : (INPUT_TYPE)0;
xval = in_range ? ptr_read(src, j*src_strides_0) : (INPUT_TYPE)0;
x = inv_bits^RadixConfig<INPUT_TYPE>::convert(xval);
x = inv_bits^RadixConfig<INPUT_TYPE>::convert(xval);
ga_int digit = (int)((x>>i) & (RADIX_SIZE-1));
ga_int digit = (int)((x>>i) & (RADIX_SIZE-1));
// count within warp
#pragma unroll
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
for (int bin=0; bin<RADIX_SIZE; ++bin) {
bool incr_bin = (
bool incr_bin = (
(bin == digit) &&
(bin == digit) &&
((x&known_bits_mask) == known_bits) &&
((x >> (i+RADIX_BITS)) == known_bits) && in_range);
in_range);
counts[bin] += __popc(__ballot(incr_bin));
ga_uint incr_bin_warp = __ballot(incr_bin);
if (lane_id==0)
smem[bin + RADIX_SIZE*warp_id] += __popc(incr_bin_warp);
}
}
}
}
local_barrier();
local_barrier();
// sum counts across all warps
// sum counts across all warps
// TODO: test in-block parallel sum?
if (lane_id() < RADIX_SIZE) {
if (idx < RADIX_SIZE) {
atomicAdd(&smem[lane_id()], counts[lane_id()]);
for(int w=RADIX_SIZE;
w<(LDIM_0/ GA_WARP_SIZE)*RADIX_SIZE;
w+=RADIX_SIZE)
smem[idx] += smem[idx + w];
}
}
local_barrier();
local_barrier();
...
@@ -99,8 +93,7 @@ KERNEL void k_topk_dense_large(
...
@@ -99,8 +93,7 @@ KERNEL void k_topk_dense_large(
#pragma unroll
#pragma unroll
for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
if (smem[bin] >= k2) {
if (smem[bin] >= k2) {
known_bits |= (((radix_t)bin) << i);
known_bits = (known_bits << RADIX_BITS) | bin;
known_bits_mask |= (((radix_t)(RADIX_SIZE-1)) << i);
break;
break;
} else
} else
k2 -= smem[bin];
k2 -= smem[bin];
...
@@ -109,50 +102,55 @@ KERNEL void k_topk_dense_large(
...
@@ -109,50 +102,55 @@ KERNEL void k_topk_dense_large(
local_barrier();
local_barrier();
}
}
// now we use k2 for base index to write output
if (idx == 0)
k2 = 0;
local_barrier();
// 2. write values smaller than top-kth
// 2. write values smaller than top-kth
for (i=0; i<inp_per_thread; ++i) {
for (i
nt i
=0; i<inp_per_thread; ++i) {
in_range = (i
dx*inp_per_thread
+i) < size;
in_range = (i
np_idx
+i) < size;
xval = in_range ? ptr_read(src, i*src_strides_0) : (INPUT_TYPE)0;
xval = in_range ? ptr_read(src, i*src_strides_0) : (INPUT_TYPE)0;
x = inv_bits ^ RadixConfig<INPUT_TYPE>::convert(xval);
x = inv_bits ^ RadixConfig<INPUT_TYPE>::convert(xval);
is_topk = (x > known_bits) && in_range;
is_topk = (x > known_bits) && in_range;
out_idx = binary_cumsum(idx, warp_id,
lane_id,
smem, is_topk);
out_idx = binary_cumsum(idx, warp_id, smem, is_topk);
if (is_topk) {
if (is_topk) {
#if WRITE_VALUE == 1
#if WRITE_VALUE == 1
ptr_at(dstv, (out_idx+
write_base
-1) * dstv_strides_0) = xval;
ptr_at(dstv, (out_idx+
k2
-1) * dstv_strides_0) = xval;
#endif
#endif
#if WRITE_INDEX == 1
#if WRITE_INDEX == 1
ptr_at(dsti, (out_idx+
write_base
-1) * dsti_strides_0) = (INDEX_TYPE)(idx*inp_per_thread + i);
ptr_at(dsti, (out_idx+
k2
-1) * dsti_strides_0) = (INDEX_TYPE)(idx*inp_per_thread + i);
#endif
#endif
}
}
local_barrier();
local_barrier();
if (idx == blockDim.x - 1)
if (idx == blockDim.x - 1)
write_base
+= out_idx;
k2
+= out_idx;
local_barrier();
local_barrier();
}
}
// 3. write values equal to top-kth
// 3. write values equal to top-kth
for (i=0; i<inp_per_thread; ++i) {
for (i
nt i
=0; i<inp_per_thread; ++i) {
in_range = (i
dx*inp_per_thread
+i) < size;
in_range = (i
np_idx
+i) < size;
xval = in_range ? ptr_read(src, i*src_strides_0) : (INPUT_TYPE)0;
xval = in_range ? ptr_read(src, i*src_strides_0) : (INPUT_TYPE)0;
x = inv_bits ^ RadixConfig<INPUT_TYPE>::convert(xval);
x = inv_bits ^ RadixConfig<INPUT_TYPE>::convert(xval);
is_topk = (x == known_bits) && in_range;
is_topk = (x == known_bits) && in_range;
out_idx = binary_cumsum(idx, warp_id,
lane_id,
smem, is_topk);
out_idx = binary_cumsum(idx, warp_id, smem, is_topk);
is_topk
= ((out_idx+write_base) <= abs(k)) && is_top
k;
is_topk
&= (out_idx+k2) <=
k;
if (is_topk) {
if (is_topk) {
#if WRITE_VALUE == 1
#if WRITE_VALUE == 1
ptr_at(dstv, (out_idx+
write_base
-1) * dstv_strides_0) = xval;
ptr_at(dstv, (out_idx+
k2
-1) * dstv_strides_0) = xval;
#endif
#endif
#if WRITE_INDEX == 1
#if WRITE_INDEX == 1
ptr_at(dsti, (out_idx+
write_base-1) * dsti_strides_0) = (INDEX_TYPE)(idx*inp_per_thread
+ i);
ptr_at(dsti, (out_idx+
k2-1) * dsti_strides_0) = (INDEX_TYPE)(inp_idx
+ i);
#endif
#endif
}
}
local_barrier();
local_barrier();
if (idx == blockDim.x - 1)
if (idx == blockDim.x - 1)
write_base
+= out_idx;
k2
+= out_idx;
local_barrier();
local_barrier();
if(
write_base >= abs(k)
)
if(
k2 >= k
)
break;
break;
}
}
}
}
...
...
theano/gpuarray/sort.py
浏览文件 @
1728cd8c
...
@@ -20,6 +20,7 @@ except ImportError as e:
...
@@ -20,6 +20,7 @@ except ImportError as e:
pass
pass
# TODO GPU sort / argsort
# TODO GPU sort / argsort
# TODO support when k >= 2^31
class
GpuTopKOp
(
GpuKernelBase
,
TopKOp
):
class
GpuTopKOp
(
GpuKernelBase
,
TopKOp
):
...
@@ -34,7 +35,8 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
...
@@ -34,7 +35,8 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
self
,
axis
=-
1
,
self
,
axis
=-
1
,
idx_dtype
=
'int64'
,
idx_dtype
=
'int64'
,
return_values
=
True
,
return_values
=
True
,
return_indices
=
True
):
return_indices
=
True
):
GpuKernelBase
.
__init__
(
self
)
GpuKernelBase
.
__init__
(
self
)
TopKOp
.
__init__
(
TopKOp
.
__init__
(
self
,
axis
=
axis
,
self
,
axis
=
axis
,
...
@@ -206,10 +208,14 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
...
@@ -206,10 +208,14 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
PyExc_ValueError,
PyExc_ValueError,
"topk: k cannot larger than size on specified axis
%(axis)
d");
"topk: k cannot larger than size on specified axis
%(axis)
d");
%(fail)
s;
%(fail)
s;
} else if (odims[
%(axis)
d] > 0x7fffffffu) {
PyErr_SetString(
PyExc_ValueError,
"topk: on GPU, k cannot larger or equal than 2^31");
%(fail)
s;
}
}
%(prep_output)
s
%(prep_output)
s
// TODO better scheduling?
size_t blk[6];
size_t blk[6];
size_t *grd = blk+3;
size_t *grd = blk+3;
blk[0] = blk[1] = blk[2] = 1;
blk[0] = blk[1] = blk[2] = 1;
...
@@ -227,7 +233,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
...
@@ -227,7 +233,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
%(def_distrides)
s;
%(def_distrides)
s;
const ssize_t *sstrides = PyGpuArray_STRIDES(
%(x)
s);
const ssize_t *sstrides = PyGpuArray_STRIDES(
%(x)
s);
// inputs per thread
// inputs per thread
unsigned short ipt = (dims[
%(axis)
d] + (
%(MAX_TPB)
d
/2)-1) / (
%(MAX_TPB)
d/
2);
unsigned short ipt = (dims[
%(axis)
d] + (
%(MAX_TPB)
d
/ 2)-1) / (
%(MAX_TPB)
d /
2);
void* args[] = {
void* args[] = {
%(dims)
s
%(dims)
s
%(params_dv)
s
%(params_dv)
s
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论