提交 5582a715 authored 作者: Adam Becker's avatar Adam Becker

topk gpu bug fixes

- fixed gpu indexing error in small kernel - make RadixType at least 32 bits
上级 7072d479
...@@ -126,13 +126,13 @@ struct RadixConfig { ...@@ -126,13 +126,13 @@ struct RadixConfig {
// We use this to enable radix selection of floating-point values. // We use this to enable radix selection of floating-point values.
// This also gives a relative order for NaNs, but that's ok, as they // This also gives a relative order for NaNs, but that's ok, as they
// will all be adjacent // will all be adjacent
typedef T RadixType; typedef ga_uint RadixType;
static inline __device__ RadixType convert(T v) { static inline __device__ RadixType convert(T v) {
return v; return (RadixType)v;
} }
static inline __device__ float deconvert(RadixType v) { static inline __device__ float deconvert(RadixType v) {
return v; return (T)v;
} }
}; };
...@@ -173,7 +173,7 @@ struct RadixConfig<ga_double> { ...@@ -173,7 +173,7 @@ struct RadixConfig<ga_double> {
template <> template <>
struct RadixConfig<ga_byte> { struct RadixConfig<ga_byte> {
typedef ga_ubyte RadixType; typedef ga_uint RadixType;
static inline __device__ RadixType convert(ga_byte v) { static inline __device__ RadixType convert(ga_byte v) {
return 128u + v; return 128u + v;
...@@ -186,7 +186,7 @@ struct RadixConfig<ga_byte> { ...@@ -186,7 +186,7 @@ struct RadixConfig<ga_byte> {
template <> template <>
struct RadixConfig<ga_short> { struct RadixConfig<ga_short> {
typedef ga_ushort RadixType; typedef ga_uint RadixType;
static inline __device__ RadixType convert(ga_short v) { static inline __device__ RadixType convert(ga_short v) {
assert(sizeof(ga_short) == 2); assert(sizeof(ga_short) == 2);
...@@ -232,7 +232,7 @@ struct RadixConfig<ga_long> { ...@@ -232,7 +232,7 @@ struct RadixConfig<ga_long> {
// since ga_half is ushort, use macro to protect this part is necessary // since ga_half is ushort, use macro to protect this part is necessary
template <> template <>
struct RadixConfig<ga_half> { struct RadixConfig<ga_half> {
typedef ga_ushort RadixType; typedef ga_uint RadixType;
static inline __device__ RadixType convert(ga_half v) { static inline __device__ RadixType convert(ga_half v) {
RadixType mask = -(((RadixType)v >> 15)) | 0x8000; RadixType mask = -(((RadixType)v >> 15)) | 0x8000;
......
...@@ -100,10 +100,9 @@ KERNEL void k_topk_dense( ...@@ -100,10 +100,9 @@ KERNEL void k_topk_dense(
if (idx==0) { if (idx==0) {
#pragma unroll #pragma unroll
for (int bin=RADIX_SIZE-1; bin>=0; --bin) { for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
if (smem[bin] <= 0) { if (smem[bin] <= 0)
k2 = -smem[bin];
break; break;
} k2 = smem[bin];
} }
} }
local_barrier(); local_barrier();
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论