提交 ba48c060 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron 提交者: notoraptor

Fix compilation of GpuTopKOp

上级 2716a8be
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#define RADIX_DIGITS(T) (bitsof(T)/RADIX_BITS) #define RADIX_DIGITS(T) (bitsof(T)/RADIX_BITS)
// works when length on axis is within max allowed threads in block (1024) // works when length on axis is within max allowed threads in block (1024)
KERNEL void k_topk_dense( extern "C" __global__ void k_topk_dense(
$dims $dims
// size_t dims_1, ssize_t dims_2, ... , dims_$${NDIM} // size_t dims_1, ssize_t dims_2, ... , dims_$${NDIM}
$dstv $dstv
......
...@@ -194,7 +194,7 @@ __device__ void radix_select(DataType* data, ...@@ -194,7 +194,7 @@ __device__ void radix_select(DataType* data,
*top_kth = RadixConfig<DataType>::deconvert(known_bits); *top_kth = RadixConfig<DataType>::deconvert(known_bits);
} }
KERNEL void KERNEL_NAME( extern "C" __global__ void KERNEL_NAME(
$dims $dims
// size_t dims_1, ssize_t dims_2, ... , dims_$${NDIM} // size_t dims_1, ssize_t dims_2, ... , dims_$${NDIM}
$dstv $dstv
......
...@@ -53,7 +53,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -53,7 +53,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
pygpu.get_include()] pygpu.get_include()]
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
def gpu_kernels(self, node, nodename): def gpu_kernels(self, node, nodename):
# load kernel source # load kernel source
...@@ -120,7 +120,8 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -120,7 +120,8 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
) as f: ) as f:
kernel_src = f.read() kernel_src = f.read()
ker = Kernel( ker = Kernel(
code=Template(common_src + kernel_src).substitute(**subs), code=("#include <cluda.h>\n" +
Template(common_src + kernel_src).substitute(**subs)),
name=kname, name=kname,
params=param_types, params=param_types,
flags=flags, flags=flags,
...@@ -159,7 +160,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -159,7 +160,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
ctx = sub['params'] ctx = sub['params']
k_dtype = node.inputs[1].type.dtype_specs()[1] k_dtype = node.inputs[1].type.dtype_specs()[1]
# max threads per block # max threads per block
MAX_TPB = context.maxlsize MAX_TPB = context.maxlsize0
# max blocks per grid # max blocks per grid
MAX_BPG = context.maxgsize0 MAX_BPG = context.maxgsize0
WARP_SIZE = 32 WARP_SIZE = 32
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论