提交 dfce51f6 authored 作者: Adam Becker's avatar Adam Becker

address most Fred's review

上级 4de386df
...@@ -19,12 +19,7 @@ except ImportError as e: ...@@ -19,12 +19,7 @@ except ImportError as e:
# To make sure theano is importable # To make sure theano is importable
pass pass
# TODO sort / argsort # TODO GPU sort / argsort
# TODO add runtime opt, if k==1, use max/min reduce
# also if k is axis size, just copy input tensor
# TODO add opt to merge argtopk / topk, or split topk_and_argtopk when only
# one result is needed
class GpuTopKOp(GpuKernelBase, TopKOp): class GpuTopKOp(GpuKernelBase, TopKOp):
...@@ -141,8 +136,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -141,8 +136,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
inp_dtc = ga.dtype_to_typecode(node.inputs[0].dtype) inp_dtc = ga.dtype_to_typecode(node.inputs[0].dtype)
if not self.return_indices: if not self.return_indices:
yv, = outs yv, = outs
else: elif self.return_values:
if self.return_values:
yv, yi = outs yv, yi = outs
else: else:
yi, = outs yi, = outs
...@@ -285,7 +279,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -285,7 +279,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
@register_opt('fast_compile') @register_opt('fast_compile')
@op_lifter([TopKOp]) @op_lifter([TopKOp], cuda_only=True)
@register_opt2([TopKOp], 'fast_compile') @register_opt2([TopKOp], 'fast_compile')
def local_gpua_topkop(op, ctx_name, inputs, outputs): def local_gpua_topkop(op, ctx_name, inputs, outputs):
if isinstance(op, GpuTopKOp): if isinstance(op, GpuTopKOp):
......
...@@ -221,19 +221,19 @@ def argsort(a, axis=-1, kind='quicksort', order=None): ...@@ -221,19 +221,19 @@ def argsort(a, axis=-1, kind='quicksort', order=None):
return ArgSortOp(kind, order)(a, axis) return ArgSortOp(kind, order)(a, axis)
if hasattr(np, 'argpartition'): def _topk_py_impl(op, x, k, axis, idx_dtype):
# numpy >= 1.8 implementation
def _topk_py_impl(op, x, k, axis, idx_dtype):
ndim = x.ndim ndim = x.ndim
if k == 0:
raise ValueError('topk: k cannot be zero')
if abs(k) == 1: if abs(k) == 1:
i = (k + 1) // 2 # negative k means min instead of max
fn_max = [np.min, np.max][i] fn_max = [None, np.max, np.min][k]
fn_argmax = [np.argmin, np.argmax][i] fn_argmax = [None, np.argmax, np.argmin][k]
if not op.return_indices: if not op.return_indices:
return np.expand_dims(fn_max(x, axis=axis), axis) return np.expand_dims(fn_max(x, axis=axis), axis)
elif op.return_values: elif op.return_values:
zi = np.expand_dims( zi = np.expand_dims(
fn_argmax(x, axis=axis).astype(idx_dtype), axis) fn_argmax(x, axis=axis), axis)
idx2 = tuple( idx2 = tuple(
np.arange(s).reshape( np.arange(s).reshape(
(s,) + (1,) * (ndim - i - 1) (s,) + (1,) * (ndim - i - 1)
...@@ -242,7 +242,7 @@ if hasattr(np, 'argpartition'): ...@@ -242,7 +242,7 @@ if hasattr(np, 'argpartition'):
return zv, zi.astype(idx_dtype) return zv, zi.astype(idx_dtype)
else: else:
zi = np.expand_dims( zi = np.expand_dims(
fn_argmax(x, axis=axis).astype(idx_dtype), axis) fn_argmax(x, axis=axis), axis)
return zi.astype(idx_dtype) return zi.astype(idx_dtype)
asize = x.shape[axis] asize = x.shape[axis]
...@@ -263,12 +263,8 @@ if hasattr(np, 'argpartition'): ...@@ -263,12 +263,8 @@ if hasattr(np, 'argpartition'):
return zi return zi
idx = [slice(None)] * ndim idx = [slice(None)] * ndim
if k > 0: idx[axis] = slice(-k, None) if k > 0 else idx[axis] = slice(-k)
idx[axis] = slice(-k, None)
elif k < 0:
idx[axis] = slice(-k)
else:
raise ValueError('k cannot be zero')
if not op.return_indices: if not op.return_indices:
zv = np.partition(x, -k, axis=axis)[idx] zv = np.partition(x, -k, axis=axis)[idx]
return zv return zv
...@@ -282,11 +278,7 @@ if hasattr(np, 'argpartition'): ...@@ -282,11 +278,7 @@ if hasattr(np, 'argpartition'):
return zv, zi.astype(idx_dtype) return zv, zi.astype(idx_dtype)
else: else:
zi = np.argpartition(x, -k, axis=axis)[idx] zi = np.argpartition(x, -k, axis=axis)[idx]
return zi return zi.astype(idx_dtype)
else:
def _topk_py_impl(op, x, k, axis, idx_dtype):
# TODO better compatibility?
raise NotImplementedError('TopKOp: need numpy.argpartition() method (numpy >= 1.8)')
class TopKOp(theano.Op): class TopKOp(theano.Op):
...@@ -296,6 +288,7 @@ class TopKOp(theano.Op): ...@@ -296,6 +288,7 @@ class TopKOp(theano.Op):
Parameters Parameters
---------- ----------
axis: integer axis: integer
Defaults to ``-1``.
The axis to perform the operation. Must be in range ``[-ndim, ndim)``, where The axis to perform the operation. Must be in range ``[-ndim, ndim)``, where
``ndim`` is the dimensionality of input tensor. ``ndim`` is the dimensionality of input tensor.
...@@ -306,7 +299,13 @@ class TopKOp(theano.Op): ...@@ -306,7 +299,13 @@ class TopKOp(theano.Op):
Notes Notes
----- -----
- By default, this Op give two outputs: values and indices. However optimizer may - By default, this Op give two outputs: values and indices. However optimizer may
remove a certain output if not needed for computing graph outputs. remove a certain output if not needed.
- Computing gradient is only possible when both values and indices are computed in
forward pass.
- If the top-k-th value is not unique, we cannot guarantee the output indices being
deterministically chosen.
See Also See Also
-------- --------
...@@ -326,11 +325,17 @@ class TopKOp(theano.Op): ...@@ -326,11 +325,17 @@ class TopKOp(theano.Op):
only_top_kth: bool only_top_kth: bool
Defaults to ``False`` Defaults to ``False``
If ``True``, will only find the exact top k-th element. The Op behaves If ``True``, will only find one exact top k-th element on given axis.
like a reduction.
''' '''
# TODO c_code # TODO c_code
# TODO add opt, if k==1, use max/min reduce
# also if k is axis size, just copy input tensor
# TODO add opt to merge argtopk / topk, or split topk_and_argtopk when only
# one result is needed
# TODO R_op
__props__ = ('axis', 'return_values', 'return_indices', 'idx_dtype') __props__ = ('axis', 'return_values', 'return_indices', 'idx_dtype')
...@@ -338,7 +343,12 @@ class TopKOp(theano.Op): ...@@ -338,7 +343,12 @@ class TopKOp(theano.Op):
self, self,
axis=-1, axis=-1,
idx_dtype='int64'): idx_dtype='int64'):
assert isinstance(axis, int) if not isinstance(axis, int):
raise TypeError(
'"axis" parameter must be integer, got "%s"' % type(self.axis))
if idx_dtype not in theano.tensor.integer_dtypes:
raise TypeError(
'"idx_dtype" parameter must be an integer dtype, got "%s"' % idx_dtype)
self.axis = axis self.axis = axis
self.return_indices = True self.return_indices = True
self.return_values = True self.return_values = True
...@@ -349,9 +359,17 @@ class TopKOp(theano.Op): ...@@ -349,9 +359,17 @@ class TopKOp(theano.Op):
op=self.__class__.__name__, axis=self.axis) op=self.__class__.__name__, axis=self.axis)
def make_node(self, inp, k): def make_node(self, inp, k):
# numpy always uses float64 as output dtype for arg*() routines # numpy always uses int64 as output dtype for arg*() routines
# however, we add this option as memory is more precious on gpu # however, we add this option as memory is more precious on gpu
inp = theano.tensor.as_tensor_variable(inp) inp = theano.tensor.as_tensor_variable(inp)
ndim = inp.ndim
if ndim == 0:
raise ValueError('Cannot take scalar as input')
if not -ndim <= self.axis < ndim:
raise IndexError(
'"axis" parameter out of range,'
' expected integer within [%d, %d]' % (-ndim, ndim - 1))
k = theano.tensor.as_tensor_variable(k) k = theano.tensor.as_tensor_variable(k)
bcast = inp.type.broadcastable bcast = inp.type.broadcastable
outs = [] outs = []
...@@ -383,16 +401,7 @@ class TopKOp(theano.Op): ...@@ -383,16 +401,7 @@ class TopKOp(theano.Op):
def infer_shape(self, node, inp_shapes): def infer_shape(self, node, inp_shapes):
_check_tensor_is_scalar(node.inputs[1]) _check_tensor_is_scalar(node.inputs[1])
shp = list(inp_shapes[0]) shp = list(inp_shapes[0])
if not isinstance(self.axis, int):
raise TypeError(
'"axis" parameter must be integer, got "%s"' % type(self.axis))
ndim = node.inputs[0].ndim ndim = node.inputs[0].ndim
if ndim == 0:
raise ValueError('Cannot take 0d tensor as input')
if not -ndim <= self.axis < ndim:
raise IndexError(
'"axis" parameter out of range,'
' expected integer within [%d, %d]' % (-ndim, ndim - 1))
shp[self.axis] = np.abs(node.inputs[1]) shp[self.axis] = np.abs(node.inputs[1])
shp = tuple(shp) shp = tuple(shp)
return [shp for i in [self.return_values, self.return_indices] if i] return [shp for i in [self.return_values, self.return_indices] if i]
...@@ -437,8 +446,8 @@ def topk(x, k, axis=-1, idx_dtype='int64'): ...@@ -437,8 +446,8 @@ def topk(x, k, axis=-1, idx_dtype='int64'):
Must not be 0. If negative, gives k-smallest elements instead. Must not be 0. If negative, gives k-smallest elements instead.
axis: integer or ``None`` axis: integer or ``None``
Upon which axis shall the operation be performed on. If ``None``, Upon which axis shall the operation be performed on.
works on flattened array. If ``None``, works on flattened array.
idx_dtype: string idx_dtype: string
Specify output dtype used in indices, defaults to ``int64``, must be integer type. Specify output dtype used in indices, defaults to ``int64``, must be integer type.
...@@ -450,7 +459,7 @@ def topk(x, k, axis=-1, idx_dtype='int64'): ...@@ -450,7 +459,7 @@ def topk(x, k, axis=-1, idx_dtype='int64'):
Notes Notes
----- -----
- The returned values may not be sorted. - The returned values are not sorted.
""" """
if axis is None: if axis is None:
...@@ -471,9 +480,9 @@ def argtopk(x, k, axis=-1, idx_dtype='int64'): ...@@ -471,9 +480,9 @@ def argtopk(x, k, axis=-1, idx_dtype='int64'):
k: integer constant/variable k: integer constant/variable
Must not be 0. If negative, gives k-smallest elements instead. Must not be 0. If negative, gives k-smallest elements instead.
axis: integer or ``None`` axis: integer, tuple/list of integers, or ``None``
Upon which axis shall the operation be performed on. If ``None``, Upon which axis shall the operation be performed on.
works on flattened array. If ``None``, works on flattened array.
idx_dtype: string idx_dtype: string
Specify output dtype, defaults to ``int64``, must be integer type. Specify output dtype, defaults to ``int64``, must be integer type.
...@@ -484,7 +493,10 @@ def argtopk(x, k, axis=-1, idx_dtype='int64'): ...@@ -484,7 +493,10 @@ def argtopk(x, k, axis=-1, idx_dtype='int64'):
Notes Notes
----- -----
- The corresponding values of returned indices may not be sorted. - The corresponding values of returned indices are not sorted.
- If the top-k-th value is not unique, we cannot guarantee the output
indices are deterministically chosen.
""" """
if axis is None: if axis is None:
...@@ -501,6 +513,10 @@ def topk_and_argtopk(x, k, axis=-1, idx_dtype='int64'): ...@@ -501,6 +513,10 @@ def topk_and_argtopk(x, k, axis=-1, idx_dtype='int64'):
See the respective documentation for details. See the respective documentation for details.
Returns
-------
tuple: (values, indices)
""" """
if axis is None: if axis is None:
x = theano.tensor.flatten(x) x = theano.tensor.flatten(x)
......
...@@ -13,13 +13,7 @@ from theano.tensor.sort import sort, SortOp ...@@ -13,13 +13,7 @@ from theano.tensor.sort import sort, SortOp
from theano.tensor.sort import argsort, ArgSortOp from theano.tensor.sort import argsort, ArgSortOp
from theano.tensor.sort import topk, argtopk, topk_and_argtopk, TopKOp from theano.tensor.sort import topk, argtopk, topk_and_argtopk, TopKOp
_dtypes = ( _all_dtypes = tensor.integer_dtypes + tensor.float_dtypes
'float32', 'float64',
'int8', 'int16', 'int32', 'int64',
'uint8', 'uint16', 'uint32', 'uint64')
_int_dtypes = (
'int8', 'int16', 'int32', 'int64',
'uint8', 'uint16', 'uint32', 'uint64')
def gen_unique_vector(size, dtype): def gen_unique_vector(size, dtype):
...@@ -39,7 +33,7 @@ class Test_sort(unittest.TestCase): ...@@ -39,7 +33,7 @@ class Test_sort(unittest.TestCase):
a = tensor.dmatrix() a = tensor.dmatrix()
w = sort(a) w = sort(a)
f = theano.function([a], w) f = theano.function([a], w)
assert np.allclose(f(self.m_val), np.sort(self.m_val)) assert utt.assert_allclose(f(self.m_val), np.sort(self.m_val))
def test2(self): def test2(self):
a = tensor.dmatrix() a = tensor.dmatrix()
...@@ -49,7 +43,7 @@ class Test_sort(unittest.TestCase): ...@@ -49,7 +43,7 @@ class Test_sort(unittest.TestCase):
for axis_val in 0, 1: for axis_val in 0, 1:
gv = f(self.m_val, axis_val) gv = f(self.m_val, axis_val)
gt = np.sort(self.m_val, axis_val) gt = np.sort(self.m_val, axis_val)
assert np.allclose(gv, gt) assert utt.assert_allclose(gv, gt)
def test3(self): def test3(self):
a = tensor.dvector() a = tensor.dvector()
...@@ -57,7 +51,7 @@ class Test_sort(unittest.TestCase): ...@@ -57,7 +51,7 @@ class Test_sort(unittest.TestCase):
f = theano.function([a], w2) f = theano.function([a], w2)
gv = f(self.v_val) gv = f(self.v_val)
gt = np.sort(self.v_val) gt = np.sort(self.v_val)
assert np.allclose(gv, gt) assert utt.assert_allclose(gv, gt)
def test4(self): def test4(self):
a = tensor.dmatrix() a = tensor.dmatrix()
...@@ -67,7 +61,7 @@ class Test_sort(unittest.TestCase): ...@@ -67,7 +61,7 @@ class Test_sort(unittest.TestCase):
for axis_val in 0, 1: for axis_val in 0, 1:
gv = f(self.m_val, axis_val) gv = f(self.m_val, axis_val)
gt = np.sort(self.m_val, axis_val) gt = np.sort(self.m_val, axis_val)
assert np.allclose(gv, gt) assert utt.assert_allclose(gv, gt)
def test5(self): def test5(self):
a1 = SortOp("mergesort", []) a1 = SortOp("mergesort", [])
...@@ -84,7 +78,7 @@ class Test_sort(unittest.TestCase): ...@@ -84,7 +78,7 @@ class Test_sort(unittest.TestCase):
f = theano.function([a], l) f = theano.function([a], l)
gv = f(self.m_val) gv = f(self.m_val)
gt = np.sort(self.m_val, None) gt = np.sort(self.m_val, None)
assert np.allclose(gv, gt) assert utt.assert_allclose(gv, gt)
def test_grad_vector(self): def test_grad_vector(self):
data = np.random.rand(10).astype(theano.config.floatX) data = np.random.rand(10).astype(theano.config.floatX)
...@@ -176,7 +170,7 @@ def test_argsort(): ...@@ -176,7 +170,7 @@ def test_argsort():
f = theano.function([a], w) f = theano.function([a], w)
gv = f(m_val) gv = f(m_val)
gt = np.argsort(m_val) gt = np.argsort(m_val)
assert np.allclose(gv, gt) assert utt.assert_allclose(gv, gt)
# Example 2 # Example 2
a = tensor.dmatrix() a = tensor.dmatrix()
...@@ -186,7 +180,7 @@ def test_argsort(): ...@@ -186,7 +180,7 @@ def test_argsort():
for axis_val in 0, 1: for axis_val in 0, 1:
gv = f(m_val, axis_val) gv = f(m_val, axis_val)
gt = np.argsort(m_val, axis_val) gt = np.argsort(m_val, axis_val)
assert np.allclose(gv, gt) assert utt.assert_allclose(gv, gt)
# Example 3 # Example 3
a = tensor.dvector() a = tensor.dvector()
...@@ -194,7 +188,7 @@ def test_argsort(): ...@@ -194,7 +188,7 @@ def test_argsort():
f = theano.function([a], w2) f = theano.function([a], w2)
gv = f(v_val) gv = f(v_val)
gt = np.argsort(v_val) gt = np.argsort(v_val)
assert np.allclose(gv, gt) assert utt.assert_allclose(gv, gt)
# Example 4 # Example 4
a = tensor.dmatrix() a = tensor.dmatrix()
...@@ -204,7 +198,7 @@ def test_argsort(): ...@@ -204,7 +198,7 @@ def test_argsort():
for axis_val in 0, 1: for axis_val in 0, 1:
gv = f(m_val, axis_val) gv = f(m_val, axis_val)
gt = np.argsort(m_val, axis_val) gt = np.argsort(m_val, axis_val)
assert np.allclose(gv, gt) assert utt.assert_allclose(gv, gt)
# Example 5 # Example 5
a = tensor.dmatrix() a = tensor.dmatrix()
...@@ -222,7 +216,7 @@ def test_argsort(): ...@@ -222,7 +216,7 @@ def test_argsort():
f = theano.function([a], w2) f = theano.function([a], w2)
gv = f(m_val) gv = f(m_val)
gt = np.argsort(m_val, None) gt = np.argsort(m_val, None)
assert np.allclose(gv, gt) assert utt.assert_allclose(gv, gt)
def test_argsort_grad(): def test_argsort_grad():
...@@ -243,7 +237,7 @@ class Test_TopK(unittest.TestCase): ...@@ -243,7 +237,7 @@ class Test_TopK(unittest.TestCase):
pass pass
@utt.parameterized.expand(product( @utt.parameterized.expand(product(
_dtypes, _int_dtypes, [-1, 0, None])) _all_dtypes, tensor.integer_dtypes, [-1, 0, None]))
def test_argtopk_sanity(self, dtype, idx_dtype, axis): def test_argtopk_sanity(self, dtype, idx_dtype, axis):
x = tensor.vector(name='x', dtype=dtype) x = tensor.vector(name='x', dtype=dtype)
fn = theano.function([x], argtopk(x, 1, axis=axis, idx_dtype=idx_dtype)) fn = theano.function([x], argtopk(x, 1, axis=axis, idx_dtype=idx_dtype))
...@@ -253,7 +247,7 @@ class Test_TopK(unittest.TestCase): ...@@ -253,7 +247,7 @@ class Test_TopK(unittest.TestCase):
assert yval.dtype == np.dtype(idx_dtype) assert yval.dtype == np.dtype(idx_dtype)
@utt.parameterized.expand(product( @utt.parameterized.expand(product(
_dtypes, [-1, 0, None])) _all_dtypes, [-1, 0, None]))
def test_topk_sanity(self, dtype, axis): def test_topk_sanity(self, dtype, axis):
x = tensor.vector(name='x', dtype=dtype) x = tensor.vector(name='x', dtype=dtype)
fn = theano.function([x], topk(x, 1, axis=axis)) fn = theano.function([x], topk(x, 1, axis=axis))
...@@ -263,7 +257,7 @@ class Test_TopK(unittest.TestCase): ...@@ -263,7 +257,7 @@ class Test_TopK(unittest.TestCase):
assert yval.dtype == xval.dtype assert yval.dtype == xval.dtype
@utt.parameterized.expand(product( @utt.parameterized.expand(product(
_dtypes, _int_dtypes, [-1, 0, None])) _all_dtypes, tensor.integer_dtypes, [-1, 0, None]))
def test_combined_sanity(self, dtype, idx_dtype, axis): def test_combined_sanity(self, dtype, idx_dtype, axis):
x = tensor.vector(name='x', dtype=dtype) x = tensor.vector(name='x', dtype=dtype)
yv, yi = topk_and_argtopk(x, 1, axis=axis, idx_dtype=idx_dtype) yv, yi = topk_and_argtopk(x, 1, axis=axis, idx_dtype=idx_dtype)
...@@ -271,14 +265,14 @@ class Test_TopK(unittest.TestCase): ...@@ -271,14 +265,14 @@ class Test_TopK(unittest.TestCase):
xval = np.asarray([1]).astype(dtype) xval = np.asarray([1]).astype(dtype)
yvval, yival = fn(xval) yvval, yival = fn(xval)
assert yival == np.asarray([0], dtype=idx_dtype) assert yival == np.asarray([0], dtype=idx_dtype)
assert np.allclose(xval, yvval) assert utt.assert_allclose(xval, yvval)
assert yvval.dtype == xval.dtype assert yvval.dtype == xval.dtype
assert yival.dtype == np.dtype(idx_dtype) assert yival.dtype == np.dtype(idx_dtype)
@utt.parameterized.expand(chain( @utt.parameterized.expand(chain(
product( product(
(16, 61, 257), (16, 61, 257),
(1, -1, 10, -10, 'n//2', 'n-1', '-n', '1-n'), (1, -1, -10, 'n//2', 'n-1', '-n', '1-n'),
('float64', 'float16', 'int16', 'int8')), ('float64', 'float16', 'int16', 'int8')),
((2049, 1337, 'float64'),))) ((2049, 1337, 'float64'),)))
def test_topk_1d(self, size, k, dtype): def test_topk_1d(self, size, k, dtype):
...@@ -297,7 +291,7 @@ class Test_TopK(unittest.TestCase): ...@@ -297,7 +291,7 @@ class Test_TopK(unittest.TestCase):
print(np.sort(yval)) print(np.sort(yval))
print(goal) print(goal)
assert yval.dtype == goal.dtype assert yval.dtype == goal.dtype
assert np.allclose(np.sort(yval), goal) assert utt.assert_allclose(np.sort(yval), goal)
@utt.parameterized.expand(chain( @utt.parameterized.expand(chain(
product( product(
...@@ -345,7 +339,7 @@ class Test_TopK(unittest.TestCase): ...@@ -345,7 +339,7 @@ class Test_TopK(unittest.TestCase):
# due to uniqueness, we expect indices same # due to uniqueness, we expect indices same
assert np.all(xval[np.sort(yival)] == xval[np.sort(goali)]) assert np.all(xval[np.sort(yival)] == xval[np.sort(goali)])
assert np.allclose(np.sort(yvval), goalv) assert utt.assert_allclose(np.sort(yvval), goalv)
@utt.parameterized.expand(chain( @utt.parameterized.expand(chain(
product( product(
...@@ -368,11 +362,11 @@ class Test_TopK(unittest.TestCase): ...@@ -368,11 +362,11 @@ class Test_TopK(unittest.TestCase):
goal = np.argsort(xval)[idx].astype('int32') goal = np.argsort(xval)[idx].astype('int32')
print(goal) print(goal)
print(np.argsort(xval)) print(np.argsort(xval))
assert np.allclose(np.sort(xval[yval]), np.sort(xval[goal])) assert utt.assert_allclose(np.sort(xval[yval]), np.sort(xval[goal]))
@utt.parameterized.expand(product( @utt.parameterized.expand(product(
((1, 1), (2, 3), (17, 15), (15, 17), (11, 7, 5), (2, 3, 5, 7, 11), (2017, 5, 3)), ((17, 15), (2, 3, 5, 7, 11), (2017, 5, 3)),
(1, -1, '(1+n)//2', 'n-1', '-n', '1-n'), (-1, '(1+n)//2', '-n', '1-n'),
('float32', 'int32'), ('float32', 'int32'),
('int32', 'int64'))) ('int32', 'int64')))
def test_argtopk_nd(self, shp, k_, dtype, idx_dtype): def test_argtopk_nd(self, shp, k_, dtype, idx_dtype):
...@@ -410,7 +404,7 @@ class Test_TopK(unittest.TestCase): ...@@ -410,7 +404,7 @@ class Test_TopK(unittest.TestCase):
assert np.all(np.sort(yval, axis=axis) == np.sort(goal, axis=axis)) assert np.all(np.sort(yval, axis=axis) == np.sort(goal, axis=axis))
@utt.parameterized.expand(product( @utt.parameterized.expand(product(
((3,), (257,), (2, 3), (17, 15), (11, 7, 5), (5, 3, 5, 3), (2, 3, 5, 7, 11)), ((257,), (17, 15), (5, 3, 5, 3), (2, 3, 5, 7, 11)),
(1, -1, '(1+n)//2', 'n-1', '-n', '1-n'))) (1, -1, '(1+n)//2', 'n-1', '-n', '1-n')))
def test_grad(self, shp, k_): def test_grad(self, shp, k_):
ndim = len(shp) ndim = len(shp)
...@@ -429,8 +423,8 @@ class Test_TopK(unittest.TestCase): ...@@ -429,8 +423,8 @@ class Test_TopK(unittest.TestCase):
class TopKInferShapeTester(utt.InferShapeTester): class TopKInferShapeTester(utt.InferShapeTester):
@utt.parameterized.expand(product( @utt.parameterized.expand(product(
((2, 3), (15, 17), (11, 7, 5), (2, 3, 5, 7, 11), (2, 4, 3, 1)), ((15, 17), (11, 7, 5), (2, 3, 5, 7, 11), (2, 4, 3, 1)),
(1, -1, '(1+n)//2', 'n-1', '-n', '1-n'))) (1, '(1+n)//2', 'n-1', '-n')))
def test_topk_infer_shape(self, shp, k_): def test_topk_infer_shape(self, shp, k_):
ndim = len(shp) ndim = len(shp)
for axis in range(-ndim, ndim): for axis in range(-ndim, ndim):
...@@ -452,8 +446,8 @@ class TopKInferShapeTester(utt.InferShapeTester): ...@@ -452,8 +446,8 @@ class TopKInferShapeTester(utt.InferShapeTester):
[x], [y], [xval], TopKOp) [x], [y], [xval], TopKOp)
@utt.parameterized.expand(product( @utt.parameterized.expand(product(
((2, 3), (15, 17), (11, 7, 5), (2, 3, 5, 7, 11), (2, 4, 3, 1)), ((15, 17), (11, 7, 5), (2, 3, 5, 7, 11), (2, 4, 3, 1)),
(1, -1, '(1+n)//2', 'n-1', '-n', '1-n'))) (-1, '(1+n)//2', '1-n')))
def test_argtopk_infer_shape(self, shp, k_): def test_argtopk_infer_shape(self, shp, k_):
ndim = len(shp) ndim = len(shp)
for axis in range(-ndim, ndim): for axis in range(-ndim, ndim):
...@@ -476,7 +470,7 @@ class TopKInferShapeTester(utt.InferShapeTester): ...@@ -476,7 +470,7 @@ class TopKInferShapeTester(utt.InferShapeTester):
@utt.parameterized.expand(product( @utt.parameterized.expand(product(
((2, 3), (15, 17), (11, 7, 5), (2, 3, 5, 7, 11), (2, 4, 3, 1)), ((2, 3), (15, 17), (11, 7, 5), (2, 3, 5, 7, 11), (2, 4, 3, 1)),
(1, -1, '(1+n)//2', 'n-1', '-n', '1-n'))) (1, '(1+n)//2', 'n-1', 'n')))
def test_combined_infer_shape(self, shp, k_): def test_combined_infer_shape(self, shp, k_):
ndim = len(shp) ndim = len(shp)
for axis in range(-ndim, ndim): for axis in range(-ndim, ndim):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论