提交 933cb859 authored 作者: Adam Becker's avatar Adam Becker

Implement new kernel to handle arbitrary shape

上级 bf83d342
...@@ -9,26 +9,52 @@ ...@@ -9,26 +9,52 @@
// will all be adjacent // will all be adjacent
template <typename T> template <typename T>
struct RadixConfig {}; struct RadixConfig {
typedef T RadixType;
static inline __device__ RadixType convert(T v) {
return v;
}
static inline __device__ float deconvert(RadixType v) {
return v;
}
};
template <> template <>
struct RadixConfig<float> { struct RadixConfig<ga_float> {
typedef ga_uint RadixType; typedef ga_uint RadixType;
static inline __device__ RadixType convert(float v) { static inline __device__ RadixType convert(ga_float v) {
RadixType x = __float_as_int(v); RadixType x = __float_as_int(v);
RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000; RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
return (x ^ mask); return (x ^ mask);
} }
static inline __device__ float deconvert(RadixType v) { static inline __device__ ga_float deconvert(RadixType v) {
RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff; RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff;
return __int_as_float(v ^ mask); return __int_as_float(v ^ mask);
} }
}; };
template <>
struct RadixConfig<ga_double> {
typedef unsigned long long int RadixType;
static inline __device__ RadixType convert(ga_double v) {
RadixType x = __double_as_longlong(v);
RadixType mask = -((x >> 63)) | 0x8000000000000000;
return (x ^ mask);
}
static inline __device__ ga_double deconvert(RadixType v) {
RadixType mask = ((v >> 63) - 1) | 0x8000000000000000;
return __longlong_as_double(v ^ mask);
}
};
template <> template <>
struct RadixConfig<ga_ubyte> { struct RadixConfig<ga_ubyte> {
typedef ga_uint RadixType; typedef ga_uint RadixType;
...@@ -43,14 +69,14 @@ struct RadixConfig<ga_ubyte> { ...@@ -43,14 +69,14 @@ struct RadixConfig<ga_ubyte> {
}; };
template <> template <>
struct RadixConfig<char> { struct RadixConfig<ga_byte> {
typedef ga_uint RadixType; typedef ga_uint RadixType;
static inline __device__ RadixType convert(char v) { static inline __device__ RadixType convert(ga_byte v) {
return 128u + v; return 128u + v;
} }
static inline __device__ char deconvert(RadixType v) { static inline __device__ ga_byte deconvert(RadixType v) {
return v - 128; return v - 128;
} }
}; };
...@@ -61,7 +87,7 @@ struct RadixConfig<ga_short> { ...@@ -61,7 +87,7 @@ struct RadixConfig<ga_short> {
static inline __device__ RadixType convert(ga_short v) { static inline __device__ RadixType convert(ga_short v) {
assert(sizeof(ga_short) == 2); assert(sizeof(ga_short) == 2);
return 32768u + v; return 32768u ^ v;
} }
static inline __device__ ga_short deconvert(RadixType v) { static inline __device__ ga_short deconvert(RadixType v) {
...@@ -75,45 +101,30 @@ struct RadixConfig<int> { ...@@ -75,45 +101,30 @@ struct RadixConfig<int> {
static inline __device__ RadixType convert(int v) { static inline __device__ RadixType convert(int v) {
assert(sizeof(int) == 4); assert(sizeof(int) == 4);
return 2147483648u + v; return (1u << 31) ^ v;
} }
static inline __device__ int deconvert(RadixType v) { static inline __device__ int deconvert(RadixType v) {
return v - 2147483648u; return (1u << 31) ^ v;
} }
}; };
template <> template <>
struct RadixConfig<long> { struct RadixConfig<ga_long> {
typedef unsigned long long int RadixType; typedef unsigned long long int RadixType;
static inline __device__ RadixType convert(long v) { static inline __device__ RadixType convert(ga_long v) {
assert(sizeof(long) == 8); assert(sizeof(long) == 8);
return 9223372036854775808ull + v; return (1ull << 63) ^ v;
} }
static inline __device__ long deconvert(RadixType v) { static inline __device__ ga_long deconvert(RadixType v) {
return v - 9223372036854775808ull; return (1ull << 63) ^ v;
}
};
template <>
struct RadixConfig<double> {
typedef unsigned long long int RadixType;
static inline __device__ RadixType convert(double v) {
RadixType x = __double_as_longlong(v);
RadixType mask = -((x >> 63)) | 0x8000000000000000;
return (x ^ mask);
}
static inline __device__ double deconvert(RadixType v) {
RadixType mask = ((v >> 63) - 1) | 0x8000000000000000;
return __longlong_as_double(v ^ mask);
} }
}; };
#ifdef USE_HALF #ifdef USE_HALF
// TODO: make this work
template <> template <>
struct RadixConfig<half> { struct RadixConfig<half> {
typedef ga_uint RadixType; typedef ga_uint RadixType;
...@@ -242,135 +253,9 @@ static __device__ inline T& ptr_at(T *ptr, ga_ssize offset) { ...@@ -242,135 +253,9 @@ static __device__ inline T& ptr_at(T *ptr, ga_ssize offset) {
return *((T*)((char*)ptr + offset)); return *((T*)((char*)ptr + offset));
} }
KERNEL void k_topk_dense( // read array element using raw(byte) offset
$dims template <typename T>
// ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM} static __device__ inline T ptr_read(T *ptr, ga_ssize offset) {
$dstv return __ldg(((T*)((char*)ptr + offset)));
// INPUT_TYPE *dstv
$dstv_strides
// ga_ssize dstv_strides_0, ga_ssize dstv_strides_1, ... , dstv_strides_$${NDIM}
$dsti
// INDEX_TYPE *dsti
$dsti_strides
// ga_ssize dsti_strides_0, ga_ssize dsti_strides_1, ... , dsti_strides_$${NDIM}
ga_ssize k,
INPUT_TYPE* src,
$src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
ga_size size) {
LOCAL_MEM radix_t smem[32 * RADIX_SIZE];
ga_ssize LOCAL_MEM bins[RADIX_SIZE]; // TODO: does using 32-bit gives speedup?
bool is_topk=true, is_topkth=true;
radix_t out_idx;
const ga_size idx = LID_0;
ga_size LOCAL_MEM k2, exceed;
const ga_uint warp_id = idx / GA_WARP_SIZE;
const ga_uint lane_id = idx % GA_WARP_SIZE;
const bool in_range = (idx < size);
is_topk &= in_range;
// 0. get the slice for thread block to work on
// TODO if ndim <= 3, use native indexing ? (blockIdx.[xyz])
ga_size gid = GID_0, gidx;
$set_slice
//for(int i=1; i<NDIM; i++) {
// gidx = gid % dims_$${i};
// gid /= dims_$${i};
// dsti = ptr_add(dsti, gidx*dsti_strides_$${i};
// dstv = ptr_add(dstv, gidx*dstv_strides_$${i};
// src = ptr_add(src, gidx*src_strides_$${i});
//}
// get input and its radix friendly form
const INPUT_TYPE xval = in_range ? ptr_at(src, idx*src_strides_0) : (INPUT_TYPE)0;
radix_t x = in_range ? RadixConfig<INPUT_TYPE>::convert(xval) : 0;
// resolve negative k
if (k<0) { x = ~x; k = -k; }
if (idx==0) k2 = k;
// 1. filter is_topk and is_topkth using radix select
#pragma unroll
for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
int digit = (x>>i) & (RADIX_SIZE-1);
// count within warp
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
bool incr_bin = (bin == digit) && is_topkth && in_range;
ga_uint incr_bin_warp = __ballot(incr_bin);
if (lane_id==0)
smem[bin + RADIX_SIZE*warp_id] = __popc(incr_bin_warp);
}
local_barrier();
// sum counts across all warps
// TODO: test in-block parallel sum?
if (idx < RADIX_SIZE) {
for(int w=RADIX_SIZE; w<LDIM_0*RADIX_SIZE / GA_WARP_SIZE; w+=RADIX_SIZE)
smem[idx] += smem[idx + w];
}
local_barrier();
// calculate k minus cumsum(count)
if (idx<RADIX_SIZE)
bins[idx] = 0;
if (idx == 0) {
exceed = k; // how many the number of is_topk exceeds k
bins[RADIX_SIZE-1] = k2 - smem[RADIX_SIZE-1];
if (bins[RADIX_SIZE-1] > 0)
k2 = bins[RADIX_SIZE-1];
else
exceed = min(exceed, bins[RADIX_SIZE-1]);
#pragma unroll
for(int bin=RADIX_SIZE-1; bin; --bin) {
bins[bin-1] = bins[bin] - smem[bin-1];
if (bins[bin-1] > 0)
k2 = bins[bin-1];
else
exceed = min(exceed, bins[bin-1]);
}
}
local_barrier();
// smem -> count
// bins -> k2 - cumsum(count)
if (is_topk && is_topkth) {
ga_ssize icount = bins[digit];
if (icount > 0) {
is_topkth = false;
} else if (icount < 0) {
if (digit+1!=RADIX_SIZE) {
if (bins[digit+1] <= 0) {
is_topk = false;
is_topkth = false;
}
}
}
}
}
// 2. find the index of output array, if exists
if (exceed != 0) {
// top_kth value may not be unique, so we need to
// perform binary cumsum on is_topkth to drop exceeding top-kth values
out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topkth);
is_topk &= (out_idx < exceed);
}
// perform binary cumsum on is_topk to determine the indices to put result
out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topk);
local_barrier();
if (is_topk) {
#if WRITE_VALUE == 1
ptr_at(dstv, out_idx * dstv_strides_0) = xval;
#endif
#if WRITE_INDEX == 1
ptr_at(dsti, out_idx * dsti_strides_0) = (INDEX_TYPE)idx;
#endif
}
} }
// works when length on axis is within max allowed threads in block (1024)
KERNEL void k_topk_dense(
$dims
// ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM}
$dstv
// INPUT_TYPE *dstv
$dstv_strides
// ga_ssize dstv_strides_0, ga_ssize dstv_strides_1, ... , dstv_strides_$${NDIM}
$dsti
// INDEX_TYPE *dsti
$dsti_strides
// ga_ssize dsti_strides_0, ga_ssize dsti_strides_1, ... , dsti_strides_$${NDIM}
ga_ssize k,
INPUT_TYPE* src,
$src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
ga_size size) {
LOCAL_MEM radix_t smem[32 * RADIX_SIZE];
ga_ssize LOCAL_MEM bins[RADIX_SIZE+1]; // TODO: does using 32-bit gives good speedup?
bool is_topk=true, is_topkth=true;
radix_t out_idx;
const ga_ushort idx = LID_0;
ga_size LOCAL_MEM k2, exceed;
const ga_ubyte warp_id = idx / GA_WARP_SIZE;
const ga_ubyte lane_id = idx % GA_WARP_SIZE;
const bool in_range = (idx < size);
is_topk &= in_range;
// 0. get the slice for thread block to work on
// TODO if ndim <= 3, use native indexing ? (blockIdx.[xyz])
ga_size gid = GID_0, gidx;
$set_slice
//for(int i=1; i<NDIM; i++) {
// gidx = gid % dims_$${i};
// gid /= dims_$${i};
// dsti = ptr_add(dsti, gidx*dsti_strides_$${i};
// dstv = ptr_add(dstv, gidx*dstv_strides_$${i};
// src = ptr_add(src, gidx*src_strides_$${i});
//}
// get input and its radix friendly form
const INPUT_TYPE xval = in_range ? ptr_at(src, idx*src_strides_0) : (INPUT_TYPE)0;
radix_t x = in_range ? RadixConfig<INPUT_TYPE>::convert(xval) : 0;
// resolve negative k
if (k<0) { x = ~x; k = -k; }
if (idx==0) {
k2 = k;
bins[RADIX_SIZE] = 1;
}
// 1. filter is_topk and is_topkth using radix select
#pragma unroll
for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
int digit = (x>>i) & (RADIX_SIZE-1);
// count within warp
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
bool incr_bin = (bin == digit) && is_topkth && in_range;
ga_uint incr_bin_warp = __ballot(incr_bin);
if (lane_id==0)
smem[bin + RADIX_SIZE*warp_id] = __popc(incr_bin_warp);
}
local_barrier();
// sum counts across all warps
// TODO: test in-block parallel sum?
if (idx < RADIX_SIZE) {
for(int w=RADIX_SIZE; w<LDIM_0*RADIX_SIZE / GA_WARP_SIZE; w+=RADIX_SIZE)
smem[idx] += smem[idx + w];
}
local_barrier();
// bins = k - cumsum(smem[:RADIX_SIZE])
if (idx == 0) {
bins[RADIX_SIZE-1] = k2 - smem[RADIX_SIZE-1];
if (bins[RADIX_SIZE-1] > 0)
k2 = bins[RADIX_SIZE-1];
#pragma unroll
for (int bin=RADIX_SIZE-1; bin; --bin) {
bins[bin-1] = bins[bin] - smem[bin-1];
if (bins[bin-1] > 0)
k2 = bins[bin-1];
}
}
local_barrier();
// smem -> count
// bins -> k2 - cumsum(count)
if (is_topk && is_topkth) {
ga_ssize icount = bins[digit];
if (icount > 0) {
is_topkth = false;
} else if (bins[digit+1] <= 0) {
is_topk = false;
is_topkth = false;
}
}
}
if (idx==0) {
#pragma unroll
for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
if (bins[bin] <= 0) {
exceed = -bins[bin];
break;
}
}
}
local_barrier();
// 2. find the index of output array, if exists
if (exceed != 0) {
// top_kth value may not be unique, so we need to
// perform binary cumsum on is_topkth to drop exceeding top-kth values
out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topkth);
is_topk &= ((!is_topkth) || out_idx>=exceed);
}
// perform binary cumsum on is_topk to determine the indices to put result
out_idx = binary_cumsum_exclusive<radix_t>(idx, warp_id, lane_id, smem, is_topk);
local_barrier();
if (is_topk) {
#if WRITE_VALUE == 1
ptr_at(dstv, out_idx * dstv_strides_0) = xval;
#endif
#if WRITE_INDEX == 1
ptr_at(dsti, out_idx * dsti_strides_0) = (INDEX_TYPE)idx;
#endif
}
}
// works when length on axis is larger than max allowed threads in block (1024)
KERNEL void k_topk_dense_large(
$dims
// ga_size dims_1, ga_ssize dims_2, ... , dims_$${NDIM}
$dstv
// INPUT_TYPE *dstv
$dstv_strides
// ga_ssize dstv_strides_0, ga_ssize dstv_strides_1, ... , dstv_strides_$${NDIM}
$dsti
// INDEX_TYPE *dsti
$dsti_strides
// ga_ssize dsti_strides_0, ga_ssize dsti_strides_1, ... , dsti_strides_$${NDIM}
ga_ssize k,
INPUT_TYPE* src,
$src_strides
// ga_ssize src_strides_0, ga_ssize src_strides_1, ... , src_strides_$${NDIM}
ga_size size, ga_ushort inp_per_thread) {
LOCAL_MEM radix_t smem[32 * RADIX_SIZE];
LOCAL_MEM radix_t known_bits, known_bits_mask;
radix_t out_idx;
ga_size LOCAL_MEM write_base;
INPUT_TYPE xval;
radix_t x;
ga_int i;
bool in_range, is_topk;
const ga_size idx = LID_0;
ga_size LOCAL_MEM k2;
const ga_ushort warp_id = idx / GA_WARP_SIZE;
const ga_ushort lane_id = idx % GA_WARP_SIZE;
// 0. get the slice for thread block to work on
// TODO if ndim <= 3, use native indexing ? (blockIdx.[xyz])
ga_size gid = GID_0, gidx;
$set_slice
//for(int i=1; i<NDIM; i++) {
// gidx = gid % dims_$${i};
// gid /= dims_$${i};
// dsti = ptr_add(dsti, gidx*dsti_strides_$${i};
// dstv = ptr_add(dstv, gidx*dstv_strides_$${i};
// src = ptr_add(src, gidx*src_strides_$${i});
//}
src = ptr_add(src, idx*inp_per_thread*src_strides_0);
LOCAL_MEM radix_t inv_bits;
if (idx==0) {
known_bits = known_bits_mask = 0;
k2 = abs(k);
inv_bits = (k>=0) ? 0 : (~0);
write_base = 0;
}
if (k<0) { k = -k; }
local_barrier();
// 1. find bits of top-k-th value using radix select
#pragma unroll
for (i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
/*for (i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i*=-1) {*/
if (lane_id == 0) {
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
smem[bin + warp_id*RADIX_SIZE] = 0;
}
}
local_barrier();
for (int j=0; j<inp_per_thread; ++j) {
in_range = (idx*inp_per_thread+j) < size;
xval = in_range ? ptr_read(src, j*src_strides_0) : (INPUT_TYPE)0;
x = inv_bits^RadixConfig<INPUT_TYPE>::convert(xval);
ga_int digit = (int)((x>>i) & (RADIX_SIZE-1));
// count within warp
#pragma unroll
for (int bin=0; bin<RADIX_SIZE; ++bin) {
bool incr_bin = (
(bin == digit) &&
((x&known_bits_mask) == known_bits) &&
in_range);
ga_uint incr_bin_warp = __ballot(incr_bin);
if (lane_id==0)
smem[bin + RADIX_SIZE*warp_id] += __popc(incr_bin_warp);
}
}
local_barrier();
// sum counts across all warps
// TODO: test in-block parallel sum?
if (idx < RADIX_SIZE) {
for(int w=RADIX_SIZE;
w<(LDIM_0/ GA_WARP_SIZE)*RADIX_SIZE;
w+=RADIX_SIZE)
smem[idx] += smem[idx + w];
}
local_barrier();
// update known bits
if (idx==0) {
#pragma unroll
for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
if (smem[bin] >= k2) {
known_bits |= (bin << i);
known_bits_mask |= ((RADIX_SIZE-1) << i);
break;
} else
k2 -= smem[bin];
}
}
local_barrier();
}
/*
if (idx < RADIX_SIZE) {
ptr_at(dstv, idx*dstv_strides_0) = known_bits;
ptr_at(dstv, idx*dstv_strides_0) = smem[idx];
}
return;
*/
// 2. write values smaller than top-kth
for (i=0; i<inp_per_thread; ++i) {
in_range = (idx*inp_per_thread+i) < size;
xval = in_range ? ptr_read(src, i*src_strides_0) : (INPUT_TYPE)0;
x = inv_bits ^ RadixConfig<INPUT_TYPE>::convert(xval);
is_topk = (x > known_bits) && in_range;
out_idx = binary_cumsum<radix_t>(idx, warp_id, lane_id, smem, is_topk);
if (is_topk) {
#if WRITE_VALUE == 1
ptr_at(dstv, (out_idx+write_base-1) * dstv_strides_0) = xval;
#endif
#if WRITE_INDEX == 1
ptr_at(dsti, (out_idx+write_base-1) * dsti_strides_0) = (INDEX_TYPE)(idx*inp_per_thread + i);
#endif
}
local_barrier();
if (idx == blockDim.x - 1)
write_base += out_idx;
local_barrier();
}
// 3. write values equal to top-kth
for (i=0; i<inp_per_thread; ++i) {
in_range = (idx*inp_per_thread+i) < size;
xval = in_range ? ptr_read(src, i*src_strides_0) : (INPUT_TYPE)0;
x = inv_bits ^ RadixConfig<INPUT_TYPE>::convert(xval);
is_topk = (x == known_bits) && in_range;
out_idx = binary_cumsum<radix_t>(idx, warp_id, lane_id, smem, is_topk);
is_topk = ((out_idx+write_base) <= abs(k)) && is_topk;
if (is_topk) {
#if WRITE_VALUE == 1
ptr_at(dstv, (out_idx+write_base-1) * dstv_strides_0) = xval;
#endif
#if WRITE_INDEX == 1
ptr_at(dsti, (out_idx+write_base-1) * dsti_strides_0) = (INDEX_TYPE)(idx*inp_per_thread + i);
#endif
}
local_barrier();
if (idx == blockDim.x - 1)
write_base += out_idx;
local_barrier();
if(write_base >= abs(k))
break;
}
}
差异被折叠。
...@@ -342,20 +342,21 @@ class TopKOp(theano.Op): ...@@ -342,20 +342,21 @@ class TopKOp(theano.Op):
# TODO c_code # TODO c_code
__props__ = ('axis', 'return_values', 'return_indices') __props__ = ('axis', 'return_values', 'return_indices', 'idx_dtype')
def __init__(self, axis=-1, return_indices=False, return_values=True): def __init__(self, axis=-1, return_indices=False, return_values=True, idx_dtype='int64'):
assert isinstance(axis, int) assert isinstance(axis, int)
assert return_indices or return_values assert return_indices or return_values
self.axis = axis self.axis = axis
self.return_indices = return_indices self.return_indices = return_indices
self.return_values = return_values self.return_values = return_values
self.idx_dtype = idx_dtype
def __str__(self): def __str__(self):
return '%(op)s{axis=%(axis)d}' % dict( return '%(op)s{axis=%(axis)d}' % dict(
op=self.__class__.__name__, axis=self.axis) op=self.__class__.__name__, axis=self.axis)
def make_node(self, inp, k, idx_dtype='int64'): def make_node(self, inp, k):
# numpy always uses float64 as output dtype for arg*() routines # numpy always uses float64 as output dtype for arg*() routines
# however, we add this option as memory is more precious on gpu # however, we add this option as memory is more precious on gpu
inp = theano.tensor.as_tensor_variable(inp) inp = theano.tensor.as_tensor_variable(inp)
...@@ -366,7 +367,7 @@ class TopKOp(theano.Op): ...@@ -366,7 +367,7 @@ class TopKOp(theano.Op):
outs.append(inp.type()) outs.append(inp.type())
if self.return_indices: if self.return_indices:
outs.append( outs.append(
theano.tensor.TensorType(dtype=idx_dtype, broadcastable=bcast)()) theano.tensor.TensorType(dtype=self.idx_dtype, broadcastable=bcast)())
return theano.Apply(self, [inp, k], outs) return theano.Apply(self, [inp, k], outs)
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
...@@ -458,18 +459,18 @@ def argtopk(x, k, axis=-1, idx_dtype='int64'): ...@@ -458,18 +459,18 @@ def argtopk(x, k, axis=-1, idx_dtype='int64'):
if axis is None: if axis is None:
x = theano.tensor.flatten(x) x = theano.tensor.flatten(x)
axis = -1 axis = -1
return TopKOp(axis=axis, return_indices=True, return_values=False)(x, k, idx_dtype=idx_dtype) return TopKOp(axis=axis, return_indices=True, return_values=False, idx_dtype=idx_dtype)(x, k)
def topk_and_argtopk(x, k, axis=-1, idx_dtype='int64'): def topk_and_argtopk(x, k, axis=-1, idx_dtype='int64'):
''' """
Returns the results of both topk() and argtopk() in one Op. Returns the results of both topk() and argtopk() in one Op.
See the respective documentation for details. See the respective documentation for details.
''' """
if axis is None: if axis is None:
x = theano.tensor.flatten(x) x = theano.tensor.flatten(x)
axis = -1 axis = -1
return TopKOp(axis=axis, return_indices=True)(x, k, idx_dtype=idx_dtype) return TopKOp(axis=axis, return_indices=True, idx_dtype=idx_dtype)(x, k)
...@@ -21,14 +21,12 @@ _int_dtypes = ( ...@@ -21,14 +21,12 @@ _int_dtypes = (
'int8', 'int16', 'int32', 'int64', 'int8', 'int16', 'int32', 'int64',
'uint8', 'uint16', 'uint32', 'uint64') 'uint8', 'uint16', 'uint32', 'uint64')
def gen_unique_vector(size, dtype): def gen_unique_vector(size, dtype):
# generate a randomized vector with unique elements # generate a randomized vector with unique elements
retval = np.arange(size*3) + np.random.uniform(-1., 1.) retval = np.arange(size*3) + np.random.uniform(-1., 1.)
return (retval[np.random.permutation(size)] - size*1.5).astype(dtype) return (retval[np.random.permutation(size)] - size*1.5).astype(dtype)
'''
class Test_sort(unittest.TestCase): class Test_sort(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -236,7 +234,6 @@ def test_argsort_grad(): ...@@ -236,7 +234,6 @@ def test_argsort_grad():
data = np.random.rand(2, 3, 3).astype(theano.config.floatX) data = np.random.rand(2, 3, 3).astype(theano.config.floatX)
utt.verify_grad(lambda x: argsort(x, axis=2), [data]) utt.verify_grad(lambda x: argsort(x, axis=2), [data])
'''
class Test_TopK(unittest.TestCase): class Test_TopK(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论