Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
a3624d6f
提交
a3624d6f
authored
8月 19, 2017
作者:
Adam Becker
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
CLUDA code -> pure CUDA
上级
62ce362d
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
82 行增加
和
82 行删除
+82
-82
topk_common.cu
theano/gpuarray/c_code/topk_common.cu
+39
-39
topk_dense.cu
theano/gpuarray/c_code/topk_dense.cu
+18
-18
topk_dense_large.cu
theano/gpuarray/c_code/topk_dense_large.cu
+21
-21
sort.py
theano/gpuarray/sort.py
+4
-4
没有找到文件。
theano/gpuarray/c_code/topk_common.cu
浏览文件 @
a3624d6f
...
@@ -126,7 +126,7 @@ struct RadixConfig {
...
@@ -126,7 +126,7 @@ struct RadixConfig {
// We use this to enable radix selection of floating-point values.
// We use this to enable radix selection of floating-point values.
// This also gives a relative order for NaNs, but that's ok, as they
// This also gives a relative order for NaNs, but that's ok, as they
// will all be adjacent
// will all be adjacent
typedef
ga_u
int RadixType;
typedef
unsigned
int RadixType;
static inline __device__ RadixType convert(T v) {
static inline __device__ RadixType convert(T v) {
return (RadixType)v;
return (RadixType)v;
}
}
...
@@ -137,17 +137,17 @@ struct RadixConfig {
...
@@ -137,17 +137,17 @@ struct RadixConfig {
};
};
template <>
template <>
struct RadixConfig<
ga_
float> {
struct RadixConfig<float> {
typedef
ga_u
int RadixType;
typedef
unsigned
int RadixType;
static inline __device__ RadixType convert(
ga_
float v) {
static inline __device__ RadixType convert(float v) {
RadixType x = __float_as_int(v);
RadixType x = __float_as_int(v);
RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
return (x ^ mask);
return (x ^ mask);
}
}
static inline __device__
ga_
float deconvert(RadixType v) {
static inline __device__ float deconvert(RadixType v) {
RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff;
RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff;
return __int_as_float(v ^ mask);
return __int_as_float(v ^ mask);
...
@@ -155,16 +155,16 @@ struct RadixConfig<ga_float> {
...
@@ -155,16 +155,16 @@ struct RadixConfig<ga_float> {
};
};
template <>
template <>
struct RadixConfig<
ga_
double> {
struct RadixConfig<double> {
typedef
ga_u
long RadixType;
typedef
unsigned long
long RadixType;
static inline __device__ RadixType convert(
ga_
double v) {
static inline __device__ RadixType convert(double v) {
RadixType x = __double_as_longlong(v);
RadixType x = __double_as_longlong(v);
RadixType mask = -((x >> 63)) | 0x8000000000000000;
RadixType mask = -((x >> 63)) | 0x8000000000000000;
return (x ^ mask);
return (x ^ mask);
}
}
static inline __device__
ga_
double deconvert(RadixType v) {
static inline __device__ double deconvert(RadixType v) {
RadixType mask = ((v >> 63) - 1) | 0x8000000000000000;
RadixType mask = ((v >> 63) - 1) | 0x8000000000000000;
return __longlong_as_double(v ^ mask);
return __longlong_as_double(v ^ mask);
}
}
...
@@ -172,52 +172,52 @@ struct RadixConfig<ga_double> {
...
@@ -172,52 +172,52 @@ struct RadixConfig<ga_double> {
template <>
template <>
struct RadixConfig<
ga_byte
> {
struct RadixConfig<
char
> {
typedef
ga_u
int RadixType;
typedef
unsigned
int RadixType;
static inline __device__ RadixType convert(
ga_byte
v) {
static inline __device__ RadixType convert(
char
v) {
return 128u + v;
return 128u + v;
}
}
static inline __device__
ga_byte
deconvert(RadixType v) {
static inline __device__
char
deconvert(RadixType v) {
return v - 128;
return v - 128;
}
}
};
};
template <>
template <>
struct RadixConfig<
ga_
short> {
struct RadixConfig<short> {
typedef
ga_u
int RadixType;
typedef
unsigned
int RadixType;
static inline __device__ RadixType convert(
ga_
short v) {
static inline __device__ RadixType convert(short v) {
assert(sizeof(
ga_
short) == 2);
assert(sizeof(short) == 2);
return 32768u ^ v;
return 32768u ^ v;
}
}
static inline __device__
ga_
short deconvert(RadixType v) {
static inline __device__ short deconvert(RadixType v) {
return v - 32768;
return v - 32768;
}
}
};
};
template <>
template <>
struct RadixConfig<
ga_
int> {
struct RadixConfig<int> {
typedef
ga_u
int RadixType;
typedef
unsigned
int RadixType;
static inline __device__ RadixType convert(
ga_
int v) {
static inline __device__ RadixType convert(int v) {
assert(sizeof(int) == 4);
assert(sizeof(int) == 4);
return 2147483648u + v;
return 2147483648u + v;
}
}
static inline __device__
ga_
int deconvert(RadixType v) {
static inline __device__ int deconvert(RadixType v) {
return v - 2147483648u;
return v - 2147483648u;
}
}
};
};
template <>
template <>
struct RadixConfig<
ga_
long> {
struct RadixConfig<
long
long> {
typedef
ga_u
long RadixType;
typedef
unsigned long
long RadixType;
static inline __device__ RadixType convert(
ga_
long v) {
static inline __device__ RadixType convert(
long
long v) {
assert(sizeof(
ga_
long) == 8);
assert(sizeof(
long
long) == 8);
return 9223372036854775808ull + v;
return 9223372036854775808ull + v;
}
}
...
@@ -229,19 +229,19 @@ struct RadixConfig<ga_long> {
...
@@ -229,19 +229,19 @@ struct RadixConfig<ga_long> {
#define USE_HALF $use_half
#define USE_HALF $use_half
#if USE_HALF == 1
#if USE_HALF == 1
// since
ga_half is ushort, use
macro to protect this part is necessary
// since
half is ushort, using
macro to protect this part is necessary
template <>
template <>
struct RadixConfig<
ga_half
> {
struct RadixConfig<
unsigned short
> {
typedef
ga_u
int RadixType;
typedef
unsigned
int RadixType;
static inline __device__ RadixType convert(
ga_half
v) {
static inline __device__ RadixType convert(
unsigned short
v) {
RadixType mask = -(((RadixType)v >> 15)) | 0x8000;
RadixType mask = -(((RadixType)v >> 15)) | 0x8000;
return (v ^ mask);
return (v ^ mask);
}
}
static inline __device__
ga_half
deconvert(RadixType v) {
static inline __device__
unsigned short
deconvert(RadixType v) {
RadixType mask = ((v >> 15) - 1) | 0x8000;
RadixType mask = ((v >> 15) - 1) | 0x8000;
return (
ga_half
)(v ^ mask);
return (
unsigned short
)(v ^ mask);
}
}
};
};
#endif // USE_HALF
#endif // USE_HALF
...
@@ -274,7 +274,7 @@ static inline __device__ T binary_cumsum(
...
@@ -274,7 +274,7 @@ static inline __device__ T binary_cumsum(
// binary_cumsum(1, 0, 1, 0, 1) -> (1, 1, 2, 2, 3)
// binary_cumsum(1, 0, 1, 0, 1) -> (1, 1, 2, 2, 3)
// cumsum within warp
// cumsum within warp
ga_u
int warp_bits = __ballot(value);
unsigned
int warp_bits = __ballot(value);
T warp_sum = __popc(lane_mask_le() & warp_bits);
T warp_sum = __popc(lane_mask_le() & warp_bits);
if (lane_id() == 0)
if (lane_id() == 0)
...
@@ -285,7 +285,7 @@ static inline __device__ T binary_cumsum(
...
@@ -285,7 +285,7 @@ static inline __device__ T binary_cumsum(
// cumsum across warps in one thread
// cumsum across warps in one thread
if (idx == 0) {
if (idx == 0) {
T sum = smem[0];
T sum = smem[0];
for (int i = 1; i <
LDIM_0
/ GA_WARP_SIZE; ++i) {
for (int i = 1; i <
blockDim.x
/ GA_WARP_SIZE; ++i) {
sum += smem[i];
sum += smem[i];
smem[i] = sum;
smem[i] = sum;
}
}
...
@@ -309,7 +309,7 @@ static inline __device__ T binary_cumsum_exclusive(
...
@@ -309,7 +309,7 @@ static inline __device__ T binary_cumsum_exclusive(
// binary_cumsum_excl(1, 0, 1, 0, 1) -> (0, 1, 1, 2, 2)
// binary_cumsum_excl(1, 0, 1, 0, 1) -> (0, 1, 1, 2, 2)
// cumsum within warp
// cumsum within warp
ga_u
int warp_bits = __ballot(value);
unsigned
int warp_bits = __ballot(value);
T warp_sum = __popc(lane_mask_lt() & warp_bits);
T warp_sum = __popc(lane_mask_lt() & warp_bits);
if (lane_id() == 0)
if (lane_id() == 0)
...
@@ -320,7 +320,7 @@ static inline __device__ T binary_cumsum_exclusive(
...
@@ -320,7 +320,7 @@ static inline __device__ T binary_cumsum_exclusive(
// cumsum across warps in one thread
// cumsum across warps in one thread
if (idx == 0) {
if (idx == 0) {
T sum = smem[0];
T sum = smem[0];
for (int i = 1; i <
LDIM_0
/ GA_WARP_SIZE; ++i) {
for (int i = 1; i <
blockDim.x
/ GA_WARP_SIZE; ++i) {
sum += smem[i];
sum += smem[i];
smem[i] = sum;
smem[i] = sum;
}
}
...
@@ -337,19 +337,19 @@ static inline __device__ T binary_cumsum_exclusive(
...
@@ -337,19 +337,19 @@ static inline __device__ T binary_cumsum_exclusive(
// apply raw(byte) offset to pointer
// apply raw(byte) offset to pointer
template <typename T>
template <typename T>
static __device__ inline T* ptr_add(T *ptr,
ga_ssize
offset) {
static __device__ inline T* ptr_add(T *ptr,
ssize_t
offset) {
return (T*)((char*)ptr + offset);
return (T*)((char*)ptr + offset);
}
}
// get array element using raw(byte) offset
// get array element using raw(byte) offset
template <typename T>
template <typename T>
static __device__ inline T& ptr_at(T *ptr,
ga_ssize
offset) {
static __device__ inline T& ptr_at(T *ptr,
ssize_t
offset) {
return *((T*)((char*)ptr + offset));
return *((T*)((char*)ptr + offset));
}
}
// read array element using raw(byte) offset
// read array element using raw(byte) offset
template <typename T>
template <typename T>
static __device__ inline T ptr_read_cached(T *ptr,
ga_ssize
offset) {
static __device__ inline T ptr_read_cached(T *ptr,
ssize_t
offset) {
return __ldg(((T*)((char*)ptr + offset)));
return __ldg(((T*)((char*)ptr + offset)));
}
}
theano/gpuarray/c_code/topk_dense.cu
浏览文件 @
a3624d6f
...
@@ -6,32 +6,32 @@
...
@@ -6,32 +6,32 @@
// works when length on axis is within max allowed threads in block (1024)
// works when length on axis is within max allowed threads in block (1024)
KERNEL void k_topk_dense(
KERNEL void k_topk_dense(
$dims
$dims
//
ga_size dims_1, ga_ssize
dims_2, ... , dims_$${NDIM}
//
size_t dims_1, ssize_t
dims_2, ... , dims_$${NDIM}
$dstv
$dstv
// INPUT_TYPE *dstv
// INPUT_TYPE *dstv
$dstv_strides
$dstv_strides
//
ga_ssize dstv_strides_0, ga_ssize
dstv_strides_1, ... , dstv_strides_$${NDIM}
//
ssize_t dstv_strides_0, ssize_t
dstv_strides_1, ... , dstv_strides_$${NDIM}
$dsti
$dsti
// INDEX_TYPE *dsti
// INDEX_TYPE *dsti
$dsti_strides
$dsti_strides
//
ga_ssize dsti_strides_0, ga_ssize
dsti_strides_1, ... , dsti_strides_$${NDIM}
//
ssize_t dsti_strides_0, ssize_t
dsti_strides_1, ... , dsti_strides_$${NDIM}
ga_ssize
k,
ssize_t
k,
INPUT_TYPE* src,
INPUT_TYPE* src,
$src_strides
$src_strides
//
ga_ssize src_strides_0, ga_ssize
src_strides_1, ... , src_strides_$${NDIM}
//
ssize_t src_strides_0, ssize_t
src_strides_1, ... , src_strides_$${NDIM}
ga_size
size) {
size_t
size) {
LOCAL_MEM ga_
int smem[32 * RADIX_SIZE];
__shared__
int smem[32 * RADIX_SIZE];
LOCAL_MEM ga_
int k2;
__shared__
int k2;
const
ga_uint idx = LID_0
;
const
unsigned int idx = threadIdx.x
;
bool is_topk= (idx < size);
bool is_topk= (idx < size);
bool is_topkth = is_topk;
bool is_topkth = is_topk;
ga_size
out_idx;
size_t
out_idx;
const
ga_ubyte
warp_id = idx / GA_WARP_SIZE;
const
unsigned char
warp_id = idx / GA_WARP_SIZE;
// 0. get the slice for thread block to work on
// 0. get the slice for thread block to work on
ga_size gid = GID_0
, gidx;
size_t gid = blockIdx.x
, gidx;
$set_slice
$set_slice
// $$set_slice expands into:
// $$set_slice expands into:
//for(int i=1; i<NDIM; i++) {
//for(int i=1; i<NDIM; i++) {
...
@@ -55,22 +55,22 @@ KERNEL void k_topk_dense(
...
@@ -55,22 +55,22 @@ KERNEL void k_topk_dense(
#pragma unroll
#pragma unroll
for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
const
ga_
int digit = Bitfield<radix_t>::get(x, i, RADIX_BITS);
const int digit = Bitfield<radix_t>::get(x, i, RADIX_BITS);
/*
ga_
int digit = (x>>i) & (RADIX_SIZE-1);*/
/*int digit = (x>>i) & (RADIX_SIZE-1);*/
// count within warp
// count within warp
#pragma unroll
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
for (int bin=0; bin<RADIX_SIZE; ++bin) {
bool vote = (bin == digit) && is_topkth;
bool vote = (bin == digit) && is_topkth;
ga_u
int votes = __ballot(vote);
unsigned
int votes = __ballot(vote);
if (lane_id()==0)
if (lane_id()==0)
smem[bin + RADIX_SIZE*warp_id] = __popc(votes);
smem[bin + RADIX_SIZE*warp_id] = __popc(votes);
}
}
local_barrier();
local_barrier();
// sum counts across all warps
// sum counts across all warps
if (idx < RADIX_SIZE) {
if (idx < RADIX_SIZE) {
ga_
int sum = smem[idx];
int sum = smem[idx];
#pragma unroll
#pragma unroll
for(int w=RADIX_SIZE; w<
LDIM_0
*RADIX_SIZE / GA_WARP_SIZE; w+=RADIX_SIZE)
for(int w=RADIX_SIZE; w<
blockDim.x
*RADIX_SIZE / GA_WARP_SIZE; w+=RADIX_SIZE)
sum += smem[idx + w];
sum += smem[idx + w];
smem[idx] = sum;
smem[idx] = sum;
}
}
...
@@ -79,7 +79,7 @@ KERNEL void k_topk_dense(
...
@@ -79,7 +79,7 @@ KERNEL void k_topk_dense(
// find the bucket and update k2
// find the bucket and update k2
// smem[:RADIX_SIZE:-1] = k2 - cumsum(smem[:RADIX_SIZE-1:-1])
// smem[:RADIX_SIZE:-1] = k2 - cumsum(smem[:RADIX_SIZE-1:-1])
if (idx == 0) {
if (idx == 0) {
ga_
int sum = k2;
int sum = k2;
#pragma unroll
#pragma unroll
for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
sum -= smem[bin];
sum -= smem[bin];
...
...
theano/gpuarray/c_code/topk_dense_large.cu
浏览文件 @
a3624d6f
...
@@ -14,13 +14,13 @@ __device__ DataType find_pattern(DataType* smem,
...
@@ -14,13 +14,13 @@ __device__ DataType find_pattern(DataType* smem,
CountType stride,
CountType stride,
RadixType known_bits,
RadixType known_bits,
RadixType known_bits_mask) {
RadixType known_bits_mask) {
if (
LID_0
< 32)
if (
threadIdx.x
< 32)
smem[
LID_0
] = 0;
smem[
threadIdx.x
] = 0;
local_barrier();
local_barrier();
// All threads participate in the loop, in order to sync on the flag
// All threads participate in the loop, in order to sync on the flag
for (CountType i =
LID_0; i < (slice_size + (CountType)LDIM_0-1); i += LDIM_0
) {
for (CountType i =
threadIdx.x; i < (slice_size + (CountType)blockDim.x-1); i += blockDim.x
) {
bool in_range = (i < slice_size);
bool in_range = (i < slice_size);
DataType v = in_range ? ptr_read_cached(data, i*stride) : 0;
DataType v = in_range ? ptr_read_cached(data, i*stride) : 0;
...
@@ -64,14 +64,14 @@ __device__ void count_radix_masked(CountType counts[RADIX_SIZE],
...
@@ -64,14 +64,14 @@ __device__ void count_radix_masked(CountType counts[RADIX_SIZE],
for (int i = 0; i < RADIX_SIZE; ++i)
for (int i = 0; i < RADIX_SIZE; ++i)
counts[i] = 0;
counts[i] = 0;
if (
LID_0
< RADIX_SIZE)
if (
threadIdx.x
< RADIX_SIZE)
smem[
LID_0
] = 0;
smem[
threadIdx.x
] = 0;
local_barrier();
local_barrier();
// Scan over all the data. Upon a read, the warp will accumulate
// Scan over all the data. Upon a read, the warp will accumulate
// counts per each digit in the radix using warp voting.
// counts per each digit in the radix using warp voting.
for (CountType i =
LID_0; i < slice_size; i += LDIM_0
) {
for (CountType i =
threadIdx.x; i < slice_size; i += blockDim.x
) {
RadixType val = RadixConfig<DataType>::convert(ptr_read_cached(data, i*stride));
RadixType val = RadixConfig<DataType>::convert(ptr_read_cached(data, i*stride));
bool has_val = ((val & known_bits_mask) == known_bits);
bool has_val = ((val & known_bits_mask) == known_bits);
...
@@ -196,32 +196,32 @@ __device__ void radix_select(DataType* data,
...
@@ -196,32 +196,32 @@ __device__ void radix_select(DataType* data,
KERNEL void KERNEL_NAME(
KERNEL void KERNEL_NAME(
$dims
$dims
//
ga_size dims_1, ga_ssize
dims_2, ... , dims_$${NDIM}
//
size_t dims_1, ssize_t
dims_2, ... , dims_$${NDIM}
$dstv
$dstv
// INPUT_TYPE *dstv
// INPUT_TYPE *dstv
$dstv_strides
$dstv_strides
//
ga_ssize dstv_strides_0, ga_ssize
dstv_strides_1, ... , dstv_strides_$${NDIM}
//
ssize_t dstv_strides_0, ssize_t
dstv_strides_1, ... , dstv_strides_$${NDIM}
$dsti
$dsti
// INDEX_TYPE *dsti
// INDEX_TYPE *dsti
$dsti_strides
$dsti_strides
//
ga_ssize dsti_strides_0, ga_ssize
dsti_strides_1, ... , dsti_strides_$${NDIM}
//
ssize_t dsti_strides_0, ssize_t
dsti_strides_1, ... , dsti_strides_$${NDIM}
ga_ssize
k,
ssize_t
k,
INPUT_TYPE* src,
INPUT_TYPE* src,
$src_strides
$src_strides
//
ga_ssize src_strides_0, ga_ssize
src_strides_1, ... , src_strides_$${NDIM}
//
ssize_t src_strides_0, ssize_t
src_strides_1, ... , src_strides_$${NDIM}
ga_size
size) {
size_t
size) {
LOCAL_MEM
COUNT_TYPE smem[32];
__shared__
COUNT_TYPE smem[32];
INPUT_TYPE topkth_value;
INPUT_TYPE topkth_value;
const bool order = (k>0);
const bool order = (k>0);
k = (order ? k : -k);
k = (order ? k : -k);
const
ga_int idx = LID_0
;
const
int idx = threadIdx.x
;
const
ga_
int warp_id = idx / GA_WARP_SIZE;
const int warp_id = idx / GA_WARP_SIZE;
// get the slice for thread block to work on
// get the slice for thread block to work on
// size <- the axis to work on
// size <- the axis to work on
// dims_1+ <- batched dimensions
// dims_1+ <- batched dimensions
ga_uint gid = GID_0
, gidx;
unsigned int gid = blockIdx.x
, gidx;
$set_slice
$set_slice
// $$set_slice expands into:
// $$set_slice expands into:
//for(int i=1; i<NDIM; i++) {
//for(int i=1; i<NDIM; i++) {
...
@@ -250,10 +250,10 @@ KERNEL void KERNEL_NAME(
...
@@ -250,10 +250,10 @@ KERNEL void KERNEL_NAME(
// All threads need to participate in the loop and the cumsum
// All threads need to participate in the loop and the cumsum
// but not necessarily in the load; hence loop bounds being rounded
// but not necessarily in the load; hence loop bounds being rounded
// up to a multiple of the block dim.
// up to a multiple of the block dim.
COUNT_TYPE iter_bound = size +
LDIM_0
-1;
COUNT_TYPE iter_bound = size +
blockDim.x
-1;
INDEX_TYPE write_base = 0;
INDEX_TYPE write_base = 0;
for (int i = idx; i < iter_bound; i +=
LDIM_0
) {
for (int i = idx; i < iter_bound; i +=
blockDim.x
) {
bool in_range = (i < size);
bool in_range = (i < size);
INPUT_TYPE v = in_range ? ptr_read_cached(src, i*src_strides_0) : 0;
INPUT_TYPE v = in_range ? ptr_read_cached(src, i*src_strides_0) : 0;
bool has_topk;
bool has_topk;
...
@@ -264,7 +264,7 @@ KERNEL void KERNEL_NAME(
...
@@ -264,7 +264,7 @@ KERNEL void KERNEL_NAME(
}
}
int index = binary_cumsum_exclusive(idx, warp_id, smem, has_topk);
int index = binary_cumsum_exclusive(idx, warp_id, smem, has_topk);
int carry = smem[
LDIM_0
/ 32 - 1];
int carry = smem[
blockDim.x
/ 32 - 1];
if (has_topk) {
if (has_topk) {
COUNT_TYPE write_idx = write_base + index;
COUNT_TYPE write_idx = write_base + index;
...
@@ -281,13 +281,13 @@ KERNEL void KERNEL_NAME(
...
@@ -281,13 +281,13 @@ KERNEL void KERNEL_NAME(
COUNT_TYPE topk_remaining = (k - write_base);
COUNT_TYPE topk_remaining = (k - write_base);
for (COUNT_TYPE i = idx; i < iter_bound; i +=
LDIM_0
) {
for (COUNT_TYPE i = idx; i < iter_bound; i +=
blockDim.x
) {
bool in_range = (i < size);
bool in_range = (i < size);
INPUT_TYPE v = in_range ? ptr_read_cached(src, i*src_strides_0) : 0;
INPUT_TYPE v = in_range ? ptr_read_cached(src, i*src_strides_0) : 0;
bool has_topk = in_range && (v == topkth_value);
bool has_topk = in_range && (v == topkth_value);
int index = binary_cumsum_exclusive(idx, warp_id, smem, has_topk);
int index = binary_cumsum_exclusive(idx, warp_id, smem, has_topk);
int carry = smem[
LDIM_0
/ 32 - 1];
int carry = smem[
blockDim.x
/ 32 - 1];
if (has_topk && index < topk_remaining) {
if (has_topk && index < topk_remaining) {
COUNT_TYPE write_idx = write_base + index;
COUNT_TYPE write_idx = write_base + index;
...
...
theano/gpuarray/sort.py
浏览文件 @
a3624d6f
...
@@ -63,9 +63,9 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
...
@@ -63,9 +63,9 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
# prepare "$" macros
# prepare "$" macros
if
device_type
==
b
'cuda'
:
if
device_type
==
b
'cuda'
:
ndim
=
node
.
inputs
[
0
]
.
ndim
ndim
=
node
.
inputs
[
0
]
.
ndim
dstv_strides_code
=
''
.
join
(
'
ga_ssize
dstv_strides_
%
d, '
%
i
for
i
in
range
(
ndim
))
dstv_strides_code
=
''
.
join
(
'
ssize_t
dstv_strides_
%
d, '
%
i
for
i
in
range
(
ndim
))
dsti_strides_code
=
''
.
join
(
'
ga_ssize
dsti_strides_
%
d, '
%
i
for
i
in
range
(
ndim
))
dsti_strides_code
=
''
.
join
(
'
ssize_t
dsti_strides_
%
d, '
%
i
for
i
in
range
(
ndim
))
src_strides_code
=
''
.
join
(
'
ga_ssize
src_strides_
%
d, '
%
i
for
i
in
range
(
ndim
))
src_strides_code
=
''
.
join
(
'
ssize_t
src_strides_
%
d, '
%
i
for
i
in
range
(
ndim
))
set_slice_code
=
'''
set_slice_code
=
'''
gidx = gid
%%
dims_
%(i)
d;
gidx = gid
%%
dims_
%(i)
d;
gid /= dims_
%(i)
d;
gid /= dims_
%(i)
d;
...
@@ -80,7 +80,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
...
@@ -80,7 +80,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
subs
=
dict
(
subs
=
dict
(
inp_t
=
ga
.
dtype_to_ctype
(
node
.
inputs
[
0
]
.
dtype
),
inp_t
=
ga
.
dtype_to_ctype
(
node
.
inputs
[
0
]
.
dtype
),
out_t
=
ga
.
dtype_to_ctype
(
self
.
idx_dtype
),
out_t
=
ga
.
dtype_to_ctype
(
self
.
idx_dtype
),
dims
=
''
.
join
(
'
ga_size
dims_
%
d, '
%
i
for
i
in
range
(
1
,
ndim
)),
dims
=
''
.
join
(
'
size_t
dims_
%
d, '
%
i
for
i
in
range
(
1
,
ndim
)),
dstv
=
'INPUT_TYPE *dstv,'
if
self
.
return_values
else
''
,
dstv
=
'INPUT_TYPE *dstv,'
if
self
.
return_values
else
''
,
dsti
=
'INDEX_TYPE *dsti,'
if
self
.
return_indices
else
''
,
dsti
=
'INDEX_TYPE *dsti,'
if
self
.
return_indices
else
''
,
dstv_strides
=
dstv_strides_code
if
self
.
return_values
else
''
,
dstv_strides
=
dstv_strides_code
if
self
.
return_values
else
''
,
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论