Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
5cc9c326
提交
5cc9c326
authored
9月 21, 2017
作者:
abergeron
提交者:
GitHub
9月 21, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #5959 from khaotik/topk
TopKOp implementation
上级
0d47e204
61de1573
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
11 个修改的文件
包含
183 行增加
和
11 行删除
+183
-11
opt.py
theano/gof/opt.py
+14
-8
__init__.py
theano/gpuarray/__init__.py
+1
-1
topk_common.cuh
theano/gpuarray/c_code/topk_common.cuh
+0
-0
topk_dense.cu
theano/gpuarray/c_code/topk_dense.cu
+133
-0
topk_dense_large.cu
theano/gpuarray/c_code/topk_dense_large.cu
+0
-0
sort.py
theano/gpuarray/sort.py
+0
-0
__init__.py
theano/tensor/__init__.py
+1
-1
opt.py
theano/tensor/opt.py
+33
-0
sort.py
theano/tensor/sort.py
+0
-0
test_sort.py
theano/tensor/tests/test_sort.py
+0
-0
unittest_tools.py
theano/tests/unittest_tools.py
+1
-1
没有找到文件。
theano/gof/opt.py
浏览文件 @
5cc9c326
...
...
@@ -1367,21 +1367,27 @@ class LocalOptGroup(LocalOptimizer):
self
.
process_count
[
opt
]
+=
1
if
not
new_repl
:
continue
else
:
if
self
.
profile
:
self
.
node_created
[
opt
]
+=
len
(
graph
.
ops
(
fgraph
.
variables
,
new_repl
))
self
.
applied_true
[
opt
]
+=
1
break
# break from the for loop over optimization.
if
isinstance
(
new_repl
,
(
tuple
,
list
)):
new_vars
=
new_repl
else
:
# It must be a dict
new_vars
=
list
(
new_repl
.
values
())
if
self
.
profile
:
self
.
node_created
[
opt
]
+=
len
(
graph
.
ops
(
fgraph
.
variables
,
new_vars
))
self
.
applied_true
[
opt
]
+=
1
break
# break from the for loop over optimization.
if
not
new_repl
:
# No optimization applied in the last iteration
return
repl
# only 1 iteration or we are at the start of the graph.
if
not
self
.
apply_all_opts
or
not
new_repl
[
0
]
.
owner
:
# only 1 iteration
if
not
self
.
apply_all_opts
:
return
new_repl
if
not
new_vars
[
0
]
.
owner
:
# We are at the start of the graph.
return
new_repl
if
len
(
new_repl
)
>
1
:
s
=
set
([
v
.
owner
for
v
in
new_repl
])
assert
len
(
s
)
==
1
repl
=
new_repl
node
=
repl
[
0
]
.
owner
node
=
new_vars
[
0
]
.
owner
@staticmethod
def
print_profile
(
stream
,
prof
,
level
=
0
):
...
...
theano/gpuarray/__init__.py
浏览文件 @
5cc9c326
...
...
@@ -28,7 +28,7 @@ from .type import (GpuArrayType, GpuArrayVariable, GpuArrayConstant,
GpuArraySharedVariable
,
gpuarray_shared_constructor
,
reg_context
,
get_context
,
ContextNotDefined
)
from
.basic_ops
import
as_gpuarray_variable
from
.
import
fft
,
dnn
,
opt
,
extra_ops
,
multinomial
,
reduction
,
rng_mrg
,
ctc
from
.
import
fft
,
dnn
,
opt
,
extra_ops
,
multinomial
,
reduction
,
sort
,
rng_mrg
,
ctc
def
transfer
(
x
,
target
):
...
...
theano/gpuarray/c_code/topk_common.cuh
0 → 100644
浏览文件 @
5cc9c326
差异被折叠。
点击展开。
theano/gpuarray/c_code/topk_dense.cu
0 → 100644
浏览文件 @
5cc9c326
#define RADIX_BITS 4
#define RADIX_SIZE (1<<RADIX_BITS)
#define RADIX_MASK(n) ((RADIX_SIZE-1) << (n*RADIX_BITS))
#define RADIX_DIGITS(T) (bitsof(T)/RADIX_BITS)
// works when length on axis is within max allowed threads in block (1024)
KERNEL void k_topk_dense(
$dims
// size_t dims_1, ssize_t dims_2, ... , dims_$${NDIM}
$dstv
// INPUT_TYPE *dstv
$dstv_strides
// ssize_t dstv_strides_0, ssize_t dstv_strides_1, ... , dstv_strides_$${NDIM}
$dsti
// INDEX_TYPE *dsti
$dsti_strides
// ssize_t dsti_strides_0, ssize_t dsti_strides_1, ... , dsti_strides_$${NDIM}
ssize_t k,
INPUT_TYPE* src,
$src_strides
// ssize_t src_strides_0, ssize_t src_strides_1, ... , src_strides_$${NDIM}
size_t size) {
__shared__ int smem[32 * RADIX_SIZE];
__shared__ int k2;
const unsigned int idx = threadIdx.x;
bool is_topk= (idx < size);
bool is_topkth = is_topk;
size_t out_idx;
const unsigned char warp_id = idx / GA_WARP_SIZE;
// 0. get the slice for thread block to work on
size_t gid = blockIdx.x, gidx;
$set_slice
// $$set_slice expands into:
//for(int i=1; i<NDIM; i++) {
// gidx = gid % dims_$${i};
// gid /= dims_$${i};
// dsti = ptr_add(dsti, gidx*dsti_strides_$${i};
// dstv = ptr_add(dstv, gidx*dstv_strides_$${i};
// src = ptr_add(src, gidx*src_strides_$${i});
//}
// get input and its radix friendly form
const INPUT_TYPE xval = is_topk ? ptr_at(src, idx*src_strides_0) : (INPUT_TYPE)0;
radix_t x = RadixConfig<INPUT_TYPE>::convert(xval);
// resolve negative k
if (k<0) { x = ~x; k = -k; }
if (idx==0)
k2 = k;
// 1. filter is_topk and is_topkth using radix select
#pragma unroll
for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
const int digit = Bitfield<radix_t>::get(x, i, RADIX_BITS);
/*int digit = (x>>i) & (RADIX_SIZE-1);*/
// count within warp
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
bool vote = (bin == digit) && is_topkth;
unsigned int votes = __ballot(vote);
if (lane_id()==0)
smem[bin + RADIX_SIZE*warp_id] = __popc(votes);
}
local_barrier();
// sum counts across all warps
if (idx < RADIX_SIZE) {
int sum = smem[idx];
#pragma unroll
for(int w=RADIX_SIZE; w<blockDim.x*RADIX_SIZE / GA_WARP_SIZE; w+=RADIX_SIZE)
sum += smem[idx + w];
smem[idx] = sum;
}
local_barrier();
// find the bucket and update k2
// smem[:RADIX_SIZE:-1] = k2 - cumsum(smem[:RADIX_SIZE-1:-1])
if (idx == 0) {
int sum = k2;
#pragma unroll
for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
sum -= smem[bin];
smem[bin] = sum;
k2 = (sum > 0) ? sum : k2;
}
smem[RADIX_SIZE] = 1;
}
local_barrier();
if (is_topkth) {
is_topk &= (smem[digit+1] > 0);
is_topkth &= (smem[digit] <= 0) && (smem[digit+1] > 0);
}
local_barrier();
}
// set k2 as number of exceeding values
if (idx==0) {
#pragma unroll
for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
if (smem[bin] <= 0)
break;
k2 = smem[bin];
}
}
local_barrier();
// 2. find the index of output array, if exists
if (k2 != 0) {
// top_kth value may not be unique, so we need to
// perform binary cumsum on is_topkth to drop exceeding top-kth values
out_idx = binary_cumsum_exclusive(idx, warp_id, smem, is_topkth);
if ((out_idx >= k2) && is_topkth)
is_topk = false;
local_barrier();
}
// perform binary cumsum on is_topk to determine the indices to put result
out_idx = binary_cumsum_exclusive(idx, warp_id, smem, is_topk);
if (is_topk) {
#if WRITE_VALUE == 1
ptr_at(dstv, out_idx * dstv_strides_0) = xval;
#endif
#if WRITE_INDEX == 1
ptr_at(dsti, out_idx * dsti_strides_0) = (INDEX_TYPE)idx;
#endif
}
}
theano/gpuarray/c_code/topk_dense_large.cu
0 → 100644
浏览文件 @
5cc9c326
差异被折叠。
点击展开。
theano/gpuarray/sort.py
0 → 100644
浏览文件 @
5cc9c326
差异被折叠。
点击展开。
theano/tensor/__init__.py
浏览文件 @
5cc9c326
...
...
@@ -40,7 +40,7 @@ from theano.tensor import nnet # used for softmax, sigmoid, etc.
from
theano.gradient
import
Rop
,
Lop
,
grad
,
numeric_grad
,
verify_grad
,
\
jacobian
,
hessian
,
consider_constant
from
theano.tensor.sort
import
sort
,
argsort
from
theano.tensor.sort
import
sort
,
argsort
,
topk
,
argtopk
,
topk_and_argtopk
from
theano.tensor.extra_ops
import
(
DiffOp
,
bincount
,
squeeze
,
repeat
,
bartlett
,
fill_diagonal
,
fill_diagonal_offset
,
cumsum
,
cumprod
,
unravel_index
,
ravel_multi_index
)
...
...
theano/tensor/opt.py
浏览文件 @
5cc9c326
...
...
@@ -35,6 +35,7 @@ from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice,
advanced_subtensor
,
advanced_subtensor1
,
advanced_inc_subtensor1
)
from
theano.tensor.sort
import
TopKOp
from
theano
import
scalar
from
theano.scalar
import
basic
from
theano.tensor
import
basic
as
T
...
...
@@ -7548,3 +7549,35 @@ def local_merge_alloc(node):
dim_outer
,
T
.
eq
(
dim_outer
,
dim_inner
))
i
+=
1
return
[
T
.
alloc
(
inputs_inner
[
0
],
*
dims_outer
)]
@register_useless
(
'fast_compile'
)
@gof.local_optimizer
([
TopKOp
])
def
local_useless_topk
(
node
):
"""
TopKOp generates two outputs by default
This opt removes the useless ones
"""
op
=
node
.
op
if
not
isinstance
(
op
,
TopKOp
):
return
if
not
(
op
.
return_values
and
op
.
return_indices
):
return
False
x
,
k
=
node
.
inputs
ret_val
=
bool
(
node
.
outputs
[
0
]
.
clients
)
ret_idx
=
bool
(
node
.
outputs
[
1
]
.
clients
)
if
not
(
ret_val
^
ret_idx
):
# both true -> nothing to remove
# both false -> let pruner handle
return
False
old_output
=
node
.
outputs
[
ret_idx
]
new_output
=
TopKOp
(
axis
=
op
.
axis
,
idx_dtype
=
op
.
idx_dtype
,
return_values
=
ret_val
,
return_indices
=
ret_idx
)(
x
,
k
)
return
{
old_output
:
new_output
}
theano/tensor/sort.py
浏览文件 @
5cc9c326
差异被折叠。
点击展开。
theano/tensor/tests/test_sort.py
浏览文件 @
5cc9c326
差异被折叠。
点击展开。
theano/tests/unittest_tools.py
浏览文件 @
5cc9c326
...
...
@@ -83,7 +83,7 @@ def seed_rng(pseed=None):
def
verify_grad
(
op
,
pt
,
n_tests
=
2
,
rng
=
None
,
*
args
,
**
kwargs
):
"""
Wrapper for
tensor/basic
.py:verify_grad
Wrapper for
gradient
.py:verify_grad
Takes care of seeding the random number generator if None is given
"""
if
rng
is
None
:
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论