Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
7b13e2f4
提交
7b13e2f4
authored
6月 25, 2017
作者:
Adam Becker
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
multiple improvements to gpu topk
- added xlarge kernel to handle array size >= 2^31 - ported original pytorch kernel - various small fixes
上级
330dd345
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
62 行增加
和
50 行删除
+62
-50
k_topk_common.cuh
theano/gpuarray/c_code/k_topk_common.cuh
+7
-1
k_topk_dense.cu
theano/gpuarray/c_code/k_topk_dense.cu
+2
-3
k_topk_dense_large.cu
theano/gpuarray/c_code/k_topk_dense_large.cu
+0
-0
sort.py
theano/gpuarray/sort.py
+49
-45
sort.py
theano/tensor/sort.py
+4
-1
没有找到文件。
theano/gpuarray/c_code/k_topk_common.cuh
浏览文件 @
7b13e2f4
...
@@ -260,6 +260,12 @@ struct RadixConfig<ga_half> {
...
@@ -260,6 +260,12 @@ struct RadixConfig<ga_half> {
#error "RADIX_SIZE must be smaller than warp size (32)"
#error "RADIX_SIZE must be smaller than warp size (32)"
#endif
#endif
void __device__ atomicAdd(long long *dst, long long &src) {
atomicAdd(
reinterpret_cast<unsigned long long*>(dst),
reinterpret_cast<unsigned long long&>(src));
}
template <typename T>
template <typename T>
static inline __device__ T binary_cumsum(
static inline __device__ T binary_cumsum(
int idx, int warp_id, T* smem, bool value) {
int idx, int warp_id, T* smem, bool value) {
...
@@ -343,7 +349,7 @@ static __device__ inline T& ptr_at(T *ptr, ga_ssize offset) {
...
@@ -343,7 +349,7 @@ static __device__ inline T& ptr_at(T *ptr, ga_ssize offset) {
// read array element using raw(byte) offset
// read array element using raw(byte) offset
template <typename T>
template <typename T>
static __device__ inline T ptr_read(T *ptr, ga_ssize offset) {
static __device__ inline T ptr_read
_cached
(T *ptr, ga_ssize offset) {
return __ldg(((T*)((char*)ptr + offset)));
return __ldg(((T*)((char*)ptr + offset)));
}
}
theano/gpuarray/c_code/k_topk_dense.cu
浏览文件 @
7b13e2f4
...
@@ -29,9 +29,8 @@ KERNEL void k_topk_dense(
...
@@ -29,9 +29,8 @@ KERNEL void k_topk_dense(
const ga_ubyte warp_id = idx / GA_WARP_SIZE;
const ga_ubyte warp_id = idx / GA_WARP_SIZE;
// 0. get the slice for thread block to work on
// 0. get the slice for thread block to work on
// TODO if ndim <= 3, use native indexing ? (blockIdx.[xyz])
ga_size gid = GID_0, gidx;
ga_size gid = GID_0, gidx;
$set_slice
$set_slice
//for(int i=1; i<NDIM; i++) {
//for(int i=1; i<NDIM; i++) {
...
@@ -76,6 +75,7 @@ KERNEL void k_topk_dense(
...
@@ -76,6 +75,7 @@ KERNEL void k_topk_dense(
}
}
local_barrier();
local_barrier();
// find the bucket and update k2
// smem[:RADIX_SIZE:-1] = k2 - cumsum(smem[:RADIX_SIZE-1:-1])
// smem[:RADIX_SIZE:-1] = k2 - cumsum(smem[:RADIX_SIZE-1:-1])
if (idx == 0) {
if (idx == 0) {
ga_int sum = k2;
ga_int sum = k2;
...
@@ -130,4 +130,3 @@ KERNEL void k_topk_dense(
...
@@ -130,4 +130,3 @@ KERNEL void k_topk_dense(
#endif
#endif
}
}
}
}
theano/gpuarray/c_code/k_topk_dense_large.cu
浏览文件 @
7b13e2f4
差异被折叠。
点击展开。
theano/gpuarray/sort.py
浏览文件 @
7b13e2f4
...
@@ -58,20 +58,8 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
...
@@ -58,20 +58,8 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
def
gpu_kernels
(
self
,
node
,
nodename
):
def
gpu_kernels
(
self
,
node
,
nodename
):
# load kernel source
# load kernel source
device_type
=
node
.
inputs
[
0
]
.
type
.
context
.
kind
device_type
=
node
.
inputs
[
0
]
.
type
.
context
.
kind
knames
=
[
'k_topk_dense'
,
'k_topk_dense_large'
]
kernel_ext
=
{
b
'cuda'
:
'.cu'
,
b
'opencl'
:
'.cl'
}[
device_type
]
kernel_ext
=
{
b
'cuda'
:
'.cu'
,
b
'opencl'
:
'.cl'
}[
device_type
]
common_ext
=
{
b
'cuda'
:
'.cuh'
,
b
'opencl'
:
'.h'
}[
device_type
]
common_ext
=
{
b
'cuda'
:
'.cuh'
,
b
'opencl'
:
'.h'
}[
device_type
]
kernel_src
=
{}
for
kname
in
knames
:
with
open
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'c_code'
,
kname
+
kernel_ext
),
'r'
)
as
f
:
kernel_src
[
kname
]
=
f
.
read
()
with
open
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'c_code'
,
'k_topk_common'
+
common_ext
),
'r'
)
as
f
:
common_src
=
f
.
read
()
# prepare "$" macros
# prepare "$" macros
if
device_type
==
b
'cuda'
:
if
device_type
==
b
'cuda'
:
...
@@ -108,31 +96,46 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
...
@@ -108,31 +96,46 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
elif
device_type
==
b
'opencl'
:
elif
device_type
==
b
'opencl'
:
raise
NotImplementedError
()
raise
NotImplementedError
()
# compile kernels
# setup parameters
kernels
=
[]
param_types
=
[
ga
.
SIZE
]
*
(
ndim
-
1
)
# dims
param_types
=
[
ga
.
SIZE
]
*
(
ndim
-
1
)
# dims
for
_
in
range
(
int
(
self
.
return_values
)
+
int
(
self
.
return_indices
)
):
for
_
in
range
(
self
.
return_values
+
self
.
return_indices
):
param_types
.
append
(
ga
.
GpuArray
)
# dst*
param_types
.
append
(
ga
.
GpuArray
)
# dst*
param_types
.
extend
([
ga
.
SSIZE
]
*
ndim
)
# dst*_strides
param_types
.
extend
([
ga
.
SSIZE
]
*
ndim
)
# dst*_strides
param_types
.
append
(
ga
.
SIZE
)
# k
param_types
.
append
(
ga
.
SIZE
)
# k
param_types
.
append
(
ga
.
GpuArray
)
# src
param_types
.
append
(
ga
.
GpuArray
)
# src
param_types
.
extend
([
ga
.
SSIZE
]
*
ndim
)
# src_strides
param_types
.
extend
([
ga
.
SSIZE
]
*
ndim
)
# src_strides
param_types
.
append
(
ga
.
SIZE
)
# size
param_types
.
append
(
ga
.
SIZE
)
# size
kernels
.
append
(
Kernel
(
code
=
Template
(
common_src
+
kernel_src
[
'k_topk_dense'
])
.
substitute
(
**
subs
),
# load and compile kernels
name
=
'k_topk_dense'
,
with
open
(
os
.
path
.
join
(
params
=
param_types
,
os
.
path
.
dirname
(
__file__
),
'c_code'
,
'k_topk_common'
+
common_ext
flags
=
flags
,
))
as
f
:
objvar
=
'k_topk_dense_'
+
nodename
common_src
=
f
.
read
()
))
param_types
.
append
(
np
.
uint16
)
# inp_per_thread
kernels
=
[]
kernels
.
append
(
Kernel
(
code
=
Template
(
common_src
+
kernel_src
[
'k_topk_dense_large'
])
.
substitute
(
**
subs
),
def
build_kernel
(
fname
,
kname
,
subs
):
name
=
'k_topk_dense_large'
,
with
open
(
os
.
path
.
join
(
params
=
param_types
,
os
.
path
.
dirname
(
__file__
),
'c_code'
,
fname
))
as
f
:
flags
=
flags
,
kernel_src
=
f
.
read
()
objvar
=
'k_topk_dense_large_'
+
nodename
ker
=
Kernel
(
))
code
=
Template
(
common_src
+
kernel_src
)
.
substitute
(
**
subs
),
name
=
kname
,
params
=
param_types
,
flags
=
flags
,
objvar
=
kname
+
nodename
)
return
ker
subs
[
'count_t'
]
=
'int'
kernels
.
append
(
build_kernel
(
'k_topk_dense'
+
kernel_ext
,
'k_topk_dense'
,
subs
))
subs
[
'kname'
]
=
'k_topk_dense_large'
kernels
.
append
(
build_kernel
(
'k_topk_dense_large'
+
kernel_ext
,
'k_topk_dense_large'
,
subs
))
subs
[
'count_t'
]
=
'long long'
subs
[
'kname'
]
=
'k_topk_dense_xlarge'
kernels
.
append
(
build_kernel
(
'k_topk_dense_large'
+
kernel_ext
,
'k_topk_dense_xlarge'
,
subs
))
return
kernels
return
kernels
def
c_code
(
self
,
node
,
nodename
,
inps
,
outs
,
sub
):
def
c_code
(
self
,
node
,
nodename
,
inps
,
outs
,
sub
):
...
@@ -204,16 +207,11 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
...
@@ -204,16 +207,11 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
PyExc_ValueError,
PyExc_ValueError,
"topk: kth must not be zero");
"topk: kth must not be zero");
%(fail)
s;
%(fail)
s;
} else if (dims[
%(axis)
d] < odims[
%(axis)
d]){
} else if (dims[
%(axis)
d] < odims[
%(axis)
d])
{
PyErr_SetString(
PyErr_SetString(
PyExc_ValueError,
PyExc_ValueError,
"topk: kth cannot be larger than the size of specified axis
%(axis)
d");
"topk: kth cannot be larger than the size of specified axis
%(axis)
d");
%(fail)
s;
%(fail)
s;
} else if (dims[
%(axis)
d] >= (1u << 31)) {
PyErr_SetString(
PyExc_ValueError,
"topk: on GPU, array size of specified axis cannot larger or equal than 2^31");
%(fail)
s;
}
}
%(prep_output)
s
%(prep_output)
s
...
@@ -221,7 +219,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
...
@@ -221,7 +219,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
size_t *grd = blk+3;
size_t *grd = blk+3;
blk[0] = blk[1] = blk[2] = 1;
blk[0] = blk[1] = blk[2] = 1;
grd[0] = grd[1] = grd[2] = 1;
grd[0] = grd[1] = grd[2] = 1;
for(int i=0; i<
%(ndim)
d; ++i) {
for
(int i=0; i<
%(ndim)
d; ++i) {
if (i!=
%(axis)
d)
if (i!=
%(axis)
d)
grd[0] *= dims[i];
grd[0] *= dims[i];
else
else
...
@@ -233,8 +231,6 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
...
@@ -233,8 +231,6 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
%(def_dvstrides)
s;
%(def_dvstrides)
s;
%(def_distrides)
s;
%(def_distrides)
s;
const ssize_t *sstrides = PyGpuArray_STRIDES(
%(x)
s);
const ssize_t *sstrides = PyGpuArray_STRIDES(
%(x)
s);
// inputs per thread
unsigned short ipt = (dims[
%(axis)
d] + (
%(MAX_TPB)
d / 2)-1) / (
%(MAX_TPB)
d / 2);
void* args[] = {
void* args[] = {
%(dims)
s
%(dims)
s
%(params_dv)
s
%(params_dv)
s
...
@@ -243,19 +239,27 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
...
@@ -243,19 +239,27 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
(void*)(
%(x)
s->ga.data),
(void*)(
%(x)
s->ga.data),
%(sstrides)
s,
%(sstrides)
s,
(void*)(dims+
%(axis)
d),
(void*)(dims+
%(axis)
d),
(void*)(&ipt)
};
};
int err;
int err;
if (blk[0] >
%(MAX_TPB)
d) {
if (dims[
%(axis)
d] > PY_SSIZE_T_MAX) {
// LAUNCH_OUT_OF_RESOURCE if a 1024 sized block is used
PyErr_SetString(
blk[0] =
%(MAX_TPB)
d / 2;
PyExc_ValueError,
"topk: array size on specified axis is too large, should be less than PY_SSIZE_T_MAX.");
%(fail)
s;
} else if (dims[
%(axis)
d] > (1u << 31)) {
blk[0] =
%(MAX_TPB)
d;
err = GpuKernel_call(
&k_topk_dense_xlarge
%(nodename)
s, 3,
grd, blk, 0, args);
} else if (blk[0] >
%(MAX_TPB)
d) {
blk[0] =
%(MAX_TPB)
d;
err = GpuKernel_call(
err = GpuKernel_call(
&k_topk_dense_large
_
%(nodename)
s, 3,
&k_topk_dense_large
%(nodename)
s, 3,
grd, blk, 0, args);
grd, blk, 0, args);
} else {
} else {
err = GpuKernel_call(
err = GpuKernel_call(
&k_topk_dense
_
%(nodename)
s, 3,
&k_topk_dense
%(nodename)
s, 3,
grd, blk, 0, args);
grd, blk, 0, args);
}
}
if (err != GA_NO_ERROR) {
if (err != GA_NO_ERROR) {
...
...
theano/tensor/sort.py
浏览文件 @
7b13e2f4
...
@@ -227,7 +227,10 @@ def _topk_py_impl(op, x, k, axis, idx_dtype):
...
@@ -227,7 +227,10 @@ def _topk_py_impl(op, x, k, axis, idx_dtype):
assert
-
ndim
<=
axis
<
ndim
assert
-
ndim
<=
axis
<
ndim
axis
%=
ndim
axis
%=
ndim
if
k
==
0
:
if
k
==
0
:
raise
ValueError
(
'topk: k cannot be zero'
)
raise
ValueError
(
'topk: kth cannot be zero'
)
elif
k
>
x
.
shape
[
axis
]:
raise
ValueError
(
'topk: kth cannot be larger than the size of specified axis
%
d'
%
axis
)
if
abs
(
k
)
==
1
:
if
abs
(
k
)
==
1
:
# negative k means min instead of max
# negative k means min instead of max
fn_max
=
[
None
,
np
.
max
,
np
.
min
][
k
]
fn_max
=
[
None
,
np
.
max
,
np
.
min
][
k
]
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论