提交 c5194587 authored 作者: Frederic Bastien's avatar Frederic Bastien 提交者: notoraptor

Only use the variable when they exist and are needed.

上级 2fb56225
...@@ -21,7 +21,7 @@ extern "C" __global__ void k_topk_dense( ...@@ -21,7 +21,7 @@ extern "C" __global__ void k_topk_dense(
// ssize_t dsti_strides_0, ssize_t dsti_strides_1, ... , dsti_strides_$${NDIM} // ssize_t dsti_strides_0, ssize_t dsti_strides_1, ... , dsti_strides_$${NDIM}
ssize_t k, ssize_t k,
INPUT_TYPE* src, INPUT_TYPE* src,
size_t src_offset size_t src_offset,
$src_strides $src_strides
// ssize_t src_strides_0, ssize_t src_strides_1, ... , src_strides_$${NDIM} // ssize_t src_strides_0, ssize_t src_strides_1, ... , src_strides_$${NDIM}
size_t size) { size_t size) {
...@@ -33,10 +33,6 @@ extern "C" __global__ void k_topk_dense( ...@@ -33,10 +33,6 @@ extern "C" __global__ void k_topk_dense(
size_t out_idx; size_t out_idx;
const unsigned char warp_id = idx / GA_WARP_SIZE; const unsigned char warp_id = idx / GA_WARP_SIZE;
dstv = ptr_add(dstv, dstv_offset);
dsti = ptr_add(dsti, dsti_offset);
src = ptr_add(src, src_offset);
// 0. get the slice for thread block to work on // 0. get the slice for thread block to work on
size_t gid = blockIdx.x, gidx; size_t gid = blockIdx.x, gidx;
......
...@@ -211,7 +211,7 @@ extern "C" __global__ void KERNEL_NAME( ...@@ -211,7 +211,7 @@ extern "C" __global__ void KERNEL_NAME(
// ssize_t dsti_strides_0, ssize_t dsti_strides_1, ... , dsti_strides_$${NDIM} // ssize_t dsti_strides_0, ssize_t dsti_strides_1, ... , dsti_strides_$${NDIM}
ssize_t k, ssize_t k,
INPUT_TYPE* src, INPUT_TYPE* src,
size_t src_offset size_t src_offset,
$src_strides $src_strides
// ssize_t src_strides_0, ssize_t src_strides_1, ... , src_strides_$${NDIM} // ssize_t src_strides_0, ssize_t src_strides_1, ... , src_strides_$${NDIM}
size_t size) { size_t size) {
...@@ -222,9 +222,6 @@ extern "C" __global__ void KERNEL_NAME( ...@@ -222,9 +222,6 @@ extern "C" __global__ void KERNEL_NAME(
k = (order ? k : -k); k = (order ? k : -k);
const int idx = threadIdx.x; const int idx = threadIdx.x;
const int warp_id = idx / GA_WARP_SIZE; const int warp_id = idx / GA_WARP_SIZE;
dstv = ptr_add(dstv, dstv_offset);
dsti = ptr_add(dsti, dsti_offset);
src = ptr_add(src, src_offset);
// get the slice for thread block to work on // get the slice for thread block to work on
// size <- the axis to work on // size <- the axis to work on
......
...@@ -77,6 +77,17 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -77,6 +77,17 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
dsti='dsti = ptr_add(dsti, gidx*dsti_strides_%(i)d)' if self.return_indices else '') dsti='dsti = ptr_add(dsti, gidx*dsti_strides_%(i)d)' if self.return_indices else '')
set_slice_code = ''.join( set_slice_code = ''.join(
set_slice_code % dict(i=j) for j in range(1, ndim)) set_slice_code % dict(i=j) for j in range(1, ndim))
if self.return_values:
set_slice_code += """
dstv = ptr_add(dstv, dstv_offset);
"""
if self.return_indices:
set_slice_code += """
dsti = ptr_add(dsti, dsti_offset);
"""
set_slice_code += """
src = ptr_add(src, src_offset);
"""
flags = Kernel.get_flags(node.inputs[0].dtype) flags = Kernel.get_flags(node.inputs[0].dtype)
subs = dict( subs = dict(
inp_t=ga.dtype_to_ctype(node.inputs[0].dtype), inp_t=ga.dtype_to_ctype(node.inputs[0].dtype),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论