Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
933cb859
提交
933cb859
authored
5月 18, 2017
作者:
Adam Becker
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement new kernel to handle arbitrary shape
上级
bf83d342
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
361 行增加
和
173 行删除
+361
-173
k_topk_common.cuh
theano/gpuarray/k_topk_common.cuh
+47
-162
k_topk_dense.cu
theano/gpuarray/k_topk_dense.cu
+138
-0
k_topk_dense_large.cu
theano/gpuarray/k_topk_dense_large.cu
+167
-0
sort.py
theano/gpuarray/sort.py
+0
-0
sort.py
theano/tensor/sort.py
+9
-8
test_sort.py
theano/tensor/tests/test_sort.py
+0
-3
没有找到文件。
theano/gpuarray/
topk_kernel.cu
→
theano/gpuarray/
k_topk_common.cuh
浏览文件 @
933cb859
...
...
@@ -9,26 +9,52 @@
// will all be adjacent
template <typename T>
struct RadixConfig {};
struct RadixConfig {
typedef T RadixType;
static inline __device__ RadixType convert(T v) {
return v;
}
static inline __device__ float deconvert(RadixType v) {
return v;
}
};
template <>
struct RadixConfig<float> {
struct RadixConfig<
ga_
float> {
typedef ga_uint RadixType;
static inline __device__ RadixType convert(float v) {
static inline __device__ RadixType convert(
ga_
float v) {
RadixType x = __float_as_int(v);
RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
return (x ^ mask);
}
static inline __device__ float deconvert(RadixType v) {
static inline __device__
ga_
float deconvert(RadixType v) {
RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff;
return __int_as_float(v ^ mask);
}
};
template <>
struct RadixConfig<ga_double> {
typedef unsigned long long int RadixType;
static inline __device__ RadixType convert(ga_double v) {
RadixType x = __double_as_longlong(v);
RadixType mask = -((x >> 63)) | 0x8000000000000000;
return (x ^ mask);
}
static inline __device__ ga_double deconvert(RadixType v) {
RadixType mask = ((v >> 63) - 1) | 0x8000000000000000;
return __longlong_as_double(v ^ mask);
}
};
template <>
struct RadixConfig<ga_ubyte> {
typedef ga_uint RadixType;
...
...
@@ -43,14 +69,14 @@ struct RadixConfig<ga_ubyte> {
};
template <>
struct RadixConfig<
char
> {
struct RadixConfig<
ga_byte
> {
typedef ga_uint RadixType;
static inline __device__ RadixType convert(
char
v) {
static inline __device__ RadixType convert(
ga_byte
v) {
return 128u + v;
}
static inline __device__
char
deconvert(RadixType v) {
static inline __device__
ga_byte
deconvert(RadixType v) {
return v - 128;
}
};
...
...
@@ -61,7 +87,7 @@ struct RadixConfig<ga_short> {
static inline __device__ RadixType convert(ga_short v) {
assert(sizeof(ga_short) == 2);
return 32768u
+
v;
return 32768u
^
v;
}
static inline __device__ ga_short deconvert(RadixType v) {
...
...
@@ -75,45 +101,30 @@ struct RadixConfig<int> {
static inline __device__ RadixType convert(int v) {
assert(sizeof(int) == 4);
return
2147483648u +
v;
return
(1u << 31) ^
v;
}
static inline __device__ int deconvert(RadixType v) {
return
v - 2147483648u
;
return
(1u << 31) ^ v
;
}
};
template <>
struct RadixConfig<long> {
struct RadixConfig<
ga_
long> {
typedef unsigned long long int RadixType;
static inline __device__ RadixType convert(long v) {
static inline __device__ RadixType convert(
ga_
long v) {
assert(sizeof(long) == 8);
return
9223372036854775808ull +
v;
return
(1ull << 63) ^
v;
}
static inline __device__ long deconvert(RadixType v) {
return v - 9223372036854775808ull;
}
};
template <>
struct RadixConfig<double> {
typedef unsigned long long int RadixType;
static inline __device__ RadixType convert(double v) {
RadixType x = __double_as_longlong(v);
RadixType mask = -((x >> 63)) | 0x8000000000000000;
return (x ^ mask);
}
static inline __device__ double deconvert(RadixType v) {
RadixType mask = ((v >> 63) - 1) | 0x8000000000000000;
return __longlong_as_double(v ^ mask);
static inline __device__ ga_long deconvert(RadixType v) {
return (1ull << 63) ^ v;
}
};
#ifdef USE_HALF
// TODO: make this work
template <>
struct RadixConfig<half> {
typedef ga_uint RadixType;
...
...
@@ -242,135 +253,9 @@ static __device__ inline T& ptr_at(T *ptr, ga_ssize offset) {
return *((T*)((char*)ptr + offset));
}
KERNEL void k_topk_dense(
$dims
// ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM}
$dstv
// INPUT_TYPE *dstv
$dstv_strides
// ga_ssize dstv_strides_0, ga_ssize dstv_strides_1, ... , dstv_strides_$${NDIM}
$dsti
// INDEX_TYPE *dsti
$dsti_strides
// ga_ssize dsti_strides_0, ga_ssize dsti_strides_1, ... , dsti_strides_$${NDIM}
ga_ssize k,
INPUT_TYPE* src,
$src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
ga_size size) {
LOCAL_MEM radix_t smem[32 * RADIX_SIZE];
ga_ssize LOCAL_MEM bins[RADIX_SIZE]; // TODO: does using 32-bit gives speedup?
bool is_topk=true, is_topkth=true;
radix_t out_idx;
const ga_size idx = LID_0;
ga_size LOCAL_MEM k2, exceed;
const ga_uint warp_id = idx / GA_WARP_SIZE;
const ga_uint lane_id = idx % GA_WARP_SIZE;
const bool in_range = (idx < size);
is_topk &= in_range;
// 0. get the slice for thread block to work on
// TODO if ndim <= 3, use native indexing ? (blockIdx.[xyz])
ga_size gid = GID_0, gidx;
$set_slice
//for(int i=1; i<NDIM; i++) {
// gidx = gid % dims_$${i};
// gid /= dims_$${i};
// dsti = ptr_add(dsti, gidx*dsti_strides_$${i};
// dstv = ptr_add(dstv, gidx*dstv_strides_$${i};
// src = ptr_add(src, gidx*src_strides_$${i});
//}
// get input and its radix friendly form
const INPUT_TYPE xval = in_range ? ptr_at(src, idx*src_strides_0) : (INPUT_TYPE)0;
radix_t x = in_range ? RadixConfig<INPUT_TYPE>::convert(xval) : 0;
// resolve negative k
if (k<0) { x = ~x; k = -k; }
if (idx==0) k2 = k;
// 1. filter is_topk and is_topkth using radix select
#pragma unroll
for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
int digit = (x>>i) & (RADIX_SIZE-1);
// count within warp
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
bool incr_bin = (bin == digit) && is_topkth && in_range;
ga_uint incr_bin_warp = __ballot(incr_bin);
if (lane_id==0)
smem[bin + RADIX_SIZE*warp_id] = __popc(incr_bin_warp);
}
local_barrier();
// sum counts across all warps
// TODO: test in-block parallel sum?
if (idx < RADIX_SIZE) {
for(int w=RADIX_SIZE; w<LDIM_0*RADIX_SIZE / GA_WARP_SIZE; w+=RADIX_SIZE)
smem[idx] += smem[idx + w];
}
local_barrier();
// calculate k minus cumsum(count)
if (idx<RADIX_SIZE)
bins[idx] = 0;
if (idx == 0) {
exceed = k; // how many the number of is_topk exceeds k
bins[RADIX_SIZE-1] = k2 - smem[RADIX_SIZE-1];
if (bins[RADIX_SIZE-1] > 0)
k2 = bins[RADIX_SIZE-1];
else
exceed = min(exceed, bins[RADIX_SIZE-1]);
#pragma unroll
for(int bin=RADIX_SIZE-1; bin; --bin) {
bins[bin-1] = bins[bin] - smem[bin-1];
if (bins[bin-1] > 0)
k2 = bins[bin-1];
else
exceed = min(exceed, bins[bin-1]);
}
}
local_barrier();
// smem -> count
// bins -> k2 - cumsum(count)
if (is_topk && is_topkth) {
ga_ssize icount = bins[digit];
if (icount > 0) {
is_topkth = false;
} else if (icount < 0) {
if (digit+1!=RADIX_SIZE) {
if (bins[digit+1] <= 0) {
is_topk = false;
is_topkth = false;
}
}
}
}
}
// 2. find the index of output array, if exists
if (exceed != 0) {
// top_kth value may not be unique, so we need to
// perform binary cumsum on is_topkth to drop exceeding top-kth values
out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topkth);
is_topk &= (out_idx < exceed);
}
// perform binary cumsum on is_topk to determine the indices to put result
out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topk);
local_barrier();
if (is_topk) {
#if WRITE_VALUE == 1
ptr_at(dstv, out_idx * dstv_strides_0) = xval;
#endif
#if WRITE_INDEX == 1
ptr_at(dsti, out_idx * dsti_strides_0) = (INDEX_TYPE)idx;
#endif
}
// read array element using raw(byte) offset
template <typename T>
static __device__ inline T ptr_read(T *ptr, ga_ssize offset) {
return __ldg(((T*)((char*)ptr + offset)));
}
theano/gpuarray/k_topk_dense.cu
0 → 100644
浏览文件 @
933cb859
// works when length on axis is within max allowed threads in block (1024)
KERNEL void k_topk_dense(
$dims
// ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM}
$dstv
// INPUT_TYPE *dstv
$dstv_strides
// ga_ssize dstv_strides_0, ga_ssize dstv_strides_1, ... , dstv_strides_$${NDIM}
$dsti
// INDEX_TYPE *dsti
$dsti_strides
// ga_ssize dsti_strides_0, ga_ssize dsti_strides_1, ... , dsti_strides_$${NDIM}
ga_ssize k,
INPUT_TYPE* src,
$src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
ga_size size) {
LOCAL_MEM radix_t smem[32 * RADIX_SIZE];
ga_ssize LOCAL_MEM bins[RADIX_SIZE+1]; // TODO: does using 32-bit gives good speedup?
bool is_topk=true, is_topkth=true;
radix_t out_idx;
const ga_ushort idx = LID_0;
ga_size LOCAL_MEM k2, exceed;
const ga_ubyte warp_id = idx / GA_WARP_SIZE;
const ga_ubyte lane_id = idx % GA_WARP_SIZE;
const bool in_range = (idx < size);
is_topk &= in_range;
// 0. get the slice for thread block to work on
// TODO if ndim <= 3, use native indexing ? (blockIdx.[xyz])
ga_size gid = GID_0, gidx;
$set_slice
//for(int i=1; i<NDIM; i++) {
// gidx = gid % dims_$${i};
// gid /= dims_$${i};
// dsti = ptr_add(dsti, gidx*dsti_strides_$${i};
// dstv = ptr_add(dstv, gidx*dstv_strides_$${i};
// src = ptr_add(src, gidx*src_strides_$${i});
//}
// get input and its radix friendly form
const INPUT_TYPE xval = in_range ? ptr_at(src, idx*src_strides_0) : (INPUT_TYPE)0;
radix_t x = in_range ? RadixConfig<INPUT_TYPE>::convert(xval) : 0;
// resolve negative k
if (k<0) { x = ~x; k = -k; }
if (idx==0) {
k2 = k;
bins[RADIX_SIZE] = 1;
}
// 1. filter is_topk and is_topkth using radix select
#pragma unroll
for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
int digit = (x>>i) & (RADIX_SIZE-1);
// count within warp
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
bool incr_bin = (bin == digit) && is_topkth && in_range;
ga_uint incr_bin_warp = __ballot(incr_bin);
if (lane_id==0)
smem[bin + RADIX_SIZE*warp_id] = __popc(incr_bin_warp);
}
local_barrier();
// sum counts across all warps
// TODO: test in-block parallel sum?
if (idx < RADIX_SIZE) {
for(int w=RADIX_SIZE; w<LDIM_0*RADIX_SIZE / GA_WARP_SIZE; w+=RADIX_SIZE)
smem[idx] += smem[idx + w];
}
local_barrier();
// bins = k - cumsum(smem[:RADIX_SIZE])
if (idx == 0) {
bins[RADIX_SIZE-1] = k2 - smem[RADIX_SIZE-1];
if (bins[RADIX_SIZE-1] > 0)
k2 = bins[RADIX_SIZE-1];
#pragma unroll
for (int bin=RADIX_SIZE-1; bin; --bin) {
bins[bin-1] = bins[bin] - smem[bin-1];
if (bins[bin-1] > 0)
k2 = bins[bin-1];
}
}
local_barrier();
// smem -> count
// bins -> k2 - cumsum(count)
if (is_topk && is_topkth) {
ga_ssize icount = bins[digit];
if (icount > 0) {
is_topkth = false;
} else if (bins[digit+1] <= 0) {
is_topk = false;
is_topkth = false;
}
}
}
if (idx==0) {
#pragma unroll
for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
if (bins[bin] <= 0) {
exceed = -bins[bin];
break;
}
}
}
local_barrier();
// 2. find the index of output array, if exists
if (exceed != 0) {
// top_kth value may not be unique, so we need to
// perform binary cumsum on is_topkth to drop exceeding top-kth values
out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topkth);
is_topk &= ((!is_topkth) || out_idx>=exceed);
}
// perform binary cumsum on is_topk to determine the indices to put result
out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topk);
local_barrier();
if (is_topk) {
#if WRITE_VALUE == 1
ptr_at(dstv, out_idx * dstv_strides_0) = xval;
#endif
#if WRITE_INDEX == 1
ptr_at(dsti, out_idx * dsti_strides_0) = (INDEX_TYPE)idx;
#endif
}
}
theano/gpuarray/k_topk_dense_large.cu
0 → 100644
浏览文件 @
933cb859
// works when length on axis is larger than max allowed threads in block (1024)
KERNEL void k_topk_dense_large(
$dims
// ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM}
$dstv
// INPUT_TYPE *dstv
$dstv_strides
// ga_ssize dstv_strides_0, ga_ssize dstv_strides_1, ... , dstv_strides_$${NDIM}
$dsti
// INDEX_TYPE *dsti
$dsti_strides
// ga_ssize dsti_strides_0, ga_ssize dsti_strides_1, ... , dsti_strides_$${NDIM}
ga_ssize k,
INPUT_TYPE* src,
$src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
ga_size size, ga_ushort inp_per_thread) {
LOCAL_MEM radix_t smem[32 * RADIX_SIZE];
LOCAL_MEM radix_t known_bits, known_bits_mask;
radix_t out_idx;
ga_size LOCAL_MEM write_base;
INPUT_TYPE xval;
radix_t x;
ga_int i;
bool in_range, is_topk;
const ga_size idx = LID_0;
ga_size LOCAL_MEM k2;
const ga_ushort warp_id = idx / GA_WARP_SIZE;
const ga_ushort lane_id = idx % GA_WARP_SIZE;
// 0. get the slice for thread block to work on
// TODO if ndim <= 3, use native indexing ? (blockIdx.[xyz])
ga_size gid = GID_0, gidx;
$set_slice
//for(int i=1; i<NDIM; i++) {
// gidx = gid % dims_$${i};
// gid /= dims_$${i};
// dsti = ptr_add(dsti, gidx*dsti_strides_$${i};
// dstv = ptr_add(dstv, gidx*dstv_strides_$${i};
// src = ptr_add(src, gidx*src_strides_$${i});
//}
src = ptr_add(src, idx*inp_per_thread*src_strides_0);
LOCAL_MEM radix_t inv_bits;
if (idx==0) {
known_bits = known_bits_mask = 0;
k2 = abs(k);
inv_bits = (k>=0) ? 0 : (~0);
write_base = 0;
}
if (k<0) { k = -k; }
local_barrier();
// 1. find bits of top-k-th value using radix select
#pragma unroll
for (i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
/*for (i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i*=-1) {*/
if (lane_id == 0) {
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
smem[bin + warp_id*RADIX_SIZE] = 0;
}
}
local_barrier();
for (int j=0; j<inp_per_thread; ++j) {
in_range = (idx*inp_per_thread+j) < size;
xval = in_range ? ptr_read(src, j*src_strides_0) : (INPUT_TYPE)0;
x = inv_bits^RadixConfig<INPUT_TYPE>::convert(xval);
ga_int digit = (int)((x>>i) & (RADIX_SIZE-1));
// count within warp
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
bool incr_bin = (
(bin == digit) &&
((x&known_bits_mask) == known_bits) &&
in_range);
ga_uint incr_bin_warp = __ballot(incr_bin);
if (lane_id==0)
smem[bin + RADIX_SIZE*warp_id] += __popc(incr_bin_warp);
}
}
local_barrier();
// sum counts across all warps
// TODO: test in-block parallel sum?
if (idx < RADIX_SIZE) {
for(int w=RADIX_SIZE;
w<(LDIM_0/ GA_WARP_SIZE)*RADIX_SIZE;
w+=RADIX_SIZE)
smem[idx] += smem[idx + w];
}
local_barrier();
// update known bits
if (idx==0) {
#pragma unroll
for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
if (smem[bin] >= k2) {
known_bits |= (bin << i);
known_bits_mask |= ((RADIX_SIZE-1) << i);
break;
} else
k2 -= smem[bin];
}
}
local_barrier();
}
/*
if (idx < RADIX_SIZE) {
ptr_at(dstv, idx*dstv_strides_0) = known_bits;
ptr_at(dstv, idx*dstv_strides_0) = smem[idx];
}
return;
*/
// 2. write values smaller than top-kth
for (i=0; i<inp_per_thread; ++i) {
in_range = (idx*inp_per_thread+i) < size;
xval = in_range ? ptr_read(src, i*src_strides_0) : (INPUT_TYPE)0;
x = inv_bits ^ RadixConfig<INPUT_TYPE>::convert(xval);
is_topk = (x > known_bits) && in_range;
out_idx = binary_cumsum<radix_t>(idx, warp_id, lane_id, smem, is_topk);
if (is_topk) {
#if WRITE_VALUE == 1
ptr_at(dstv, (out_idx+write_base-1) * dstv_strides_0) = xval;
#endif
#if WRITE_INDEX == 1
ptr_at(dsti, (out_idx+write_base-1) * dsti_strides_0) = (INDEX_TYPE)(idx*inp_per_thread + i);
#endif
}
local_barrier();
if (idx == blockDim.x - 1)
write_base += out_idx;
local_barrier();
}
// 3. write values equal to top-kth
for (i=0; i<inp_per_thread; ++i) {
in_range = (idx*inp_per_thread+i) < size;
xval = in_range ? ptr_read(src, i*src_strides_0) : (INPUT_TYPE)0;
x = inv_bits ^ RadixConfig<INPUT_TYPE>::convert(xval);
is_topk = (x == known_bits) && in_range;
out_idx = binary_cumsum<radix_t>(idx, warp_id, lane_id, smem, is_topk);
is_topk = ((out_idx+write_base) <= abs(k)) && is_topk;
if (is_topk) {
#if WRITE_VALUE == 1
ptr_at(dstv, (out_idx+write_base-1) * dstv_strides_0) = xval;
#endif
#if WRITE_INDEX == 1
ptr_at(dsti, (out_idx+write_base-1) * dsti_strides_0) = (INDEX_TYPE)(idx*inp_per_thread + i);
#endif
}
local_barrier();
if (idx == blockDim.x - 1)
write_base += out_idx;
local_barrier();
if(write_base >= abs(k))
break;
}
}
theano/gpuarray/sort.py
浏览文件 @
933cb859
差异被折叠。
点击展开。
theano/tensor/sort.py
浏览文件 @
933cb859
...
...
@@ -342,20 +342,21 @@ class TopKOp(theano.Op):
# TODO c_code
__props__
=
(
'axis'
,
'return_values'
,
'return_indices'
)
__props__
=
(
'axis'
,
'return_values'
,
'return_indices'
,
'idx_dtype'
)
def
__init__
(
self
,
axis
=-
1
,
return_indices
=
False
,
return_values
=
True
):
def
__init__
(
self
,
axis
=-
1
,
return_indices
=
False
,
return_values
=
True
,
idx_dtype
=
'int64'
):
assert
isinstance
(
axis
,
int
)
assert
return_indices
or
return_values
self
.
axis
=
axis
self
.
return_indices
=
return_indices
self
.
return_values
=
return_values
self
.
idx_dtype
=
idx_dtype
def
__str__
(
self
):
return
'
%(op)
s{axis=
%(axis)
d}'
%
dict
(
op
=
self
.
__class__
.
__name__
,
axis
=
self
.
axis
)
def
make_node
(
self
,
inp
,
k
,
idx_dtype
=
'int64'
):
def
make_node
(
self
,
inp
,
k
):
# numpy always uses float64 as output dtype for arg*() routines
# however, we add this option as memory is more precious on gpu
inp
=
theano
.
tensor
.
as_tensor_variable
(
inp
)
...
...
@@ -366,7 +367,7 @@ class TopKOp(theano.Op):
outs
.
append
(
inp
.
type
())
if
self
.
return_indices
:
outs
.
append
(
theano
.
tensor
.
TensorType
(
dtype
=
idx_dtype
,
broadcastable
=
bcast
)())
theano
.
tensor
.
TensorType
(
dtype
=
self
.
idx_dtype
,
broadcastable
=
bcast
)())
return
theano
.
Apply
(
self
,
[
inp
,
k
],
outs
)
def
perform
(
self
,
node
,
inputs
,
output_storage
):
...
...
@@ -458,18 +459,18 @@ def argtopk(x, k, axis=-1, idx_dtype='int64'):
if
axis
is
None
:
x
=
theano
.
tensor
.
flatten
(
x
)
axis
=
-
1
return
TopKOp
(
axis
=
axis
,
return_indices
=
True
,
return_values
=
False
)(
x
,
k
,
idx_dtype
=
idx_dtype
)
return
TopKOp
(
axis
=
axis
,
return_indices
=
True
,
return_values
=
False
,
idx_dtype
=
idx_dtype
)(
x
,
k
)
def
topk_and_argtopk
(
x
,
k
,
axis
=-
1
,
idx_dtype
=
'int64'
):
'''
"""
Returns the results of both topk() and argtopk() in one Op.
See the respective documentation for details.
'''
"""
if
axis
is
None
:
x
=
theano
.
tensor
.
flatten
(
x
)
axis
=
-
1
return
TopKOp
(
axis
=
axis
,
return_indices
=
True
)(
x
,
k
,
idx_dtype
=
idx_dtype
)
return
TopKOp
(
axis
=
axis
,
return_indices
=
True
,
idx_dtype
=
idx_dtype
)(
x
,
k
)
theano/tensor/tests/test_sort.py
浏览文件 @
933cb859
...
...
@@ -21,14 +21,12 @@ _int_dtypes = (
'int8'
,
'int16'
,
'int32'
,
'int64'
,
'uint8'
,
'uint16'
,
'uint32'
,
'uint64'
)
def
gen_unique_vector
(
size
,
dtype
):
# generate a randomized vector with unique elements
retval
=
np
.
arange
(
size
*
3
)
+
np
.
random
.
uniform
(
-
1.
,
1.
)
return
(
retval
[
np
.
random
.
permutation
(
size
)]
-
size
*
1.5
)
.
astype
(
dtype
)
'''
class
Test_sort
(
unittest
.
TestCase
):
def
setUp
(
self
):
...
...
@@ -236,7 +234,6 @@ def test_argsort_grad():
data
=
np
.
random
.
rand
(
2
,
3
,
3
)
.
astype
(
theano
.
config
.
floatX
)
utt
.
verify_grad
(
lambda
x
:
argsort
(
x
,
axis
=
2
),
[
data
])
'''
class
Test_TopK
(
unittest
.
TestCase
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论