提交 ba48c060 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron 提交者: notoraptor

Fix compilation of GpuTopKOp

上级 2716a8be
......@@ -4,7 +4,7 @@
#define RADIX_DIGITS(T) (bitsof(T)/RADIX_BITS)
// works when length on axis is within max allowed threads in block (1024)
KERNEL void k_topk_dense(
extern "C" __global__ void k_topk_dense(
$dims
// size_t dims_1, ssize_t dims_2, ... , dims_$${NDIM}
$dstv
......
......@@ -194,7 +194,7 @@ __device__ void radix_select(DataType* data,
*top_kth = RadixConfig<DataType>::deconvert(known_bits);
}
KERNEL void KERNEL_NAME(
extern "C" __global__ void KERNEL_NAME(
$dims
// size_t dims_1, ssize_t dims_2, ... , dims_$${NDIM}
$dstv
......
......@@ -53,7 +53,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
pygpu.get_include()]
def c_code_cache_version(self):
return (1,)
return (2,)
def gpu_kernels(self, node, nodename):
# load kernel source
......@@ -120,7 +120,8 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
) as f:
kernel_src = f.read()
ker = Kernel(
code=Template(common_src + kernel_src).substitute(**subs),
code=("#include <cluda.h>\n" +
Template(common_src + kernel_src).substitute(**subs)),
name=kname,
params=param_types,
flags=flags,
......@@ -159,7 +160,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
ctx = sub['params']
k_dtype = node.inputs[1].type.dtype_specs()[1]
# max threads per block
MAX_TPB = context.maxlsize
MAX_TPB = context.maxlsize0
# max blocks per grid
MAX_BPG = context.maxgsize0
WARP_SIZE = 32
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论