提交 d95bf06c authored 作者: Adam Becker's avatar Adam Becker

small fixes for topk

上级 d517cd9f
...@@ -3,6 +3,7 @@ import numpy as np ...@@ -3,6 +3,7 @@ import numpy as np
import theano import theano
from theano.tensor.basic import mul, arange from theano.tensor.basic import mul, arange
from theano.gradient import grad_undefined from theano.gradient import grad_undefined
from theano.tensor.subtensor import set_subtensor
def _variable_is_none(var): def _variable_is_none(var):
...@@ -319,9 +320,9 @@ class TopKOp(theano.Op): ...@@ -319,9 +320,9 @@ class TopKOp(theano.Op):
# TODO more params # TODO more params
''' '''
sorted: bool sorted: bool
Defaults to ``False`` Defaults to ``True``
If True, the result array would be incremental-sorted. If True, the result array would be sorted in descending order.
only_top_kth: bool only_top_kth: bool
Defaults to ``False`` Defaults to ``False``
...@@ -335,7 +336,6 @@ class TopKOp(theano.Op): ...@@ -335,7 +336,6 @@ class TopKOp(theano.Op):
# also if k is axis size, just copy input tensor # also if k is axis size, just copy input tensor
# TODO add opt to merge argtopk / topk, or split topk_and_argtopk when only # TODO add opt to merge argtopk / topk, or split topk_and_argtopk when only
# one result is needed # one result is needed
# TODO R_op
__props__ = ('axis', 'return_values', 'return_indices', 'idx_dtype') __props__ = ('axis', 'return_values', 'return_indices', 'idx_dtype')
...@@ -346,6 +346,8 @@ class TopKOp(theano.Op): ...@@ -346,6 +346,8 @@ class TopKOp(theano.Op):
return_values=True, return_values=True,
return_indices=True return_indices=True
): ):
# numpy always uses int64 as output dtype for arg*() routines
# however, we add "idx_dtype" param as memory is more precious on gpu
if not isinstance(axis, int): if not isinstance(axis, int):
raise TypeError( raise TypeError(
'"axis" parameter must be integer, got "%s"' % type(axis)) '"axis" parameter must be integer, got "%s"' % type(axis))
...@@ -366,8 +368,6 @@ class TopKOp(theano.Op): ...@@ -366,8 +368,6 @@ class TopKOp(theano.Op):
op=self.__class__.__name__, axis=self.axis) op=self.__class__.__name__, axis=self.axis)
def make_node(self, inp, kth): def make_node(self, inp, kth):
# numpy always uses int64 as output dtype for arg*() routines
# however, we add this option as memory is more precious on gpu
inp = theano.tensor.as_tensor_variable(inp) inp = theano.tensor.as_tensor_variable(inp)
ndim = inp.ndim ndim = inp.ndim
if ndim == 0: if ndim == 0:
...@@ -378,6 +378,7 @@ class TopKOp(theano.Op): ...@@ -378,6 +378,7 @@ class TopKOp(theano.Op):
' expected integer within [%d, %d]' % (-ndim, ndim - 1)) ' expected integer within [%d, %d]' % (-ndim, ndim - 1))
kth = theano.tensor.as_tensor_variable(kth) kth = theano.tensor.as_tensor_variable(kth)
_check_tensor_is_scalar(kth)
bcast = inp.type.broadcastable bcast = inp.type.broadcastable
outs = [] outs = []
if self.return_values: if self.return_values:
...@@ -403,7 +404,6 @@ class TopKOp(theano.Op): ...@@ -403,7 +404,6 @@ class TopKOp(theano.Op):
pzi[0] = _topk_py_impl(self, x, k, axis, node.outputs[0].dtype) pzi[0] = _topk_py_impl(self, x, k, axis, node.outputs[0].dtype)
def infer_shape(self, node, inp_shapes): def infer_shape(self, node, inp_shapes):
_check_tensor_is_scalar(node.inputs[1])
shp = list(inp_shapes[0]) shp = list(inp_shapes[0])
shp[self.axis] = np.abs(node.inputs[1]) shp[self.axis] = np.abs(node.inputs[1])
shp = tuple(shp) shp = tuple(shp)
...@@ -412,16 +412,11 @@ class TopKOp(theano.Op): ...@@ -412,16 +412,11 @@ class TopKOp(theano.Op):
def L_op(self, inputs, outputs, out_grads): def L_op(self, inputs, outputs, out_grads):
x, k = inputs x, k = inputs
k_grad = grad_undefined(self, 1, k, 'topk: k is not differentiable') k_grad = grad_undefined(self, 1, k, 'topk: k is not differentiable')
if not (self.return_indices, self.return_values):
if not (self.return_indices and self.return_values):
x_grad = grad_undefined( x_grad = grad_undefined(
self, 0, x, 'topk: cannot get gradient' self, 0, x, 'topk: cannot get gradient'
' without both indices and values') ' without both indices and values')
elif x.ndim == 1:
z_grad = out_grads[0]
indices = outputs[-1]
x_grad = x.zeros_like(dtype=z_grad.dtype)
x_grad = theano.tensor.advanced_set_subtensor1(
x_grad, z_grad, indices)
else: else:
x_shp = theano.tensor.shape(x) x_shp = theano.tensor.shape(x)
z_grad = out_grads[0] z_grad = out_grads[0]
...@@ -431,12 +426,12 @@ class TopKOp(theano.Op): ...@@ -431,12 +426,12 @@ class TopKOp(theano.Op):
arange(x_shp[i]).dimshuffle([0] + ['x'] * (ndim - i - 1)) arange(x_shp[i]).dimshuffle([0] + ['x'] * (ndim - i - 1))
if i != axis else outputs[-1] for i in range(ndim)] if i != axis else outputs[-1] for i in range(ndim)]
x_grad = x.zeros_like(dtype=z_grad.dtype) x_grad = x.zeros_like(dtype=z_grad.dtype)
x_grad = theano.tensor.advanced_set_subtensor( x_grad = set_subtensor(x_grad[tuple(grad_indices)], z_grad)
x_grad, z_grad, *grad_indices)
return [x_grad, k_grad] return [x_grad, k_grad]
def topk(x, kth, axis=-1, idx_dtype='int64'): def topk(x, kth, axis=-1, sorted=True, idx_dtype='int64'):
""" """
Returns the k-largest elements along an axis. Returns the k-largest elements along an axis.
...@@ -452,6 +447,11 @@ def topk(x, kth, axis=-1, idx_dtype='int64'): ...@@ -452,6 +447,11 @@ def topk(x, kth, axis=-1, idx_dtype='int64'):
Upon which axis shall the operation be performed on. Upon which axis shall the operation be performed on.
If ``None``, works on flattened array. If ``None``, works on flattened array.
sorted: bool
Defaults to ``True``
If True, the result array would be sorted in descending order.
idx_dtype: string idx_dtype: string
Specify output dtype used in indices, defaults to ``int64``, must be integer type. Specify output dtype used in indices, defaults to ``int64``, must be integer type.
This option is here because indices are needed for gradient. This option is here because indices are needed for gradient.
...@@ -462,16 +462,18 @@ def topk(x, kth, axis=-1, idx_dtype='int64'): ...@@ -462,16 +462,18 @@ def topk(x, kth, axis=-1, idx_dtype='int64'):
Notes Notes
----- -----
- The returned values are not sorted. - ``sorted=True`` is not supported yet.
""" """
if sorted:
raise NotImplementedError("sorted=True is not supported yet.")
if axis is None: if axis is None:
x = theano.tensor.flatten(x) x = theano.tensor.flatten(x)
axis = -1 axis = -1
return TopKOp(axis=axis, idx_dtype=idx_dtype)(x, kth)[0] return TopKOp(axis=axis, idx_dtype=idx_dtype)(x, kth)[0]
def argtopk(x, kth, axis=-1, idx_dtype='int64'): def argtopk(x, kth, axis=-1, sorted=True, idx_dtype='int64'):
""" """
Returns the indices of k-largest elements along an axis. Returns the indices of k-largest elements along an axis.
...@@ -483,6 +485,12 @@ def argtopk(x, kth, axis=-1, idx_dtype='int64'): ...@@ -483,6 +485,12 @@ def argtopk(x, kth, axis=-1, idx_dtype='int64'):
kth: integer constant/variable kth: integer constant/variable
Must not be 0. If negative, gives k-smallest elements instead. Must not be 0. If negative, gives k-smallest elements instead.
sorted: bool
Defaults to ``True``
If True, the result array of corresponding indices would be sorted in descending order.
axis: integer, tuple/list of integers, or ``None`` axis: integer, tuple/list of integers, or ``None``
Upon which axis shall the operation be performed on. Upon which axis shall the operation be performed on.
If ``None``, works on flattened array. If ``None``, works on flattened array.
...@@ -496,21 +504,23 @@ def argtopk(x, kth, axis=-1, idx_dtype='int64'): ...@@ -496,21 +504,23 @@ def argtopk(x, kth, axis=-1, idx_dtype='int64'):
Notes Notes
----- -----
- The corresponding values of returned indices are not sorted. - ``sorted=True`` is not supported yet.
- If the top-k-th value is not unique, we cannot guarantee the output - If the top-k-th value is not unique, we cannot guarantee the output
indices are deterministically chosen. indices are deterministically chosen.
""" """
if sorted:
raise NotImplementedError("sorted=True is not supported yet.")
if axis is None: if axis is None:
x = theano.tensor.flatten(x) x = theano.tensor.flatten(x)
axis = -1 axis = 0
return TopKOp( return TopKOp(
axis=axis, axis=axis,
idx_dtype=idx_dtype)(x, kth)[1] idx_dtype=idx_dtype)(x, kth)[1]
def topk_and_argtopk(x, kth, axis=-1, idx_dtype='int64'): def topk_and_argtopk(x, kth, axis=-1, sorted=True, idx_dtype='int64'):
""" """
Returns the results of both topk() and argtopk() in one Op. Returns the results of both topk() and argtopk() in one Op.
...@@ -521,9 +531,11 @@ def topk_and_argtopk(x, kth, axis=-1, idx_dtype='int64'): ...@@ -521,9 +531,11 @@ def topk_and_argtopk(x, kth, axis=-1, idx_dtype='int64'):
tuple: (values, indices) tuple: (values, indices)
""" """
if sorted:
raise NotImplementedError("sorted=True is not supported yet.")
if axis is None: if axis is None:
x = theano.tensor.flatten(x) x = theano.tensor.flatten(x)
axis = -1 axis = 0
return TopKOp( return TopKOp(
axis=axis, axis=axis,
idx_dtype=idx_dtype)(x, kth) idx_dtype=idx_dtype)(x, kth)
...@@ -237,30 +237,30 @@ class Test_TopK(unittest.TestCase): ...@@ -237,30 +237,30 @@ class Test_TopK(unittest.TestCase):
pass pass
@utt.parameterized.expand(product( @utt.parameterized.expand(product(
_all_dtypes, tensor.integer_dtypes, [-1, 0, None])) _all_dtypes, tensor.integer_dtypes, [-1, 0, None], [False]))
def test_argtopk_sanity(self, dtype, idx_dtype, axis): def test_argtopk_sanity(self, dtype, idx_dtype, axis, sorted):
x = tensor.vector(name='x', dtype=dtype) x = tensor.vector(name='x', dtype=dtype)
fn = theano.function([x], argtopk(x, 1, axis=axis, idx_dtype=idx_dtype)) fn = theano.function([x], argtopk(x, 1, axis=axis, sorted=sorted, idx_dtype=idx_dtype))
xval = np.asarray([1]).astype(dtype) xval = np.asarray([1]).astype(dtype)
yval = fn(xval) yval = fn(xval)
assert yval == np.asarray([0], dtype=idx_dtype) assert yval == np.asarray([0], dtype=idx_dtype)
assert yval.dtype == np.dtype(idx_dtype) assert yval.dtype == np.dtype(idx_dtype)
@utt.parameterized.expand(product( @utt.parameterized.expand(product(
_all_dtypes, [-1, 0, None])) _all_dtypes, [-1, 0, None], [False]))
def test_topk_sanity(self, dtype, axis): def test_topk_sanity(self, dtype, axis, sorted):
x = tensor.vector(name='x', dtype=dtype) x = tensor.vector(name='x', dtype=dtype)
fn = theano.function([x], topk(x, 1, axis=axis)) fn = theano.function([x], topk(x, 1, axis=axis, sorted=sorted))
xval = np.asarray([1]).astype(dtype) xval = np.asarray([1]).astype(dtype)
yval = fn(xval) yval = fn(xval)
assert yval == xval assert yval == xval
assert yval.dtype == xval.dtype assert yval.dtype == xval.dtype
@utt.parameterized.expand(product( @utt.parameterized.expand(product(
_all_dtypes, tensor.integer_dtypes, [-1, 0, None])) _all_dtypes, tensor.integer_dtypes, [-1, 0, None], [False]))
def test_combined_sanity(self, dtype, idx_dtype, axis): def test_combined_sanity(self, dtype, idx_dtype, axis, sorted):
x = tensor.vector(name='x', dtype=dtype) x = tensor.vector(name='x', dtype=dtype)
yv, yi = topk_and_argtopk(x, 1, axis=axis, idx_dtype=idx_dtype) yv, yi = topk_and_argtopk(x, 1, axis=axis, sorted=sorted, idx_dtype=idx_dtype)
fn = theano.function([x], [yv, yi]) fn = theano.function([x], [yv, yi])
xval = np.asarray([1]).astype(dtype) xval = np.asarray([1]).astype(dtype)
yvval, yival = fn(xval) yvval, yival = fn(xval)
...@@ -273,14 +273,15 @@ class Test_TopK(unittest.TestCase): ...@@ -273,14 +273,15 @@ class Test_TopK(unittest.TestCase):
product( product(
(16, 61, 257), (16, 61, 257),
(1, -1, -10, 'n//2', 'n-1', '-n', '1-n'), (1, -1, -10, 'n//2', 'n-1', '-n', '1-n'),
('float64', 'float16', 'int16', 'int8')), ('float64', 'float16', 'int16', 'int8'),
((2049, 1337, 'float64'),))) (False,)),
def test_topk_1d(self, size, k, dtype): ((2049, 1337, 'float64', False),)))
def test_topk_1d(self, size, k, dtype, sorted):
if isinstance(k, str): if isinstance(k, str):
k = eval(k.replace('n', str(size))) k = eval(k.replace('n', str(size)))
x = theano.tensor.vector(name='x', dtype=dtype) x = theano.tensor.vector(name='x', dtype=dtype)
y = topk(x, k) y = topk(x, k, sorted=sorted)
fn = theano.function([x], y) fn = theano.function([x], y)
# generate a all-unique array # generate a all-unique array
xval = gen_unique_vector(size, dtype) xval = gen_unique_vector(size, dtype)
...@@ -296,14 +297,15 @@ class Test_TopK(unittest.TestCase): ...@@ -296,14 +297,15 @@ class Test_TopK(unittest.TestCase):
(16, 61, 257), (16, 61, 257),
(1, -1, -10, 'n//2', 'n-1', '-n'), (1, -1, -10, 'n//2', 'n-1', '-n'),
('float32', 'int32'), ('float32', 'int32'),
(False,),
('int32', 'int64')), ('int32', 'int64')),
((2049, 1337, 'float32', 'int32'),))) ((2049, 1337, 'float32', False, 'int32'),)))
def test_argtopk_1d(self, size, k, dtype, idx_dtype): def test_argtopk_1d(self, size, k, dtype, sorted, idx_dtype):
if isinstance(k, str): if isinstance(k, str):
k = eval(k.replace('n', str(size))) k = eval(k.replace('n', str(size)))
x = theano.tensor.vector(name='x', dtype=dtype) x = theano.tensor.vector(name='x', dtype=dtype)
y = argtopk(x, k, idx_dtype=idx_dtype) y = argtopk(x, k, sorted=sorted, idx_dtype=idx_dtype)
fn = theano.function([x], y) fn = theano.function([x], y)
# generate a all-unique array # generate a all-unique array
xval = gen_unique_vector(size, dtype) xval = gen_unique_vector(size, dtype)
...@@ -319,14 +321,15 @@ class Test_TopK(unittest.TestCase): ...@@ -319,14 +321,15 @@ class Test_TopK(unittest.TestCase):
(16, 61, 257), (16, 61, 257),
(1, -1, 10, 'n//2', 'n-1', '1-n'), (1, -1, 10, 'n//2', 'n-1', '1-n'),
('float32', 'int32'), ('float32', 'int32'),
(False,),
('int32', 'int64')), ('int32', 'int64')),
((2049, 1337, 'float32', 'int32'),))) ((2049, 1337, 'float32', False, 'int32'),)))
def test_combined_1d(self, size, k, dtype, idx_dtype): def test_combined_1d(self, size, k, dtype, sorted, idx_dtype):
if isinstance(k, str): if isinstance(k, str):
k = eval(k.replace('n', str(size))) k = eval(k.replace('n', str(size)))
x = theano.tensor.vector(name='x', dtype=dtype) x = theano.tensor.vector(name='x', dtype=dtype)
yv, yi = topk_and_argtopk(x, k, idx_dtype=idx_dtype) yv, yi = topk_and_argtopk(x, k, sorted=sorted, idx_dtype=idx_dtype)
fn = theano.function([x], [yv, yi]) fn = theano.function([x], [yv, yi])
# generate a all-unique array # generate a all-unique array
xval = gen_unique_vector(size, dtype) xval = gen_unique_vector(size, dtype)
...@@ -343,15 +346,16 @@ class Test_TopK(unittest.TestCase): ...@@ -343,15 +346,16 @@ class Test_TopK(unittest.TestCase):
product( product(
(18, 62, 258), (18, 62, 258),
(1, -1, 'n//2'), (1, -1, 'n//2'),
('int32', 'float32')), ('int32', 'float32'),
((2048, 1337, 'float32'),))) (False,)),
def test_argtopk_1d_collision(self, size, k, dtype): ((2048, 1337, 'float32', False),)))
def test_argtopk_1d_collision(self, size, k, dtype, sorted):
# with non-unique kth max value # with non-unique kth max value
if isinstance(k, str): if isinstance(k, str):
k = eval(k.replace('n', str(size))) k = eval(k.replace('n', str(size)))
x = theano.tensor.vector(name='x', dtype=dtype) x = theano.tensor.vector(name='x', dtype=dtype)
y = argtopk(x, k, idx_dtype='int32') y = argtopk(x, k, sorted=sorted, idx_dtype='int32')
fn = theano.function([x], y) fn = theano.function([x], y)
xval = np.repeat(np.random.uniform(-100., 100., size=size // 2).astype(dtype), 2) xval = np.repeat(np.random.uniform(-100., 100., size=size // 2).astype(dtype), 2)
xval = xval[np.random.permutation(size)] xval = xval[np.random.permutation(size)]
...@@ -364,8 +368,9 @@ class Test_TopK(unittest.TestCase): ...@@ -364,8 +368,9 @@ class Test_TopK(unittest.TestCase):
((17, 15), (2, 3, 5, 7, 11), (2017, 5, 3)), ((17, 15), (2, 3, 5, 7, 11), (2017, 5, 3)),
(-1, '(1+n)//2', '-n', '1-n'), (-1, '(1+n)//2', '-n', '1-n'),
('float32', 'int32'), ('float32', 'int32'),
(False,),
('int32', 'int64'))) ('int32', 'int64')))
def test_argtopk_nd(self, shp, k_, dtype, idx_dtype): def test_argtopk_nd(self, shp, k_, dtype, sorted, idx_dtype):
ndim = len(shp) ndim = len(shp)
for axis in range(-ndim, ndim): for axis in range(-ndim, ndim):
if isinstance(k_, str): if isinstance(k_, str):
...@@ -378,7 +383,7 @@ class Test_TopK(unittest.TestCase): ...@@ -378,7 +383,7 @@ class Test_TopK(unittest.TestCase):
x = theano.tensor.tensor( x = theano.tensor.tensor(
name='x', broadcastable=(False,) * len(shp), dtype=dtype) name='x', broadcastable=(False,) * len(shp), dtype=dtype)
y = argtopk(x, k, axis=axis, idx_dtype=idx_dtype) y = argtopk(x, k, axis=axis, sorted=sorted, idx_dtype=idx_dtype)
fn = theano.function([x], y) fn = theano.function([x], y)
size = reduce(int.__mul__, shp) size = reduce(int.__mul__, shp)
xval = gen_unique_vector(size, dtype).reshape(shp) xval = gen_unique_vector(size, dtype).reshape(shp)
...@@ -393,8 +398,8 @@ class Test_TopK(unittest.TestCase): ...@@ -393,8 +398,8 @@ class Test_TopK(unittest.TestCase):
@utt.parameterized.expand(product( @utt.parameterized.expand(product(
((257,), (17, 15), (5, 3, 5, 3), (2, 3, 5, 7, 11)), ((257,), (17, 15), (5, 3, 5, 3), (2, 3, 5, 7, 11)),
(1, -1, '(1+n)//2', 'n-1', '-n', '1-n'))) (1, -1, '(1+n)//2', 'n-1', '-n', '1-n'), (False,)))
def test_grad(self, shp, k_): def test_grad(self, shp, k_, sorted):
ndim = len(shp) ndim = len(shp)
for axis in range(-ndim, ndim): for axis in range(-ndim, ndim):
if isinstance(k_, str): if isinstance(k_, str):
...@@ -410,7 +415,7 @@ class Test_TopK(unittest.TestCase): ...@@ -410,7 +415,7 @@ class Test_TopK(unittest.TestCase):
reduce(int.__mul__, shp), reduce(int.__mul__, shp),
dtype=theano.config.floatX dtype=theano.config.floatX
).reshape(shp) ).reshape(shp)
utt.verify_grad(lambda x: topk(x, k, axis=axis), [xval], eps=1e-2) utt.verify_grad(lambda x: topk(x, k, axis=axis, sorted=sorted), [xval], eps=1e-2)
class TopKInferShapeTester(utt.InferShapeTester): class TopKInferShapeTester(utt.InferShapeTester):
...@@ -431,7 +436,7 @@ class TopKInferShapeTester(utt.InferShapeTester): ...@@ -431,7 +436,7 @@ class TopKInferShapeTester(utt.InferShapeTester):
x = theano.tensor.tensor( x = theano.tensor.tensor(
name='x', broadcastable=(False,) * len(shp), name='x', broadcastable=(False,) * len(shp),
dtype=theano.config.floatX) dtype=theano.config.floatX)
yv, yi = topk_and_argtopk(x, k, axis=axis, idx_dtype='int32') yv, yi = topk_and_argtopk(x, k, axis=axis, sorted=False, idx_dtype='int32')
size = reduce(int.__mul__, shp) size = reduce(int.__mul__, shp)
xval = gen_unique_vector(size, theano.config.floatX).reshape(shp) xval = gen_unique_vector(size, theano.config.floatX).reshape(shp)
self._compile_and_check( self._compile_and_check(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论