提交 a3624d6f authored 作者: Adam Becker's avatar Adam Becker

CLUDA code -> pure CUDA

上级 62ce362d
...@@ -126,7 +126,7 @@ struct RadixConfig { ...@@ -126,7 +126,7 @@ struct RadixConfig {
// We use this to enable radix selection of floating-point values. // We use this to enable radix selection of floating-point values.
// This also gives a relative order for NaNs, but that's ok, as they // This also gives a relative order for NaNs, but that's ok, as they
// will all be adjacent // will all be adjacent
typedef ga_uint RadixType; typedef unsigned int RadixType;
static inline __device__ RadixType convert(T v) { static inline __device__ RadixType convert(T v) {
return (RadixType)v; return (RadixType)v;
} }
...@@ -137,17 +137,17 @@ struct RadixConfig { ...@@ -137,17 +137,17 @@ struct RadixConfig {
}; };
template <> template <>
struct RadixConfig<ga_float> { struct RadixConfig<float> {
typedef ga_uint RadixType; typedef unsigned int RadixType;
static inline __device__ RadixType convert(ga_float v) { static inline __device__ RadixType convert(float v) {
RadixType x = __float_as_int(v); RadixType x = __float_as_int(v);
RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000; RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
return (x ^ mask); return (x ^ mask);
} }
static inline __device__ ga_float deconvert(RadixType v) { static inline __device__ float deconvert(RadixType v) {
RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff; RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff;
return __int_as_float(v ^ mask); return __int_as_float(v ^ mask);
...@@ -155,16 +155,16 @@ struct RadixConfig<ga_float> { ...@@ -155,16 +155,16 @@ struct RadixConfig<ga_float> {
}; };
template <> template <>
struct RadixConfig<ga_double> { struct RadixConfig<double> {
typedef ga_ulong RadixType; typedef unsigned long long RadixType;
static inline __device__ RadixType convert(ga_double v) { static inline __device__ RadixType convert(double v) {
RadixType x = __double_as_longlong(v); RadixType x = __double_as_longlong(v);
RadixType mask = -((x >> 63)) | 0x8000000000000000; RadixType mask = -((x >> 63)) | 0x8000000000000000;
return (x ^ mask); return (x ^ mask);
} }
static inline __device__ ga_double deconvert(RadixType v) { static inline __device__ double deconvert(RadixType v) {
RadixType mask = ((v >> 63) - 1) | 0x8000000000000000; RadixType mask = ((v >> 63) - 1) | 0x8000000000000000;
return __longlong_as_double(v ^ mask); return __longlong_as_double(v ^ mask);
} }
...@@ -172,52 +172,52 @@ struct RadixConfig<ga_double> { ...@@ -172,52 +172,52 @@ struct RadixConfig<ga_double> {
template <> template <>
struct RadixConfig<ga_byte> { struct RadixConfig<char> {
typedef ga_uint RadixType; typedef unsigned int RadixType;
static inline __device__ RadixType convert(ga_byte v) { static inline __device__ RadixType convert(char v) {
return 128u + v; return 128u + v;
} }
static inline __device__ ga_byte deconvert(RadixType v) { static inline __device__ char deconvert(RadixType v) {
return v - 128; return v - 128;
} }
}; };
template <> template <>
struct RadixConfig<ga_short> { struct RadixConfig<short> {
typedef ga_uint RadixType; typedef unsigned int RadixType;
static inline __device__ RadixType convert(ga_short v) { static inline __device__ RadixType convert(short v) {
assert(sizeof(ga_short) == 2); assert(sizeof(short) == 2);
return 32768u ^ v; return 32768u ^ v;
} }
static inline __device__ ga_short deconvert(RadixType v) { static inline __device__ short deconvert(RadixType v) {
return v - 32768; return v - 32768;
} }
}; };
template <> template <>
struct RadixConfig<ga_int> { struct RadixConfig<int> {
typedef ga_uint RadixType; typedef unsigned int RadixType;
static inline __device__ RadixType convert(ga_int v) { static inline __device__ RadixType convert(int v) {
assert(sizeof(int) == 4); assert(sizeof(int) == 4);
return 2147483648u + v; return 2147483648u + v;
} }
static inline __device__ ga_int deconvert(RadixType v) { static inline __device__ int deconvert(RadixType v) {
return v - 2147483648u; return v - 2147483648u;
} }
}; };
template <> template <>
struct RadixConfig<ga_long> { struct RadixConfig<long long> {
typedef ga_ulong RadixType; typedef unsigned long long RadixType;
static inline __device__ RadixType convert(ga_long v) { static inline __device__ RadixType convert(long long v) {
assert(sizeof(ga_long) == 8); assert(sizeof(long long) == 8);
return 9223372036854775808ull + v; return 9223372036854775808ull + v;
} }
...@@ -229,19 +229,19 @@ struct RadixConfig<ga_long> { ...@@ -229,19 +229,19 @@ struct RadixConfig<ga_long> {
#define USE_HALF $use_half #define USE_HALF $use_half
#if USE_HALF == 1 #if USE_HALF == 1
// since ga_half is ushort, use macro to protect this part is necessary // since half is ushort, using macro to protect this part is necessary
template <> template <>
struct RadixConfig<ga_half> { struct RadixConfig<unsigned short> {
typedef ga_uint RadixType; typedef unsigned int RadixType;
static inline __device__ RadixType convert(ga_half v) { static inline __device__ RadixType convert(unsigned short v) {
RadixType mask = -(((RadixType)v >> 15)) | 0x8000; RadixType mask = -(((RadixType)v >> 15)) | 0x8000;
return (v ^ mask); return (v ^ mask);
} }
static inline __device__ ga_half deconvert(RadixType v) { static inline __device__ unsigned short deconvert(RadixType v) {
RadixType mask = ((v >> 15) - 1) | 0x8000; RadixType mask = ((v >> 15) - 1) | 0x8000;
return (ga_half)(v ^ mask); return (unsigned short)(v ^ mask);
} }
}; };
#endif // USE_HALF #endif // USE_HALF
...@@ -274,7 +274,7 @@ static inline __device__ T binary_cumsum( ...@@ -274,7 +274,7 @@ static inline __device__ T binary_cumsum(
// binary_cumsum(1, 0, 1, 0, 1) -> (1, 1, 2, 2, 3) // binary_cumsum(1, 0, 1, 0, 1) -> (1, 1, 2, 2, 3)
// cumsum within warp // cumsum within warp
ga_uint warp_bits = __ballot(value); unsigned int warp_bits = __ballot(value);
T warp_sum = __popc(lane_mask_le() & warp_bits); T warp_sum = __popc(lane_mask_le() & warp_bits);
if (lane_id() == 0) if (lane_id() == 0)
...@@ -285,7 +285,7 @@ static inline __device__ T binary_cumsum( ...@@ -285,7 +285,7 @@ static inline __device__ T binary_cumsum(
// cumsum across warps in one thread // cumsum across warps in one thread
if (idx == 0) { if (idx == 0) {
T sum = smem[0]; T sum = smem[0];
for (int i = 1; i < LDIM_0 / GA_WARP_SIZE; ++i) { for (int i = 1; i < blockDim.x / GA_WARP_SIZE; ++i) {
sum += smem[i]; sum += smem[i];
smem[i] = sum; smem[i] = sum;
} }
...@@ -309,7 +309,7 @@ static inline __device__ T binary_cumsum_exclusive( ...@@ -309,7 +309,7 @@ static inline __device__ T binary_cumsum_exclusive(
// binary_cumsum_excl(1, 0, 1, 0, 1) -> (0, 1, 1, 2, 2) // binary_cumsum_excl(1, 0, 1, 0, 1) -> (0, 1, 1, 2, 2)
// cumsum within warp // cumsum within warp
ga_uint warp_bits = __ballot(value); unsigned int warp_bits = __ballot(value);
T warp_sum = __popc(lane_mask_lt() & warp_bits); T warp_sum = __popc(lane_mask_lt() & warp_bits);
if (lane_id() == 0) if (lane_id() == 0)
...@@ -320,7 +320,7 @@ static inline __device__ T binary_cumsum_exclusive( ...@@ -320,7 +320,7 @@ static inline __device__ T binary_cumsum_exclusive(
// cumsum across warps in one thread // cumsum across warps in one thread
if (idx == 0) { if (idx == 0) {
T sum = smem[0]; T sum = smem[0];
for (int i = 1; i < LDIM_0 / GA_WARP_SIZE; ++i) { for (int i = 1; i < blockDim.x / GA_WARP_SIZE; ++i) {
sum += smem[i]; sum += smem[i];
smem[i] = sum; smem[i] = sum;
} }
...@@ -337,19 +337,19 @@ static inline __device__ T binary_cumsum_exclusive( ...@@ -337,19 +337,19 @@ static inline __device__ T binary_cumsum_exclusive(
// apply raw(byte) offset to pointer // apply raw(byte) offset to pointer
template <typename T> template <typename T>
static __device__ inline T* ptr_add(T *ptr, ga_ssize offset) { static __device__ inline T* ptr_add(T *ptr, ssize_t offset) {
return (T*)((char*)ptr + offset); return (T*)((char*)ptr + offset);
} }
// get array element using raw(byte) offset // get array element using raw(byte) offset
template <typename T> template <typename T>
static __device__ inline T& ptr_at(T *ptr, ga_ssize offset) { static __device__ inline T& ptr_at(T *ptr, ssize_t offset) {
return *((T*)((char*)ptr + offset)); return *((T*)((char*)ptr + offset));
} }
// read array element using raw(byte) offset // read array element using raw(byte) offset
template <typename T> template <typename T>
static __device__ inline T ptr_read_cached(T *ptr, ga_ssize offset) { static __device__ inline T ptr_read_cached(T *ptr, ssize_t offset) {
return __ldg(((T*)((char*)ptr + offset))); return __ldg(((T*)((char*)ptr + offset)));
} }
...@@ -6,32 +6,32 @@ ...@@ -6,32 +6,32 @@
// works when length on axis is within max allowed threads in block (1024) // works when length on axis is within max allowed threads in block (1024)
KERNEL void k_topk_dense( KERNEL void k_topk_dense(
$dims $dims
// ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM} // size_t dims_1, ssize_t dims_2, ... , dims_$${NDIM}
$dstv $dstv
// INPUT_TYPE *dstv // INPUT_TYPE *dstv
$dstv_strides $dstv_strides
// ga_ssize dstv_strides_0, ga_ssize dstv_strides_1, ... , dstv_strides_$${NDIM} // ssize_t dstv_strides_0, ssize_t dstv_strides_1, ... , dstv_strides_$${NDIM}
$dsti $dsti
// INDEX_TYPE *dsti // INDEX_TYPE *dsti
$dsti_strides $dsti_strides
// ga_ssize dsti_strides_0, ga_ssize dsti_strides_1, ... , dsti_strides_$${NDIM} // ssize_t dsti_strides_0, ssize_t dsti_strides_1, ... , dsti_strides_$${NDIM}
ga_ssize k, ssize_t k,
INPUT_TYPE* src, INPUT_TYPE* src,
$src_strides $src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM} // ssize_t src_strides_0, ssize_t src_strides_1, ... , src_strides_$${NDIM}
ga_size size) { size_t size) {
LOCAL_MEM ga_int smem[32 * RADIX_SIZE]; __shared__ int smem[32 * RADIX_SIZE];
LOCAL_MEM ga_int k2; __shared__ int k2;
const ga_uint idx = LID_0; const unsigned int idx = threadIdx.x;
bool is_topk= (idx < size); bool is_topk= (idx < size);
bool is_topkth = is_topk; bool is_topkth = is_topk;
ga_size out_idx; size_t out_idx;
const ga_ubyte warp_id = idx / GA_WARP_SIZE; const unsigned char warp_id = idx / GA_WARP_SIZE;
// 0. get the slice for thread block to work on // 0. get the slice for thread block to work on
ga_size gid = GID_0, gidx; size_t gid = blockIdx.x, gidx;
$set_slice $set_slice
// $$set_slice expands into: // $$set_slice expands into:
//for(int i=1; i<NDIM; i++) { //for(int i=1; i<NDIM; i++) {
...@@ -55,22 +55,22 @@ KERNEL void k_topk_dense( ...@@ -55,22 +55,22 @@ KERNEL void k_topk_dense(
#pragma unroll #pragma unroll
for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) { for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
const ga_int digit = Bitfield<radix_t>::get(x, i, RADIX_BITS); const int digit = Bitfield<radix_t>::get(x, i, RADIX_BITS);
/*ga_int digit = (x>>i) & (RADIX_SIZE-1);*/ /*int digit = (x>>i) & (RADIX_SIZE-1);*/
// count within warp // count within warp
#pragma unroll #pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) { for (int bin=0; bin<RADIX_SIZE; ++bin) {
bool vote = (bin == digit) && is_topkth; bool vote = (bin == digit) && is_topkth;
ga_uint votes = __ballot(vote); unsigned int votes = __ballot(vote);
if (lane_id()==0) if (lane_id()==0)
smem[bin + RADIX_SIZE*warp_id] = __popc(votes); smem[bin + RADIX_SIZE*warp_id] = __popc(votes);
} }
local_barrier(); local_barrier();
// sum counts across all warps // sum counts across all warps
if (idx < RADIX_SIZE) { if (idx < RADIX_SIZE) {
ga_int sum = smem[idx]; int sum = smem[idx];
#pragma unroll #pragma unroll
for(int w=RADIX_SIZE; w<LDIM_0*RADIX_SIZE / GA_WARP_SIZE; w+=RADIX_SIZE) for(int w=RADIX_SIZE; w<blockDim.x*RADIX_SIZE / GA_WARP_SIZE; w+=RADIX_SIZE)
sum += smem[idx + w]; sum += smem[idx + w];
smem[idx] = sum; smem[idx] = sum;
} }
...@@ -79,7 +79,7 @@ KERNEL void k_topk_dense( ...@@ -79,7 +79,7 @@ KERNEL void k_topk_dense(
// find the bucket and update k2 // find the bucket and update k2
// smem[:RADIX_SIZE:-1] = k2 - cumsum(smem[:RADIX_SIZE-1:-1]) // smem[:RADIX_SIZE:-1] = k2 - cumsum(smem[:RADIX_SIZE-1:-1])
if (idx == 0) { if (idx == 0) {
ga_int sum = k2; int sum = k2;
#pragma unroll #pragma unroll
for (int bin=RADIX_SIZE-1; bin>=0; --bin) { for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
sum -= smem[bin]; sum -= smem[bin];
......
...@@ -14,13 +14,13 @@ __device__ DataType find_pattern(DataType* smem, ...@@ -14,13 +14,13 @@ __device__ DataType find_pattern(DataType* smem,
CountType stride, CountType stride,
RadixType known_bits, RadixType known_bits,
RadixType known_bits_mask) { RadixType known_bits_mask) {
if (LID_0 < 32) if (threadIdx.x < 32)
smem[LID_0] = 0; smem[threadIdx.x] = 0;
local_barrier(); local_barrier();
// All threads participate in the loop, in order to sync on the flag // All threads participate in the loop, in order to sync on the flag
for (CountType i = LID_0; i < (slice_size + (CountType)LDIM_0-1); i += LDIM_0) { for (CountType i = threadIdx.x; i < (slice_size + (CountType)blockDim.x-1); i += blockDim.x) {
bool in_range = (i < slice_size); bool in_range = (i < slice_size);
DataType v = in_range ? ptr_read_cached(data, i*stride) : 0; DataType v = in_range ? ptr_read_cached(data, i*stride) : 0;
...@@ -64,14 +64,14 @@ __device__ void count_radix_masked(CountType counts[RADIX_SIZE], ...@@ -64,14 +64,14 @@ __device__ void count_radix_masked(CountType counts[RADIX_SIZE],
for (int i = 0; i < RADIX_SIZE; ++i) for (int i = 0; i < RADIX_SIZE; ++i)
counts[i] = 0; counts[i] = 0;
if (LID_0 < RADIX_SIZE) if (threadIdx.x < RADIX_SIZE)
smem[LID_0] = 0; smem[threadIdx.x] = 0;
local_barrier(); local_barrier();
// Scan over all the data. Upon a read, the warp will accumulate // Scan over all the data. Upon a read, the warp will accumulate
// counts per each digit in the radix using warp voting. // counts per each digit in the radix using warp voting.
for (CountType i = LID_0; i < slice_size; i += LDIM_0) { for (CountType i = threadIdx.x; i < slice_size; i += blockDim.x) {
RadixType val = RadixConfig<DataType>::convert(ptr_read_cached(data, i*stride)); RadixType val = RadixConfig<DataType>::convert(ptr_read_cached(data, i*stride));
bool has_val = ((val & known_bits_mask) == known_bits); bool has_val = ((val & known_bits_mask) == known_bits);
...@@ -196,32 +196,32 @@ __device__ void radix_select(DataType* data, ...@@ -196,32 +196,32 @@ __device__ void radix_select(DataType* data,
KERNEL void KERNEL_NAME( KERNEL void KERNEL_NAME(
$dims $dims
// ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM} // size_t dims_1, ssize_t dims_2, ... , dims_$${NDIM}
$dstv $dstv
// INPUT_TYPE *dstv // INPUT_TYPE *dstv
$dstv_strides $dstv_strides
// ga_ssize dstv_strides_0, ga_ssize dstv_strides_1, ... , dstv_strides_$${NDIM} // ssize_t dstv_strides_0, ssize_t dstv_strides_1, ... , dstv_strides_$${NDIM}
$dsti $dsti
// INDEX_TYPE *dsti // INDEX_TYPE *dsti
$dsti_strides $dsti_strides
// ga_ssize dsti_strides_0, ga_ssize dsti_strides_1, ... , dsti_strides_$${NDIM} // ssize_t dsti_strides_0, ssize_t dsti_strides_1, ... , dsti_strides_$${NDIM}
ga_ssize k, ssize_t k,
INPUT_TYPE* src, INPUT_TYPE* src,
$src_strides $src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM} // ssize_t src_strides_0, ssize_t src_strides_1, ... , src_strides_$${NDIM}
ga_size size) { size_t size) {
LOCAL_MEM COUNT_TYPE smem[32]; __shared__ COUNT_TYPE smem[32];
INPUT_TYPE topkth_value; INPUT_TYPE topkth_value;
const bool order = (k>0); const bool order = (k>0);
k = (order ? k : -k); k = (order ? k : -k);
const ga_int idx = LID_0; const int idx = threadIdx.x;
const ga_int warp_id = idx / GA_WARP_SIZE; const int warp_id = idx / GA_WARP_SIZE;
// get the slice for thread block to work on // get the slice for thread block to work on
// size <- the axis to work on // size <- the axis to work on
// dims_1+ <- batched dimensions // dims_1+ <- batched dimensions
ga_uint gid = GID_0, gidx; unsigned int gid = blockIdx.x, gidx;
$set_slice $set_slice
// $$set_slice expands into: // $$set_slice expands into:
//for(int i=1; i<NDIM; i++) { //for(int i=1; i<NDIM; i++) {
...@@ -250,10 +250,10 @@ KERNEL void KERNEL_NAME( ...@@ -250,10 +250,10 @@ KERNEL void KERNEL_NAME(
// All threads need to participate in the loop and the cumsum // All threads need to participate in the loop and the cumsum
// but not necessarily in the load; hence loop bounds being rounded // but not necessarily in the load; hence loop bounds being rounded
// up to a multiple of the block dim. // up to a multiple of the block dim.
COUNT_TYPE iter_bound = size + LDIM_0-1; COUNT_TYPE iter_bound = size + blockDim.x-1;
INDEX_TYPE write_base = 0; INDEX_TYPE write_base = 0;
for (int i = idx; i < iter_bound; i += LDIM_0) { for (int i = idx; i < iter_bound; i += blockDim.x) {
bool in_range = (i < size); bool in_range = (i < size);
INPUT_TYPE v = in_range ? ptr_read_cached(src, i*src_strides_0) : 0; INPUT_TYPE v = in_range ? ptr_read_cached(src, i*src_strides_0) : 0;
bool has_topk; bool has_topk;
...@@ -264,7 +264,7 @@ KERNEL void KERNEL_NAME( ...@@ -264,7 +264,7 @@ KERNEL void KERNEL_NAME(
} }
int index = binary_cumsum_exclusive(idx, warp_id, smem, has_topk); int index = binary_cumsum_exclusive(idx, warp_id, smem, has_topk);
int carry = smem[LDIM_0 / 32 - 1]; int carry = smem[blockDim.x / 32 - 1];
if (has_topk) { if (has_topk) {
COUNT_TYPE write_idx = write_base + index; COUNT_TYPE write_idx = write_base + index;
...@@ -281,13 +281,13 @@ KERNEL void KERNEL_NAME( ...@@ -281,13 +281,13 @@ KERNEL void KERNEL_NAME(
COUNT_TYPE topk_remaining = (k - write_base); COUNT_TYPE topk_remaining = (k - write_base);
for (COUNT_TYPE i = idx; i < iter_bound; i += LDIM_0) { for (COUNT_TYPE i = idx; i < iter_bound; i += blockDim.x) {
bool in_range = (i < size); bool in_range = (i < size);
INPUT_TYPE v = in_range ? ptr_read_cached(src, i*src_strides_0) : 0; INPUT_TYPE v = in_range ? ptr_read_cached(src, i*src_strides_0) : 0;
bool has_topk = in_range && (v == topkth_value); bool has_topk = in_range && (v == topkth_value);
int index = binary_cumsum_exclusive(idx, warp_id, smem, has_topk); int index = binary_cumsum_exclusive(idx, warp_id, smem, has_topk);
int carry = smem[LDIM_0 / 32 - 1]; int carry = smem[blockDim.x / 32 - 1];
if (has_topk && index < topk_remaining) { if (has_topk && index < topk_remaining) {
COUNT_TYPE write_idx = write_base + index; COUNT_TYPE write_idx = write_base + index;
......
...@@ -63,9 +63,9 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -63,9 +63,9 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
# prepare "$" macros # prepare "$" macros
if device_type == b'cuda': if device_type == b'cuda':
ndim = node.inputs[0].ndim ndim = node.inputs[0].ndim
dstv_strides_code = ''.join('ga_ssize dstv_strides_%d, ' % i for i in range(ndim)) dstv_strides_code = ''.join('ssize_t dstv_strides_%d, ' % i for i in range(ndim))
dsti_strides_code = ''.join('ga_ssize dsti_strides_%d, ' % i for i in range(ndim)) dsti_strides_code = ''.join('ssize_t dsti_strides_%d, ' % i for i in range(ndim))
src_strides_code = ''.join('ga_ssize src_strides_%d, ' % i for i in range(ndim)) src_strides_code = ''.join('ssize_t src_strides_%d, ' % i for i in range(ndim))
set_slice_code = ''' set_slice_code = '''
gidx = gid %% dims_%(i)d; gidx = gid %% dims_%(i)d;
gid /= dims_%(i)d; gid /= dims_%(i)d;
...@@ -80,7 +80,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -80,7 +80,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
subs = dict( subs = dict(
inp_t=ga.dtype_to_ctype(node.inputs[0].dtype), inp_t=ga.dtype_to_ctype(node.inputs[0].dtype),
out_t=ga.dtype_to_ctype(self.idx_dtype), out_t=ga.dtype_to_ctype(self.idx_dtype),
dims=''.join('ga_size dims_%d, ' % i for i in range(1, ndim)), dims=''.join('size_t dims_%d, ' % i for i in range(1, ndim)),
dstv='INPUT_TYPE *dstv,' if self.return_values else '', dstv='INPUT_TYPE *dstv,' if self.return_values else '',
dsti='INDEX_TYPE *dsti,' if self.return_indices else '', dsti='INDEX_TYPE *dsti,' if self.return_indices else '',
dstv_strides=dstv_strides_code if self.return_values else '', dstv_strides=dstv_strides_code if self.return_values else '',
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论