Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
933cb859
提交
933cb859
authored
5月 18, 2017
作者:
Adam Becker
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement new kernel to handle arbitrary shape
上级
bf83d342
显示空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
438 行增加
和
217 行删除
+438
-217
k_topk_common.cuh
theano/gpuarray/k_topk_common.cuh
+47
-162
k_topk_dense.cu
theano/gpuarray/k_topk_dense.cu
+138
-0
k_topk_dense_large.cu
theano/gpuarray/k_topk_dense_large.cu
+167
-0
sort.py
theano/gpuarray/sort.py
+77
-44
sort.py
theano/tensor/sort.py
+9
-8
test_sort.py
theano/tensor/tests/test_sort.py
+0
-3
没有找到文件。
theano/gpuarray/
topk_kernel.cu
→
theano/gpuarray/
k_topk_common.cuh
浏览文件 @
933cb859
...
@@ -9,26 +9,52 @@
...
@@ -9,26 +9,52 @@
// will all be adjacent
// will all be adjacent
template <typename T>
template <typename T>
struct RadixConfig {};
struct RadixConfig {
typedef T RadixType;
static inline __device__ RadixType convert(T v) {
return v;
}
static inline __device__ float deconvert(RadixType v) {
return v;
}
};
template <>
template <>
struct RadixConfig<float> {
struct RadixConfig<
ga_
float> {
typedef ga_uint RadixType;
typedef ga_uint RadixType;
static inline __device__ RadixType convert(float v) {
static inline __device__ RadixType convert(
ga_
float v) {
RadixType x = __float_as_int(v);
RadixType x = __float_as_int(v);
RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
return (x ^ mask);
return (x ^ mask);
}
}
static inline __device__ float deconvert(RadixType v) {
static inline __device__
ga_
float deconvert(RadixType v) {
RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff;
RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff;
return __int_as_float(v ^ mask);
return __int_as_float(v ^ mask);
}
}
};
};
template <>
struct RadixConfig<ga_double> {
typedef unsigned long long int RadixType;
static inline __device__ RadixType convert(ga_double v) {
RadixType x = __double_as_longlong(v);
RadixType mask = -((x >> 63)) | 0x8000000000000000;
return (x ^ mask);
}
static inline __device__ ga_double deconvert(RadixType v) {
RadixType mask = ((v >> 63) - 1) | 0x8000000000000000;
return __longlong_as_double(v ^ mask);
}
};
template <>
template <>
struct RadixConfig<ga_ubyte> {
struct RadixConfig<ga_ubyte> {
typedef ga_uint RadixType;
typedef ga_uint RadixType;
...
@@ -43,14 +69,14 @@ struct RadixConfig<ga_ubyte> {
...
@@ -43,14 +69,14 @@ struct RadixConfig<ga_ubyte> {
};
};
template <>
template <>
struct RadixConfig<
char
> {
struct RadixConfig<
ga_byte
> {
typedef ga_uint RadixType;
typedef ga_uint RadixType;
static inline __device__ RadixType convert(
char
v) {
static inline __device__ RadixType convert(
ga_byte
v) {
return 128u + v;
return 128u + v;
}
}
static inline __device__
char
deconvert(RadixType v) {
static inline __device__
ga_byte
deconvert(RadixType v) {
return v - 128;
return v - 128;
}
}
};
};
...
@@ -61,7 +87,7 @@ struct RadixConfig<ga_short> {
...
@@ -61,7 +87,7 @@ struct RadixConfig<ga_short> {
static inline __device__ RadixType convert(ga_short v) {
static inline __device__ RadixType convert(ga_short v) {
assert(sizeof(ga_short) == 2);
assert(sizeof(ga_short) == 2);
return 32768u
+
v;
return 32768u
^
v;
}
}
static inline __device__ ga_short deconvert(RadixType v) {
static inline __device__ ga_short deconvert(RadixType v) {
...
@@ -75,45 +101,30 @@ struct RadixConfig<int> {
...
@@ -75,45 +101,30 @@ struct RadixConfig<int> {
static inline __device__ RadixType convert(int v) {
static inline __device__ RadixType convert(int v) {
assert(sizeof(int) == 4);
assert(sizeof(int) == 4);
return
2147483648u +
v;
return
(1u << 31) ^
v;
}
}
static inline __device__ int deconvert(RadixType v) {
static inline __device__ int deconvert(RadixType v) {
return
v - 2147483648u
;
return
(1u << 31) ^ v
;
}
}
};
};
template <>
template <>
struct RadixConfig<long> {
struct RadixConfig<
ga_
long> {
typedef unsigned long long int RadixType;
typedef unsigned long long int RadixType;
static inline __device__ RadixType convert(long v) {
static inline __device__ RadixType convert(
ga_
long v) {
assert(sizeof(long) == 8);
assert(sizeof(long) == 8);
return 9223372036854775808ull + v;
return (1ull << 63) ^ v;
}
static inline __device__ long deconvert(RadixType v) {
return v - 9223372036854775808ull;
}
};
template <>
struct RadixConfig<double> {
typedef unsigned long long int RadixType;
static inline __device__ RadixType convert(double v) {
RadixType x = __double_as_longlong(v);
RadixType mask = -((x >> 63)) | 0x8000000000000000;
return (x ^ mask);
}
}
static inline __device__ double deconvert(RadixType v) {
static inline __device__ ga_long deconvert(RadixType v) {
RadixType mask = ((v >> 63) - 1) | 0x8000000000000000;
return (1ull << 63) ^ v;
return __longlong_as_double(v ^ mask);
}
}
};
};
#ifdef USE_HALF
#ifdef USE_HALF
// TODO: make this work
template <>
template <>
struct RadixConfig<half> {
struct RadixConfig<half> {
typedef ga_uint RadixType;
typedef ga_uint RadixType;
...
@@ -242,135 +253,9 @@ static __device__ inline T& ptr_at(T *ptr, ga_ssize offset) {
...
@@ -242,135 +253,9 @@ static __device__ inline T& ptr_at(T *ptr, ga_ssize offset) {
return *((T*)((char*)ptr + offset));
return *((T*)((char*)ptr + offset));
}
}
KERNEL void k_topk_dense(
// read array element using raw(byte) offset
$dims
template <typename T>
// ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM}
static __device__ inline T ptr_read(T *ptr, ga_ssize offset) {
$dstv
return __ldg(((T*)((char*)ptr + offset)));
// INPUT_TYPE *dstv
$dstv_strides
// ga_ssize dstv_strides_0, ga_ssize dstv_strides_1, ... , dstv_strides_$${NDIM}
$dsti
// INDEX_TYPE *dsti
$dsti_strides
// ga_ssize dsti_strides_0, ga_ssize dsti_strides_1, ... , dsti_strides_$${NDIM}
ga_ssize k,
INPUT_TYPE* src,
$src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
ga_size size) {
LOCAL_MEM radix_t smem[32 * RADIX_SIZE];
ga_ssize LOCAL_MEM bins[RADIX_SIZE]; // TODO: does using 32-bit gives speedup?
bool is_topk=true, is_topkth=true;
radix_t out_idx;
const ga_size idx = LID_0;
ga_size LOCAL_MEM k2, exceed;
const ga_uint warp_id = idx / GA_WARP_SIZE;
const ga_uint lane_id = idx % GA_WARP_SIZE;
const bool in_range = (idx < size);
is_topk &= in_range;
// 0. get the slice for thread block to work on
// TODO if ndim <= 3, use native indexing ? (blockIdx.[xyz])
ga_size gid = GID_0, gidx;
$set_slice
//for(int i=1; i<NDIM; i++) {
// gidx = gid % dims_$${i};
// gid /= dims_$${i};
// dsti = ptr_add(dsti, gidx*dsti_strides_$${i};
// dstv = ptr_add(dstv, gidx*dstv_strides_$${i};
// src = ptr_add(src, gidx*src_strides_$${i});
//}
// get input and its radix friendly form
const INPUT_TYPE xval = in_range ? ptr_at(src, idx*src_strides_0) : (INPUT_TYPE)0;
radix_t x = in_range ? RadixConfig<INPUT_TYPE>::convert(xval) : 0;
// resolve negative k
if (k<0) { x = ~x; k = -k; }
if (idx==0) k2 = k;
// 1. filter is_topk and is_topkth using radix select
#pragma unroll
for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
int digit = (x>>i) & (RADIX_SIZE-1);
// count within warp
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
bool incr_bin = (bin == digit) && is_topkth && in_range;
ga_uint incr_bin_warp = __ballot(incr_bin);
if (lane_id==0)
smem[bin + RADIX_SIZE*warp_id] = __popc(incr_bin_warp);
}
local_barrier();
// sum counts across all warps
// TODO: test in-block parallel sum?
if (idx < RADIX_SIZE) {
for(int w=RADIX_SIZE; w<LDIM_0*RADIX_SIZE / GA_WARP_SIZE; w+=RADIX_SIZE)
smem[idx] += smem[idx + w];
}
local_barrier();
// calculate k minus cumsum(count)
if (idx<RADIX_SIZE)
bins[idx] = 0;
if (idx == 0) {
exceed = k; // how many the number of is_topk exceeds k
bins[RADIX_SIZE-1] = k2 - smem[RADIX_SIZE-1];
if (bins[RADIX_SIZE-1] > 0)
k2 = bins[RADIX_SIZE-1];
else
exceed = min(exceed, bins[RADIX_SIZE-1]);
#pragma unroll
for(int bin=RADIX_SIZE-1; bin; --bin) {
bins[bin-1] = bins[bin] - smem[bin-1];
if (bins[bin-1] > 0)
k2 = bins[bin-1];
else
exceed = min(exceed, bins[bin-1]);
}
}
local_barrier();
// smem -> count
// bins -> k2 - cumsum(count)
if (is_topk && is_topkth) {
ga_ssize icount = bins[digit];
if (icount > 0) {
is_topkth = false;
} else if (icount < 0) {
if (digit+1!=RADIX_SIZE) {
if (bins[digit+1] <= 0) {
is_topk = false;
is_topkth = false;
}
}
}
}
}
// 2. find the index of output array, if exists
if (exceed != 0) {
// top_kth value may not be unique, so we need to
// perform binary cumsum on is_topkth to drop exceeding top-kth values
out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topkth);
is_topk &= (out_idx < exceed);
}
// perform binary cumsum on is_topk to determine the indices to put result
out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topk);
local_barrier();
if (is_topk) {
#if WRITE_VALUE == 1
ptr_at(dstv, out_idx * dstv_strides_0) = xval;
#endif
#if WRITE_INDEX == 1
ptr_at(dsti, out_idx * dsti_strides_0) = (INDEX_TYPE)idx;
#endif
}
}
}
theano/gpuarray/k_topk_dense.cu
0 → 100644
浏览文件 @
933cb859
// works when length on axis is within max allowed threads in block (1024)
KERNEL void k_topk_dense(
$dims
// ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM}
$dstv
// INPUT_TYPE *dstv
$dstv_strides
// ga_ssize dstv_strides_0, ga_ssize dstv_strides_1, ... , dstv_strides_$${NDIM}
$dsti
// INDEX_TYPE *dsti
$dsti_strides
// ga_ssize dsti_strides_0, ga_ssize dsti_strides_1, ... , dsti_strides_$${NDIM}
ga_ssize k,
INPUT_TYPE* src,
$src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
ga_size size) {
LOCAL_MEM radix_t smem[32 * RADIX_SIZE];
ga_ssize LOCAL_MEM bins[RADIX_SIZE+1]; // TODO: does using 32-bit gives good speedup?
bool is_topk=true, is_topkth=true;
radix_t out_idx;
const ga_ushort idx = LID_0;
ga_size LOCAL_MEM k2, exceed;
const ga_ubyte warp_id = idx / GA_WARP_SIZE;
const ga_ubyte lane_id = idx % GA_WARP_SIZE;
const bool in_range = (idx < size);
is_topk &= in_range;
// 0. get the slice for thread block to work on
// TODO if ndim <= 3, use native indexing ? (blockIdx.[xyz])
ga_size gid = GID_0, gidx;
$set_slice
//for(int i=1; i<NDIM; i++) {
// gidx = gid % dims_$${i};
// gid /= dims_$${i};
// dsti = ptr_add(dsti, gidx*dsti_strides_$${i};
// dstv = ptr_add(dstv, gidx*dstv_strides_$${i};
// src = ptr_add(src, gidx*src_strides_$${i});
//}
// get input and its radix friendly form
const INPUT_TYPE xval = in_range ? ptr_at(src, idx*src_strides_0) : (INPUT_TYPE)0;
radix_t x = in_range ? RadixConfig<INPUT_TYPE>::convert(xval) : 0;
// resolve negative k
if (k<0) { x = ~x; k = -k; }
if (idx==0) {
k2 = k;
bins[RADIX_SIZE] = 1;
}
// 1. filter is_topk and is_topkth using radix select
#pragma unroll
for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
int digit = (x>>i) & (RADIX_SIZE-1);
// count within warp
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
bool incr_bin = (bin == digit) && is_topkth && in_range;
ga_uint incr_bin_warp = __ballot(incr_bin);
if (lane_id==0)
smem[bin + RADIX_SIZE*warp_id] = __popc(incr_bin_warp);
}
local_barrier();
// sum counts across all warps
// TODO: test in-block parallel sum?
if (idx < RADIX_SIZE) {
for(int w=RADIX_SIZE; w<LDIM_0*RADIX_SIZE / GA_WARP_SIZE; w+=RADIX_SIZE)
smem[idx] += smem[idx + w];
}
local_barrier();
// bins = k - cumsum(smem[:RADIX_SIZE])
if (idx == 0) {
bins[RADIX_SIZE-1] = k2 - smem[RADIX_SIZE-1];
if (bins[RADIX_SIZE-1] > 0)
k2 = bins[RADIX_SIZE-1];
#pragma unroll
for (int bin=RADIX_SIZE-1; bin; --bin) {
bins[bin-1] = bins[bin] - smem[bin-1];
if (bins[bin-1] > 0)
k2 = bins[bin-1];
}
}
local_barrier();
// smem -> count
// bins -> k2 - cumsum(count)
if (is_topk && is_topkth) {
ga_ssize icount = bins[digit];
if (icount > 0) {
is_topkth = false;
} else if (bins[digit+1] <= 0) {
is_topk = false;
is_topkth = false;
}
}
}
if (idx==0) {
#pragma unroll
for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
if (bins[bin] <= 0) {
exceed = -bins[bin];
break;
}
}
}
local_barrier();
// 2. find the index of output array, if exists
if (exceed != 0) {
// top_kth value may not be unique, so we need to
// perform binary cumsum on is_topkth to drop exceeding top-kth values
out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topkth);
is_topk &= ((!is_topkth) || out_idx>=exceed);
}
// perform binary cumsum on is_topk to determine the indices to put result
out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topk);
local_barrier();
if (is_topk) {
#if WRITE_VALUE == 1
ptr_at(dstv, out_idx * dstv_strides_0) = xval;
#endif
#if WRITE_INDEX == 1
ptr_at(dsti, out_idx * dsti_strides_0) = (INDEX_TYPE)idx;
#endif
}
}
theano/gpuarray/k_topk_dense_large.cu
0 → 100644
浏览文件 @
933cb859
// works when length on axis is larger than max allowed threads in block (1024)
KERNEL void k_topk_dense_large(
$dims
// ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM}
$dstv
// INPUT_TYPE *dstv
$dstv_strides
// ga_ssize dstv_strides_0, ga_ssize dstv_strides_1, ... , dstv_strides_$${NDIM}
$dsti
// INDEX_TYPE *dsti
$dsti_strides
// ga_ssize dsti_strides_0, ga_ssize dsti_strides_1, ... , dsti_strides_$${NDIM}
ga_ssize k,
INPUT_TYPE* src,
$src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
ga_size size, ga_ushort inp_per_thread) {
LOCAL_MEM radix_t smem[32 * RADIX_SIZE];
LOCAL_MEM radix_t known_bits, known_bits_mask;
radix_t out_idx;
ga_size LOCAL_MEM write_base;
INPUT_TYPE xval;
radix_t x;
ga_int i;
bool in_range, is_topk;
const ga_size idx = LID_0;
ga_size LOCAL_MEM k2;
const ga_ushort warp_id = idx / GA_WARP_SIZE;
const ga_ushort lane_id = idx % GA_WARP_SIZE;
// 0. get the slice for thread block to work on
// TODO if ndim <= 3, use native indexing ? (blockIdx.[xyz])
ga_size gid = GID_0, gidx;
$set_slice
//for(int i=1; i<NDIM; i++) {
// gidx = gid % dims_$${i};
// gid /= dims_$${i};
// dsti = ptr_add(dsti, gidx*dsti_strides_$${i};
// dstv = ptr_add(dstv, gidx*dstv_strides_$${i};
// src = ptr_add(src, gidx*src_strides_$${i});
//}
src = ptr_add(src, idx*inp_per_thread*src_strides_0);
LOCAL_MEM radix_t inv_bits;
if (idx==0) {
known_bits = known_bits_mask = 0;
k2 = abs(k);
inv_bits = (k>=0) ? 0 : (~0);
write_base = 0;
}
if (k<0) { k = -k; }
local_barrier();
// 1. find bits of top-k-th value using radix select
#pragma unroll
for (i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
/*for (i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i*=-1) {*/
if (lane_id == 0) {
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
smem[bin + warp_id*RADIX_SIZE] = 0;
}
}
local_barrier();
for (int j=0; j<inp_per_thread; ++j) {
in_range = (idx*inp_per_thread+j) < size;
xval = in_range ? ptr_read(src, j*src_strides_0) : (INPUT_TYPE)0;
x = inv_bits^RadixConfig<INPUT_TYPE>::convert(xval);
ga_int digit = (int)((x>>i) & (RADIX_SIZE-1));
// count within warp
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
bool incr_bin = (
(bin == digit) &&
((x&known_bits_mask) == known_bits) &&
in_range);
ga_uint incr_bin_warp = __ballot(incr_bin);
if (lane_id==0)
smem[bin + RADIX_SIZE*warp_id] += __popc(incr_bin_warp);
}
}
local_barrier();
// sum counts across all warps
// TODO: test in-block parallel sum?
if (idx < RADIX_SIZE) {
for(int w=RADIX_SIZE;
w<(LDIM_0/ GA_WARP_SIZE)*RADIX_SIZE;
w+=RADIX_SIZE)
smem[idx] += smem[idx + w];
}
local_barrier();
// update known bits
if (idx==0) {
#pragma unroll
for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
if (smem[bin] >= k2) {
known_bits |= (bin << i);
known_bits_mask |= ((RADIX_SIZE-1) << i);
break;
} else
k2 -= smem[bin];
}
}
local_barrier();
}
/*
if (idx < RADIX_SIZE) {
ptr_at(dstv, idx*dstv_strides_0) = known_bits;
ptr_at(dstv, idx*dstv_strides_0) = smem[idx];
}
return;
*/
// 2. write values smaller than top-kth
for (i=0; i<inp_per_thread; ++i) {
in_range = (idx*inp_per_thread+i) < size;
xval = in_range ? ptr_read(src, i*src_strides_0) : (INPUT_TYPE)0;
x = inv_bits ^ RadixConfig<INPUT_TYPE>::convert(xval);
is_topk = (x > known_bits) && in_range;
out_idx = binary_cumsum<radix_t>(idx, warp_id, lane_id, smem, is_topk);
if (is_topk) {
#if WRITE_VALUE == 1
ptr_at(dstv, (out_idx+write_base-1) * dstv_strides_0) = xval;
#endif
#if WRITE_INDEX == 1
ptr_at(dsti, (out_idx+write_base-1) * dsti_strides_0) = (INDEX_TYPE)(idx*inp_per_thread + i);
#endif
}
local_barrier();
if (idx == blockDim.x - 1)
write_base += out_idx;
local_barrier();
}
// 3. write values equal to top-kth
for (i=0; i<inp_per_thread; ++i) {
in_range = (idx*inp_per_thread+i) < size;
xval = in_range ? ptr_read(src, i*src_strides_0) : (INPUT_TYPE)0;
x = inv_bits ^ RadixConfig<INPUT_TYPE>::convert(xval);
is_topk = (x == known_bits) && in_range;
out_idx = binary_cumsum<radix_t>(idx, warp_id, lane_id, smem, is_topk);
is_topk = ((out_idx+write_base) <= abs(k)) && is_topk;
if (is_topk) {
#if WRITE_VALUE == 1
ptr_at(dstv, (out_idx+write_base-1) * dstv_strides_0) = xval;
#endif
#if WRITE_INDEX == 1
ptr_at(dsti, (out_idx+write_base-1) * dsti_strides_0) = (INDEX_TYPE)(idx*inp_per_thread + i);
#endif
}
local_barrier();
if (idx == blockDim.x - 1)
write_base += out_idx;
local_barrier();
if(write_base >= abs(k))
break;
}
}
theano/gpuarray/sort.py
浏览文件 @
933cb859
from
__future__
import
absolute_import
,
print_function
,
division
from
__future__
import
absolute_import
,
print_function
,
division
import
os
import
os
from
string
import
Template
from
string
import
Template
import
pdb
import
numpy
as
np
import
theano
import
theano
from
theano
import
Apply
from
theano
import
Apply
from
theano.tensor
import
as_tensor_variable
from
theano.tensor
import
as_tensor_variable
...
@@ -20,7 +22,6 @@ except ImportError as e:
...
@@ -20,7 +22,6 @@ except ImportError as e:
pass
pass
# TODO add support when slice size is larger than max allowed block size (1024)
# TODO add runtime opt, if k==1, use max/min reduce
# TODO add runtime opt, if k==1, use max/min reduce
# TODO add opt to merge argtopk / topk, or split topk_and_argtopk when only
# TODO add opt to merge argtopk / topk, or split topk_and_argtopk when only
# one result is needed
# one result is needed
...
@@ -33,12 +34,13 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
...
@@ -33,12 +34,13 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
'''
'''
__props__
=
TopKOp
.
__props__
__props__
=
TopKOp
.
__props__
def
__init__
(
self
,
axis
=-
1
,
return_
indices
=
False
,
return_values
=
True
):
def
__init__
(
self
,
axis
=-
1
,
return_
values
=
True
,
return_indices
=
False
,
idx_dtype
=
'int64'
):
GpuKernelBase
.
__init__
(
self
)
GpuKernelBase
.
__init__
(
self
)
TopKOp
.
__init__
(
TopKOp
.
__init__
(
self
,
axis
=
axis
,
self
,
axis
=
axis
,
return_values
=
return_values
,
return_values
=
return_values
,
return_indices
=
return_indices
)
return_indices
=
return_indices
,
idx_dtype
=
idx_dtype
)
def
c_headers
(
self
):
def
c_headers
(
self
):
return
[
'gpuarray_api.h'
,
'gpuarray_helper.h'
,
'numpy_compat.h'
]
return
[
'gpuarray_api.h'
,
'gpuarray_helper.h'
,
'numpy_compat.h'
]
...
@@ -54,19 +56,23 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
...
@@ -54,19 +56,23 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
def
gpu_kernels
(
self
,
node
,
nodename
):
def
gpu_kernels
(
self
,
node
,
nodename
):
# load kernel source
# load kernel source
device_type
=
node
.
inputs
[
0
]
.
type
.
context
.
kind
device_type
=
node
.
inputs
[
0
]
.
type
.
context
.
kind
knames
=
[
'k_topk_dense'
,
'k_topk_dense_large'
]
kernel_ext
=
{
b
'cuda'
:
'.cu'
,
b
'opencl'
:
'.cl'
}[
device_type
]
kernel_ext
=
{
b
'cuda'
:
'.cu'
,
b
'opencl'
:
'.cl'
}[
device_type
]
try
:
common_ext
=
{
b
'cuda'
:
'.cuh'
,
b
'opencl'
:
'.h'
}[
device_type
]
kernel_filename
=
'topk_kernel
%
s'
%
kernel_ext
kernel_src
=
{}
for
kname
in
knames
:
with
open
(
os
.
path
.
join
(
with
open
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
k
ernel_filename
os
.
path
.
dirname
(
__file__
),
k
name
+
kernel_ext
),
'r'
)
as
f
:
),
'r'
)
as
f
:
kernel_src
=
f
.
read
()
kernel_src
[
kname
]
=
f
.
read
()
except
FileNotFoundError
:
raise
RuntimeError
(
with
open
(
os
.
path
.
join
(
'Cannot find GPU kernel '
os
.
path
.
dirname
(
__file__
),
'k_topk_common'
+
common_ext
'implementation for device "
%
s"'
%
device_type
)
),
'r'
)
as
f
:
common_src
=
f
.
read
()
# prepare "$" macros
# prepare "$" macros
if
device_type
==
b
'cuda'
:
ndim
=
node
.
inputs
[
0
]
.
ndim
ndim
=
node
.
inputs
[
0
]
.
ndim
dstv_strides_code
=
''
.
join
(
'ga_ssize dstv_strides_
%
d, '
%
i
for
i
in
range
(
ndim
))
dstv_strides_code
=
''
.
join
(
'ga_ssize dstv_strides_
%
d, '
%
i
for
i
in
range
(
ndim
))
dsti_strides_code
=
''
.
join
(
'ga_ssize dsti_strides_
%
d, '
%
i
for
i
in
range
(
ndim
))
dsti_strides_code
=
''
.
join
(
'ga_ssize dsti_strides_
%
d, '
%
i
for
i
in
range
(
ndim
))
...
@@ -84,7 +90,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
...
@@ -84,7 +90,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
flags
=
Kernel
.
get_flags
(
node
.
inputs
[
0
]
.
dtype
)
flags
=
Kernel
.
get_flags
(
node
.
inputs
[
0
]
.
dtype
)
subs
=
dict
(
subs
=
dict
(
inp_t
=
ga
.
dtype_to_ctype
(
node
.
inputs
[
0
]
.
dtype
),
inp_t
=
ga
.
dtype_to_ctype
(
node
.
inputs
[
0
]
.
dtype
),
out_t
=
ga
.
dtype_to_ctype
(
node
.
outputs
[
0
]
.
dtype
),
out_t
=
ga
.
dtype_to_ctype
(
self
.
idx_
dtype
),
dims
=
''
.
join
(
'ga_size dims_
%
d, '
%
i
for
i
in
range
(
1
,
ndim
)),
dims
=
''
.
join
(
'ga_size dims_
%
d, '
%
i
for
i
in
range
(
1
,
ndim
)),
dstv
=
'INPUT_TYPE *dstv,'
if
self
.
return_values
else
''
,
dstv
=
'INPUT_TYPE *dstv,'
if
self
.
return_values
else
''
,
dsti
=
'INDEX_TYPE *dsti,'
if
self
.
return_indices
else
''
,
dsti
=
'INDEX_TYPE *dsti,'
if
self
.
return_indices
else
''
,
...
@@ -95,11 +101,11 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
...
@@ -95,11 +101,11 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
write_value
=
int
(
self
.
return_values
),
write_value
=
int
(
self
.
return_values
),
write_index
=
int
(
self
.
return_indices
),
write_index
=
int
(
self
.
return_indices
),
ndim
=
str
(
ndim
))
ndim
=
str
(
ndim
))
elif
device_type
==
b
'opencl'
:
raise
NotImplementedError
()
# substitute "$" macros in kernel code
# compile kernels
kernel_src
=
Template
(
kernel_src
)
.
substitute
(
**
subs
)
kernels
=
[]
# compile kernel
param_types
=
[
ga
.
SIZE
]
*
(
ndim
-
1
)
# dims
param_types
=
[
ga
.
SIZE
]
*
(
ndim
-
1
)
# dims
for
_
in
range
(
int
(
self
.
return_values
)
+
int
(
self
.
return_indices
)):
for
_
in
range
(
int
(
self
.
return_values
)
+
int
(
self
.
return_indices
)):
param_types
.
append
(
ga
.
GpuArray
)
# dst*
param_types
.
append
(
ga
.
GpuArray
)
# dst*
...
@@ -108,31 +114,39 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
...
@@ -108,31 +114,39 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
param_types
.
append
(
ga
.
GpuArray
)
# src
param_types
.
append
(
ga
.
GpuArray
)
# src
param_types
.
extend
([
ga
.
SSIZE
]
*
ndim
)
# src_strides
param_types
.
extend
([
ga
.
SSIZE
]
*
ndim
)
# src_strides
param_types
.
append
(
ga
.
SIZE
)
# size
param_types
.
append
(
ga
.
SIZE
)
# size
self
.
nargs
=
len
(
param_types
)
kernels
.
append
(
Kernel
(
return
[
Kernel
(
code
=
Template
(
common_src
+
kernel_src
[
'k_topk_dense'
])
.
substitute
(
**
subs
),
code
=
kernel_src
,
name
=
'k_topk_dense'
,
name
=
'k_topk_dense'
,
params
=
param_types
,
params
=
param_types
,
flags
=
flags
,
flags
=
flags
,
objvar
=
'k_topk_dense_'
+
nodename
objvar
=
'k_topk_dense_'
+
nodename
)]
))
param_types
.
append
(
np
.
uint16
)
# inp_per_thread
kernels
.
append
(
Kernel
(
code
=
Template
(
common_src
+
kernel_src
[
'k_topk_dense_large'
])
.
substitute
(
**
subs
),
name
=
'k_topk_dense_large'
,
params
=
param_types
,
flags
=
flags
,
objvar
=
'k_topk_dense_large_'
+
nodename
))
return
kernels
def
c_code
(
self
,
node
,
nodename
,
inps
,
outs
,
sub
):
def
c_code
(
self
,
node
,
nodename
,
inps
,
outs
,
sub
):
if
node
.
inputs
[
0
]
.
type
.
context
.
kind
!=
b
'cuda'
:
if
node
.
inputs
[
0
]
.
type
.
context
.
kind
!=
b
'cuda'
:
raise
NotImplementedError
(
'We only have CUDA implementation so far.'
)
raise
NotImplementedError
(
'
%
s: We only have CUDA '
'implementation so far.'
%
self
.
__class__
.
__name__
)
x
,
k
=
inps
x
,
k
=
inps
inp_dtc
=
pygpu
.
dtypes
.
dtype_to_ctype
(
node
.
inputs
[
0
]
.
dtype
)
.
upper
(
)
inp_dtc
=
ga
.
dtype_to_typecode
(
node
.
inputs
[
0
]
.
dtype
)
if
not
self
.
return_indices
:
if
not
self
.
return_indices
:
yv
,
=
outs
yv
,
=
outs
out_dtype_s
=
''
out_dtc
=
''
else
:
else
:
if
self
.
return_values
:
if
self
.
return_values
:
yv
,
yi
=
outs
yv
,
yi
=
outs
else
:
else
:
yi
,
=
outs
yi
,
=
outs
out_dtype_s
=
node
.
outputs
[
0
]
.
dtype
out_dtype_s
=
self
.
idx_
dtype
out_dtc
=
pygpu
.
dtypes
.
dtype_to_ctype
(
out_dtype_s
)
.
upper
(
)
out_dtc
=
ga
.
dtype_to_typecode
(
out_dtype_s
)
fail
=
sub
[
'fail'
]
fail
=
sub
[
'fail'
]
ctx
=
sub
[
'params'
]
ctx
=
sub
[
'params'
]
k_dtype
=
node
.
inputs
[
1
]
.
type
.
dtype_specs
()[
1
]
k_dtype
=
node
.
inputs
[
1
]
.
type
.
dtype_specs
()[
1
]
...
@@ -140,7 +154,6 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
...
@@ -140,7 +154,6 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
WARP_SIZE
=
32
WARP_SIZE
=
32
ndim
=
node
.
inputs
[
0
]
.
ndim
ndim
=
node
.
inputs
[
0
]
.
ndim
nargs
=
self
.
nargs
reordered_axes
=
list
(
range
(
ndim
))
reordered_axes
=
list
(
range
(
ndim
))
axis
=
self
.
axis
%
ndim
axis
=
self
.
axis
%
ndim
del
(
reordered_axes
[
axis
])
del
(
reordered_axes
[
axis
])
...
@@ -175,16 +188,21 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
...
@@ -175,16 +188,21 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
sstrides
=
', '
.
join
(
'(void*)(sstrides+
%
d)'
%
i
for
i
in
reordered_axes
)
sstrides
=
', '
.
join
(
'(void*)(sstrides+
%
d)'
%
i
for
i
in
reordered_axes
)
code
=
'''
code
=
'''
{
{
const ssize_t k_ = ((
%(k_dtype)
s*)(PyArray_DATA(
%(k)
s)))[0];
const size_t *dims = PyGpuArray_DIMS(
%(x)
s);
const size_t *dims = PyGpuArray_DIMS(
%(x)
s);
size_t odims[
%(ndim)
d];
size_t odims[
%(ndim)
d];
for (int i=0; i<
%(ndim)
d; i++)
for (int i=0; i<
%(ndim)
d; i++)
odims[i] = dims[i];
odims[i] = dims[i];
odims[
%(axis)
d] = *((
%(k_dtype)
s*)(PyArray_DATA(
%(k)
s)));
if (odims[0] >
%(MAX_TPB)
d) {
odims[
%(axis)
d] = k_>=0 ? k_ : -k_;
if (0 == odims[
%(axis)
d]) {
PyErr_SetString(
PyErr_SetString(
PyExc_ValueError,
PyExc_ValueError,
"topk: slice size larger than
%(MAX_TPB)
d is not supported");
"topk: k must not be zero");
%(fail)
s; }
%(fail)
s;
}
%(prep_output)
s
%(prep_output)
s
// TODO better scheduling?
// TODO better scheduling?
...
@@ -192,32 +210,45 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
...
@@ -192,32 +210,45 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
size_t *grd = blk+3;
size_t *grd = blk+3;
blk[0] = blk[1] = blk[2] = 1;
blk[0] = blk[1] = blk[2] = 1;
grd[0] = grd[1] = grd[2] = 1;
grd[0] = grd[1] = grd[2] = 1;
// round up to multiples of warp size
for(int i=0; i<
%(ndim)
d; ++i) {
for(int i=0; i<
%(ndim)
d; ++i) {
if (i!=
%(axis)
d)
if (i!=
%(axis)
d)
grd[0] *= dims[i];
grd[0] *= dims[i];
else
else
blk[0] = dims[i];
blk[0] = dims[i];
}
}
// round up to multiples of warp size
blk[0] = ((blk[0] +
%(WARP_SIZE)
d - 1) /
%(WARP_SIZE)
d) *
%(WARP_SIZE)
d;
blk[0] = ((blk[0] +
%(WARP_SIZE)
d - 1) /
%(WARP_SIZE)
d) *
%(WARP_SIZE)
d;
%(def_dvstrides)
s;
%(def_dvstrides)
s;
%(def_distrides)
s;
%(def_distrides)
s;
const ssize_t *sstrides = PyGpuArray_STRIDES(
%(x)
s);
const ssize_t *sstrides = PyGpuArray_STRIDES(
%(x)
s);
// inputs per thread
unsigned short ipt = (dims[
%(axis)
d] + (
%(MAX_TPB)
d/2)-1) / (
%(MAX_TPB)
d/2);
void* args[] = {
void* args[] = {
%(dims)
s
%(dims)
s
%(params_dv)
s
%(params_dv)
s
%(params_di)
s
%(params_di)
s
(void*)(
odims+
%(axis)
d
),
(void*)(
&k_
),
(void*)(
%(x)
s->ga.data),
(void*)(
%(x)
s->ga.data),
%(sstrides)
s,
%(sstrides)
s,
(void*)(dims+
%(axis)
d)
(void*)(dims+
%(axis)
d),
(void*)(&ipt)
};
};
int err = GpuKernel_call(
int err;
if (blk[0] >
%(MAX_TPB)
d) {
// CUDA_OUT_OF_RESOURCE if a max sized block is used
blk[0] =
%(MAX_TPB)
d / 2;
err = GpuKernel_call(
&k_topk_dense_large_
%(nodename)
s, 3,
grd, blk, 0,
args);
} else {
err = GpuKernel_call(
&k_topk_dense_
%(nodename)
s, 3,
&k_topk_dense_
%(nodename)
s, 3,
grd, blk, 0,
grd, blk, 0,
args);
args);
}
if (err != GA_NO_ERROR) {
if (err != GA_NO_ERROR) {
PyErr_SetString(
PyErr_SetString(
PyExc_RuntimeError,
PyExc_RuntimeError,
...
@@ -228,37 +259,39 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
...
@@ -228,37 +259,39 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
'''
'''
return
code
%
locals
()
return
code
%
locals
()
def
make_node
(
self
,
inp
,
k
,
idx_dtype
=
'int64'
):
def
make_node
(
self
,
inp
,
k
):
ctx_name
=
infer_context_name
(
inp
)
ctx_name
=
infer_context_name
(
inp
)
inp
=
as_gpuarray_variable
(
inp
,
ctx_name
)
inp
=
as_gpuarray_variable
(
inp
,
ctx_name
)
k
=
as_tensor_variable
(
k
)
k
=
as_tensor_variable
(
k
)
bcast
=
inp
.
type
.
broadcastable
bcast
=
inp
.
type
.
broadcastable
outs
=
[]
outs
=
[]
if
self
.
return_values
:
outs
.
append
(
inp
.
type
())
if
self
.
return_indices
:
if
self
.
return_indices
:
outs
.
append
(
GpuArrayType
(
outs
.
append
(
GpuArrayType
(
dtype
=
idx_dtype
,
dtype
=
self
.
idx_dtype
,
broadcastable
=
bcast
,
broadcastable
=
bcast
,
context_name
=
ctx_name
)())
context_name
=
ctx_name
)())
if
self
.
return_values
:
outs
.
append
(
inp
.
type
())
return
Apply
(
self
,
[
inp
,
k
],
outs
)
return
Apply
(
self
,
[
inp
,
k
],
outs
)
def
get_params
(
self
,
node
):
def
get_params
(
self
,
node
):
return
node
.
inputs
[
0
]
.
type
.
context
return
node
.
inputs
[
0
]
.
type
.
context
# def get_op_params(self):
# return [('AXIS', self.axis)]
@register_opt
(
'fast_compile'
)
@register_opt
(
'fast_compile'
)
@op_lifter
([
TopKOp
])
@op_lifter
([
TopKOp
])
@register_opt2
([
TopKOp
],
'fast_compile'
)
@register_opt2
([
TopKOp
],
'fast_compile'
)
def
local_gpua_topkop
(
op
,
ctx_name
,
inputs
,
outputs
):
def
local_gpua_topkop
(
op
,
ctx_name
,
inputs
,
outputs
):
if
isinstance
(
op
,
GpuTopKOp
):
return
False
axis
=
op
.
axis
axis
=
op
.
axis
rv
=
op
.
return_values
rv
=
op
.
return_values
ri
=
op
.
return_indices
ri
=
op
.
return_indices
x
,
k
=
inputs
x
,
k
=
inputs
x
=
as_gpuarray_variable
(
x
,
ctx_name
)
x
=
as_gpuarray_variable
(
x
,
ctx_name
)
y
=
outputs
[
-
1
]
rets
=
GpuTopKOp
(
return
GpuTopKOp
(
axis
=
axis
,
return_values
=
rv
,
return_indices
=
ri
,
idx_dtype
=
op
.
idx_dtype
)(
x
,
k
)
axis
=
axis
,
return_values
=
rv
,
return_indices
=
ri
)(
x
,
k
,
idx_dtype
=
y
.
dtype
)
return
rets
theano/tensor/sort.py
浏览文件 @
933cb859
...
@@ -342,20 +342,21 @@ class TopKOp(theano.Op):
...
@@ -342,20 +342,21 @@ class TopKOp(theano.Op):
# TODO c_code
# TODO c_code
__props__
=
(
'axis'
,
'return_values'
,
'return_indices'
)
__props__
=
(
'axis'
,
'return_values'
,
'return_indices'
,
'idx_dtype'
)
def
__init__
(
self
,
axis
=-
1
,
return_indices
=
False
,
return_values
=
True
):
def
__init__
(
self
,
axis
=-
1
,
return_indices
=
False
,
return_values
=
True
,
idx_dtype
=
'int64'
):
assert
isinstance
(
axis
,
int
)
assert
isinstance
(
axis
,
int
)
assert
return_indices
or
return_values
assert
return_indices
or
return_values
self
.
axis
=
axis
self
.
axis
=
axis
self
.
return_indices
=
return_indices
self
.
return_indices
=
return_indices
self
.
return_values
=
return_values
self
.
return_values
=
return_values
self
.
idx_dtype
=
idx_dtype
def
__str__
(
self
):
def
__str__
(
self
):
return
'
%(op)
s{axis=
%(axis)
d}'
%
dict
(
return
'
%(op)
s{axis=
%(axis)
d}'
%
dict
(
op
=
self
.
__class__
.
__name__
,
axis
=
self
.
axis
)
op
=
self
.
__class__
.
__name__
,
axis
=
self
.
axis
)
def
make_node
(
self
,
inp
,
k
,
idx_dtype
=
'int64'
):
def
make_node
(
self
,
inp
,
k
):
# numpy always uses float64 as output dtype for arg*() routines
# numpy always uses float64 as output dtype for arg*() routines
# however, we add this option as memory is more precious on gpu
# however, we add this option as memory is more precious on gpu
inp
=
theano
.
tensor
.
as_tensor_variable
(
inp
)
inp
=
theano
.
tensor
.
as_tensor_variable
(
inp
)
...
@@ -366,7 +367,7 @@ class TopKOp(theano.Op):
...
@@ -366,7 +367,7 @@ class TopKOp(theano.Op):
outs
.
append
(
inp
.
type
())
outs
.
append
(
inp
.
type
())
if
self
.
return_indices
:
if
self
.
return_indices
:
outs
.
append
(
outs
.
append
(
theano
.
tensor
.
TensorType
(
dtype
=
idx_dtype
,
broadcastable
=
bcast
)())
theano
.
tensor
.
TensorType
(
dtype
=
self
.
idx_dtype
,
broadcastable
=
bcast
)())
return
theano
.
Apply
(
self
,
[
inp
,
k
],
outs
)
return
theano
.
Apply
(
self
,
[
inp
,
k
],
outs
)
def
perform
(
self
,
node
,
inputs
,
output_storage
):
def
perform
(
self
,
node
,
inputs
,
output_storage
):
...
@@ -458,18 +459,18 @@ def argtopk(x, k, axis=-1, idx_dtype='int64'):
...
@@ -458,18 +459,18 @@ def argtopk(x, k, axis=-1, idx_dtype='int64'):
if
axis
is
None
:
if
axis
is
None
:
x
=
theano
.
tensor
.
flatten
(
x
)
x
=
theano
.
tensor
.
flatten
(
x
)
axis
=
-
1
axis
=
-
1
return
TopKOp
(
axis
=
axis
,
return_indices
=
True
,
return_values
=
False
)(
x
,
k
,
idx_dtype
=
idx_dtype
)
return
TopKOp
(
axis
=
axis
,
return_indices
=
True
,
return_values
=
False
,
idx_dtype
=
idx_dtype
)(
x
,
k
)
def
topk_and_argtopk
(
x
,
k
,
axis
=-
1
,
idx_dtype
=
'int64'
):
def
topk_and_argtopk
(
x
,
k
,
axis
=-
1
,
idx_dtype
=
'int64'
):
'''
"""
Returns the results of both topk() and argtopk() in one Op.
Returns the results of both topk() and argtopk() in one Op.
See the respective documentation for details.
See the respective documentation for details.
'''
"""
if
axis
is
None
:
if
axis
is
None
:
x
=
theano
.
tensor
.
flatten
(
x
)
x
=
theano
.
tensor
.
flatten
(
x
)
axis
=
-
1
axis
=
-
1
return
TopKOp
(
axis
=
axis
,
return_indices
=
True
)(
x
,
k
,
idx_dtype
=
idx_dtype
)
return
TopKOp
(
axis
=
axis
,
return_indices
=
True
,
idx_dtype
=
idx_dtype
)(
x
,
k
)
theano/tensor/tests/test_sort.py
浏览文件 @
933cb859
...
@@ -21,14 +21,12 @@ _int_dtypes = (
...
@@ -21,14 +21,12 @@ _int_dtypes = (
'int8'
,
'int16'
,
'int32'
,
'int64'
,
'int8'
,
'int16'
,
'int32'
,
'int64'
,
'uint8'
,
'uint16'
,
'uint32'
,
'uint64'
)
'uint8'
,
'uint16'
,
'uint32'
,
'uint64'
)
def
gen_unique_vector
(
size
,
dtype
):
def
gen_unique_vector
(
size
,
dtype
):
# generate a randomized vector with unique elements
# generate a randomized vector with unique elements
retval
=
np
.
arange
(
size
*
3
)
+
np
.
random
.
uniform
(
-
1.
,
1.
)
retval
=
np
.
arange
(
size
*
3
)
+
np
.
random
.
uniform
(
-
1.
,
1.
)
return
(
retval
[
np
.
random
.
permutation
(
size
)]
-
size
*
1.5
)
.
astype
(
dtype
)
return
(
retval
[
np
.
random
.
permutation
(
size
)]
-
size
*
1.5
)
.
astype
(
dtype
)
'''
class
Test_sort
(
unittest
.
TestCase
):
class
Test_sort
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -236,7 +234,6 @@ def test_argsort_grad():
...
@@ -236,7 +234,6 @@ def test_argsort_grad():
data
=
np
.
random
.
rand
(
2
,
3
,
3
)
.
astype
(
theano
.
config
.
floatX
)
data
=
np
.
random
.
rand
(
2
,
3
,
3
)
.
astype
(
theano
.
config
.
floatX
)
utt
.
verify_grad
(
lambda
x
:
argsort
(
x
,
axis
=
2
),
[
data
])
utt
.
verify_grad
(
lambda
x
:
argsort
(
x
,
axis
=
2
),
[
data
])
'''
class
Test_TopK
(
unittest
.
TestCase
):
class
Test_TopK
(
unittest
.
TestCase
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论