提交 f7b0a7a4 authored 作者: Seyed Morteza Hosseini's avatar Seyed Morteza Hosseini 提交者: Ricardo Vieira

Remove TopkOp

上级 ef22377d
......@@ -142,7 +142,7 @@ from pytensor.tensor.shape import (
# We import as `_shared` instead of `shared` to avoid confusion between
# `pytensor.shared` and `tensor._shared`.
from pytensor.tensor.sort import argsort, argtopk, sort, topk, topk_and_argtopk
from pytensor.tensor.sort import argsort, sort
from pytensor.tensor.subtensor import *
from pytensor.tensor.type import *
from pytensor.tensor.type_other import *
......
......@@ -68,7 +68,6 @@ from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_arrays
from pytensor.tensor.math import Sum, add, eq
from pytensor.tensor.shape import Shape_i, shape_padleft
from pytensor.tensor.sort import TopKOp
from pytensor.tensor.type import DenseTensorType, TensorType
from pytensor.tensor.variable import TensorConstant, TensorVariable
from pytensor.utils import NoDuplicateOptWarningFilter
......@@ -1224,35 +1223,4 @@ def local_merge_alloc(fgraph, node):
return [alloc(inputs_inner[0], *dims_outer)]
@register_useless("fast_compile")
@node_rewriter([TopKOp])
def local_useless_topk(fgraph, node):
"""Remove unused `TopKOp` outputs."""
op = node.op
if not isinstance(op, TopKOp):
return
if not (op.return_values and op.return_indices):
return False
x, k = node.inputs
ret_val = bool(fgraph.clients[node.outputs[0]])
ret_idx = bool(fgraph.clients[node.outputs[1]])
if not (ret_val ^ ret_idx):
# both true -> nothing to remove
# both false -> let pruner handle
return False
old_output = node.outputs[ret_idx]
new_output = TopKOp(
axis=op.axis,
sorted=op.sorted,
idx_dtype=op.idx_dtype,
return_values=ret_val,
return_indices=ret_idx,
)(x, k)
copy_stack_trace(node.outputs[0], new_output)
return {old_output: new_output}
register_canonicalize(RemovalNodeRewriter(tensor_copy), name="remove_tensor_copy")
......@@ -4,11 +4,9 @@ from pytensor.gradient import grad_undefined
from pytensor.graph.basic import Apply, Constant
from pytensor.graph.op import Op
from pytensor.misc.safe_asarray import _asarray
from pytensor.tensor.basic import arange, as_tensor_variable, flatten, switch
from pytensor.tensor.basic import arange, as_tensor_variable, switch
from pytensor.tensor.math import eq, ge, mul
from pytensor.tensor.shape import shape
from pytensor.tensor.subtensor import set_subtensor
from pytensor.tensor.type import TensorType, integer_dtypes
from pytensor.tensor.type import TensorType
def _variable_is_none(var):
......@@ -304,270 +302,3 @@ def _topk_py_impl(op, x, k, axis, idx_dtype):
else:
zi = np.argpartition(x, -k, axis=axis)[tuple(idx)]
return zi.astype(idx_dtype)
class TopKOp(Op):
"""Operations related to finding k-largest elements.
Parameters
----------
axis: integer
Defaults to ``-1``.
The axis to perform the operation. Must be in range ``[-ndim, ndim)``, where
``ndim`` is the dimensionality of input tensor.
idx_dtype: string
Specify output dtype for indices, defaults to ``int64``, must be integer type.
sorted: bool
NOTE: NOT IMPLEMENTED YET
Defaults to ``True``
If True, the result array would be sorted in descending order.
Notes
-----
- The output order is not guaranteed. On the CPU, we use
``np.partition`` and ``np.argpartition`` that only make sure the
k-th element is the correct one and that the other
elements are on the correct side.
- By default, this Op gives two outputs: values and indices. However
optimizers may remove a certain output if not needed.
- Computing the gradient requests the computation of the indices in
forward pass.
- If the top-k-th value is not unique, we cannot guarantee the
output indices being deterministically chosen.
See Also
--------
topk
argtopk
argtopk_and_topk
"""
# TODO more params
"""
only_top_kth: bool
Defaults to ``False``
If ``True``, will only find one exact top k-th element on given axis.
"""
# TODO c_code
# TODO add opt, if k==1, use max/min reduce
# also if k is axis size, just copy input tensor
# TODO add opt, to merge argtopk / topk
__props__ = ("axis", "sorted", "return_values", "return_indices", "idx_dtype")
def __init__(
self,
axis=-1,
sorted=True,
idx_dtype="int64",
return_values=True,
return_indices=True,
):
# numpy always uses int64 as output dtype for arg*() routines
# however, we add "idx_dtype" param as memory is more precious on gpu
if not isinstance(axis, int):
raise TypeError(f'"axis" parameter must be integer, got "{type(axis)}"')
if sorted:
raise NotImplementedError(
"The sorted parameter is not yet implemented. Use sorted=False for now."
)
if idx_dtype not in integer_dtypes:
raise TypeError(
f'"idx_dtype" parameter must be an integer dtype, got "{idx_dtype}"'
)
if not (return_indices or return_values):
raise ValueError(
"Neither return_values nor return_indices is True, this isn't allowed"
)
self.axis = axis
self.sorted = sorted
self.return_values = return_values
self.return_indices = return_indices
self.idx_dtype = idx_dtype
def __str__(self):
return "%(op)s{axis=%(axis)d, sorted=%(sorted)s}" % dict(
op=self.__class__.__name__, axis=self.axis, sorted=self.sorted
)
def make_node(self, inp, kth):
inp = as_tensor_variable(inp)
ndim = inp.ndim
if ndim == 0:
raise ValueError("Cannot take scalar as input")
if not -ndim <= self.axis < ndim:
raise IndexError(
'"axis" parameter out of range,'
f" expected integer within [{int(-ndim)}, {int(ndim - 1)}]"
)
kth = as_tensor_variable(kth)
_check_tensor_is_scalar(kth)
outs = []
if self.return_values:
outs.append(
TensorType(dtype=inp.type.dtype, shape=(None,) * inp.type.ndim)()
)
if self.return_indices:
outs.append(
TensorType(dtype=self.idx_dtype, shape=(None,) * inp.type.ndim)()
)
return Apply(self, [inp, kth], outs)
def perform(self, node, inputs, output_storage):
x, k = inputs
axis = self.axis
if not self.return_indices:
pzv = output_storage[0]
pzv[0] = _topk_py_impl(self, x, k, axis, None)
elif self.return_values:
pzv = output_storage[0]
pzi = output_storage[1]
pzv[0], pzi[0] = _topk_py_impl(self, x, k, axis, node.outputs[1].dtype)
else:
pzi = output_storage[0]
pzi[0] = _topk_py_impl(self, x, k, axis, node.outputs[0].dtype)
def infer_shape(self, fgraph, node, inp_shapes):
shp = list(inp_shapes[0])
shp[self.axis] = np.abs(node.inputs[1])
shp = tuple(shp)
return [shp for i in [self.return_values, self.return_indices] if i]
def L_op(self, inputs, outputs, out_grads):
x, k = inputs
k_grad = grad_undefined(self, 1, k, "topk: k is not differentiable")
if not (self.return_indices or self.return_values):
x_grad = grad_undefined(
self,
0,
x,
"topk: cannot get gradient without both indices and values",
)
else:
x_shp = shape(x)
z_grad = out_grads[0]
ndim = x.ndim
axis = self.axis % ndim
grad_indices = [
arange(x_shp[i]).dimshuffle([0] + ["x"] * (ndim - i - 1))
if i != axis
else outputs[-1]
for i in range(ndim)
]
x_grad = x.zeros_like(dtype=z_grad.dtype)
x_grad = set_subtensor(x_grad[tuple(grad_indices)], z_grad)
return [x_grad, k_grad]
def topk(x, kth, axis=-1, sorted=True, idx_dtype="int64"):
"""
Returns the k-largest elements along an axis.
Parameters
----------
x: tensor instance
kth: integer constant/variable
Must not be 0. If negative, gives k-smallest elements instead.
axis: integer or ``None``
Upon which axis shall the operation be performed on.
If ``None``, works on flattened array.
sorted: bool
NOTE: NOT IMPLEMENTED YET, USE ``False`` FOR NOW.
Defaults to ``True``
If True, the result array would be sorted in descending order.
idx_dtype: string
Specify output dtype used in indices, defaults to ``int64``, must be integer type.
This option is here because indices are needed for gradient.
Returns
-------
Tensor variable with same dtype as `x`.
Notes
-----
- ``sorted=True`` is not supported yet.
"""
if axis is None:
x = flatten(x)
axis = 0
return TopKOp(axis=axis, sorted=sorted, idx_dtype=idx_dtype)(x, kth)[0]
def argtopk(x, kth, axis=-1, sorted=True, idx_dtype="int64"):
"""
Returns the indices of k-largest elements along an axis.
Parameters
----------
x: tensor instance
kth: integer constant/variable
Must not be 0. If negative, gives k-smallest elements instead.
sorted: bool
NOTE: NOT IMPLEMENTED YET, USE ``False`` FOR NOW.
Defaults to ``True``
If True, the result array of corresponding indices would be sorted in descending order.
axis: integer, tuple/list of integers, or ``None``
Upon which axis shall the operation be performed on.
If ``None``, works on flattened array.
idx_dtype: string
Specify output dtype, defaults to ``int64``, must be integer type.
Returns
-------
Tensor variable with dtype specified in `idx_dtype`.
Notes
-----
- ``sorted=True`` is not supported yet.
- If the top-k-th value is not unique, we cannot guarantee the output
indices are deterministically chosen.
"""
if axis is None:
x = flatten(x)
axis = 0
return TopKOp(axis=axis, sorted=sorted, idx_dtype=idx_dtype)(x, kth)[1]
def topk_and_argtopk(x, kth, axis=-1, sorted=True, idx_dtype="int64"):
"""
Returns the results of both topk() and argtopk() in one Op.
See the respective documentation for details.
Returns
-------
tuple: (values, indices)
"""
if axis is None:
x = flatten(x)
axis = 0
return TopKOp(axis=axis, sorted=sorted, idx_dtype=idx_dtype)(x, kth)
from functools import reduce
from itertools import chain, product
import numpy as np
import pytest
import pytensor
from pytensor.compile.mode import Mode
from pytensor.tensor.sort import (
ArgSortOp,
SortOp,
TopKOp,
argsort,
argtopk,
sort,
topk,
topk_and_argtopk,
)
from pytensor.tensor.sort import ArgSortOp, SortOp, argsort, sort
from pytensor.tensor.type import (
dmatrix,
dvector,
......@@ -24,8 +10,6 @@ from pytensor.tensor.type import (
lscalar,
matrix,
scalar,
tensor,
vector,
)
from tests import unittest_tools as utt
......@@ -253,272 +237,3 @@ def test_argsort_grad():
data = rng.random((2, 3, 3)).astype(pytensor.config.floatX)
utt.verify_grad(lambda x: argsort(x, axis=2), [data])
class TestTopK:
mode = None
op_class = TopKOp
def setup_method(self):
pass
@pytest.mark.parametrize("dtype", _all_dtypes)
@pytest.mark.parametrize("idx_dtype", integer_dtypes)
@pytest.mark.parametrize("axis", [-1, 0, None])
@pytest.mark.parametrize("sorted", [False])
def test_argtopk_sanity(self, dtype, idx_dtype, axis, sorted):
x = vector(name="x", dtype=dtype)
fn = pytensor.function(
[x],
argtopk(x, 1, axis=axis, sorted=sorted, idx_dtype=idx_dtype),
mode=self.mode,
)
assert any(isinstance(n.op, self.op_class) for n in fn.maker.fgraph.apply_nodes)
xval = np.asarray([1]).astype(dtype)
yval = fn(xval)
assert yval == np.asarray([0], dtype=idx_dtype)
assert yval.dtype == np.dtype(idx_dtype)
@pytest.mark.parametrize("dtype", _all_dtypes)
@pytest.mark.parametrize("axis", [-1, 0, None])
@pytest.mark.parametrize("sorted", [False])
def test_topk_sanity(self, dtype, axis, sorted):
x = vector(name="x", dtype=dtype)
fn = pytensor.function(
[x], topk(x, 1, axis=axis, sorted=sorted), mode=self.mode
)
assert any(isinstance(n.op, self.op_class) for n in fn.maker.fgraph.apply_nodes)
xval = np.asarray([1]).astype(dtype)
yval = fn(xval)
assert yval == xval
assert yval.dtype == xval.dtype
@pytest.mark.parametrize("dtype", _all_dtypes)
@pytest.mark.parametrize("idx_dtype", integer_dtypes)
@pytest.mark.parametrize("axis", [-1, 0, None])
@pytest.mark.parametrize("sorted", [False])
def test_combined_sanity(self, dtype, idx_dtype, axis, sorted):
x = vector(name="x", dtype=dtype)
yv, yi = topk_and_argtopk(x, 1, axis=axis, sorted=sorted, idx_dtype=idx_dtype)
fn = pytensor.function([x], [yv, yi], mode=self.mode)
assert any(isinstance(n.op, self.op_class) for n in fn.maker.fgraph.apply_nodes)
xval = np.asarray([1]).astype(dtype)
yvval, yival = fn(xval)
assert yival == np.asarray([0], dtype=idx_dtype)
utt.assert_allclose(xval, yvval)
assert yvval.dtype == xval.dtype
assert yival.dtype == np.dtype(idx_dtype)
@pytest.mark.parametrize(
"size, k, dtype, sorted",
chain(
product(
(16, 61, 257),
(1, -1, -10, "n//2", "n-1", "-n", "1-n"),
("float64", "float16", "int16", "int8"),
(False,),
),
((2049, 1337, "float64", False),),
),
)
def test_topk_1d(self, size, k, dtype, sorted):
if isinstance(k, str):
k = eval(k.replace("n", str(size)))
x = vector(name="x", dtype=dtype)
y = topk(x, k, sorted=sorted)
fn = pytensor.function([x], y, mode=self.mode)
assert any(isinstance(n.op, self.op_class) for n in fn.maker.fgraph.apply_nodes)
# assert local_useless_topk opt is done properly
assert 1 == len(fn.maker.fgraph.outputs[0].owner.outputs)
# generate a all-unique array
xval = gen_unique_vector(size, dtype)
yval = fn(xval)
idx = slice(-k, None) if k > 0 else slice(-k)
goal = np.sort(xval)[idx]
assert yval.dtype == goal.dtype
utt.assert_allclose(goal, np.sort(yval))
@pytest.mark.parametrize(
"size, k, dtype, sorted, idx_dtype",
chain(
product(
(16, 61, 257),
(1, -1, -10, "n//2", "n-1", "-n"),
("float32", "int32"),
(False,),
("int32", "int64"),
),
((2049, 1337, "float32", False, "int32"),),
),
)
def test_argtopk_1d(self, size, k, dtype, sorted, idx_dtype):
if isinstance(k, str):
k = eval(k.replace("n", str(size)))
x = vector(name="x", dtype=dtype)
y = argtopk(x, k, sorted=sorted, idx_dtype=idx_dtype)
fn = pytensor.function([x], y, mode=self.mode)
assert any(isinstance(n.op, self.op_class) for n in fn.maker.fgraph.apply_nodes)
# assert local_useless_topk opt is done properly
assert 1 == len(fn.maker.fgraph.outputs[0].owner.outputs)
# generate a all-unique array
xval = gen_unique_vector(size, dtype)
yval = fn(xval)
idx = slice(-k, None) if k > 0 else slice(-k)
goal = np.argsort(xval)[idx].astype(idx_dtype)
# due to uniqueness, we expect indices same
assert np.all(xval[np.sort(yval)] == xval[np.sort(goal)])
@pytest.mark.parametrize(
"size, k, dtype, sorted, idx_dtype",
chain(
product(
(16, 61, 257),
(1, -1, 10, "n//2", "n-1", "1-n"),
("float32", "int32"),
(False,),
("int32", "int64"),
),
((2049, 1337, "float32", False, "int32"),),
),
)
def test_combined_1d(self, size, k, dtype, sorted, idx_dtype):
if isinstance(k, str):
k = eval(k.replace("n", str(size)))
x = vector(name="x", dtype=dtype)
yv, yi = topk_and_argtopk(x, k, sorted=sorted, idx_dtype=idx_dtype)
fn = pytensor.function([x], [yv, yi], mode=self.mode)
assert any(isinstance(n.op, self.op_class) for n in fn.maker.fgraph.apply_nodes)
# generate a all-unique array
xval = gen_unique_vector(size, dtype)
yvval, yival = fn(xval)
idx = slice(-k, None) if k > 0 else slice(-k)
goali = np.argsort(xval)[idx].astype(idx_dtype)
goalv = xval[goali]
# due to uniqueness, we expect indices same
assert np.all(xval[np.sort(yival)] == xval[np.sort(goali)])
utt.assert_allclose(np.sort(yvval), goalv)
@pytest.mark.parametrize(
"size, k, dtype, sorted",
chain(
product((18, 62, 258), (1, -1, "n//2"), ("int32", "float32"), (False,)),
((2048, 1337, "float32", False),),
),
)
def test_argtopk_1d_collision(self, size, k, dtype, sorted):
# with non-unique kth max value
if isinstance(k, str):
k = eval(k.replace("n", str(size)))
x = vector(name="x", dtype=dtype)
y = argtopk(x, k, sorted=sorted, idx_dtype="int32")
# DebugMode won't like the index change on collision on CPU
# So don't use DebugMode here.
mode = self.mode
if isinstance(self.mode, pytensor.compile.debugmode.DebugMode):
mode = Mode(optimizer=mode.optimizer)
fn = pytensor.function([x], y, mode=mode)
assert any(isinstance(n.op, self.op_class) for n in fn.maker.fgraph.apply_nodes)
rng = np.random.default_rng(utt.fetch_seed())
xval = np.repeat(rng.uniform(-100.0, 100.0, size=size // 2).astype(dtype), 2)
xval = xval[rng.permutation(size)]
yval = fn(xval)
idx = slice(-k, None) if k > 0 else slice(-k)
goal = np.argsort(xval)[idx].astype("int32")
utt.assert_allclose(np.sort(xval[yval]), np.sort(xval[goal]))
@pytest.mark.parametrize(
"shp, k_, dtype, sorted, idx_dtype",
product(
(
(17, 15),
(2, 3, 5, 7, 11),
(500, 5, 3),
), # NB: Test may fail with bigger sizes (e.g. (2017, 5, 3)) due to "too many resources requested" kernel error on some GPUs.
(-1, "(1+n)//2", "-n", "1-n"),
("float32", "int32"),
(False,),
("int32", "int64"),
),
)
def test_argtopk_nd(self, shp, k_, dtype, sorted, idx_dtype):
ndim = len(shp)
for axis in range(-ndim, ndim):
if isinstance(k_, str):
k = eval(k_.replace("n", str(shp[axis])))
else:
k = k_
if k == 0:
continue
x = tensor(name="x", shape=(None,) * len(shp), dtype=dtype)
y = argtopk(x, k, axis=axis, sorted=sorted, idx_dtype=idx_dtype)
fn = pytensor.function([x], y, mode=self.mode)
assert any(
isinstance(n.op, self.op_class) for n in fn.maker.fgraph.apply_nodes
)
size = reduce(int.__mul__, shp)
xval = gen_unique_vector(size, dtype).reshape(shp)
yval = fn(xval)
idx = slice(-k, None) if k > 0 else slice(-k)
l = axis % ndim
r = ndim - l
idx = (slice(None),) * l + (idx,) + (slice(None),) * (r - 1)
goal = np.argsort(xval, axis=axis)[idx].astype(idx_dtype)
assert np.all(np.sort(yval, axis=axis) == np.sort(goal, axis=axis))
@pytest.mark.parametrize("shp", ((257,), (17, 15), (5, 3, 5, 3), (2, 3, 5, 7, 11)))
@pytest.mark.parametrize("k_", (1, -1, "(1+n)//2", "n-1", "-n", "1-n"))
@pytest.mark.parametrize("sorted", [False])
def test_grad(self, shp, k_, sorted):
ndim = len(shp)
for axis in range(-ndim, ndim):
if isinstance(k_, str):
k = eval(k_.replace("n", str(shp[axis])))
else:
k = k_
if k == 0:
continue
# make input away from undefined gradient (where some inputs are equal)
xval = gen_unique_vector(
reduce(int.__mul__, shp), dtype=pytensor.config.floatX
).reshape(shp)
utt.verify_grad(
lambda x: topk(x, k, axis=axis, sorted=sorted), [xval], eps=1e-2
)
class TestTopKInferShape(utt.InferShapeTester):
@pytest.mark.parametrize(
"shp", ((2, 3), (15, 17), (11, 7, 5), (2, 3, 5, 7, 11), (2, 4, 3, 1))
)
@pytest.mark.parametrize("k_", (1, "(1+n)//2", "n-1", "n"))
def test_combined_infer_shape(self, shp, k_):
ndim = len(shp)
for axis in range(-ndim, ndim):
if isinstance(k_, str):
k = eval(k_.replace("n", str(shp[axis])))
else:
k = k_
if k == 0:
continue
x = tensor(name="x", shape=(None,) * len(shp), dtype=pytensor.config.floatX)
yv, yi = topk_and_argtopk(x, k, axis=axis, sorted=False, idx_dtype="int32")
size = reduce(int.__mul__, shp)
xval = gen_unique_vector(size, pytensor.config.floatX).reshape(shp)
self._compile_and_check([x], [yv, yi], [xval], TopKOp)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论