Unverified 提交 4b1761b7 authored 作者: Victor's avatar Victor 提交者: GitHub

Support numba compiled `sort` and `argsort` functions (#1309)

* feat: support numba compiled sort and argsort functions Signed-off-by: 's avatarVictor Garcia Reolid <victor@seita.nl> * default to supported kind and add warning Signed-off-by: 's avatarVictor Garcia Reolid <victor@seita.nl> * feat: support axis Signed-off-by: 's avatarVictor Garcia Reolid <victor@seita.nl> * use syntax compatible with python 3.10 Signed-off-by: 's avatarVictor Garcia Reolid <victor@seita.nl> * remove checks Signed-off-by: 's avatarVictor Garcia Reolid <victor@seita.nl> * use range instead of prange Signed-off-by: 's avatarVictor Garcia Reolid <victor@seita.nl> * add extra case to check Axis error is raised Signed-off-by: 's avatarVictor Garcia Reolid <victor@seita.nl> * simplify tests Signed-off-by: 's avatarVictor Garcia Reolid <victor@seita.nl> --------- Signed-off-by: 's avatarVictor Garcia Reolid <victor@seita.nl>
上级 0f5da80c
...@@ -38,6 +38,7 @@ from pytensor.tensor.blas import BatchedDot ...@@ -38,6 +38,7 @@ from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot from pytensor.tensor.math import Dot
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from pytensor.tensor.slinalg import Solve from pytensor.tensor.slinalg import Solve
from pytensor.tensor.sort import ArgSortOp, SortOp
from pytensor.tensor.type import TensorType from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import MakeSlice, NoneConst from pytensor.tensor.type_other import MakeSlice, NoneConst
...@@ -433,6 +434,68 @@ def numba_funcify_Shape_i(op, **kwargs): ...@@ -433,6 +434,68 @@ def numba_funcify_Shape_i(op, **kwargs):
return shape_i return shape_i
@numba_funcify.register(SortOp)
def numba_funcify_SortOp(op, node, **kwargs):
@numba_njit
def sort_f(a, axis):
axis = axis.item()
a_swapped = np.swapaxes(a, axis, -1)
a_sorted = np.sort(a_swapped)
a_sorted_swapped = np.swapaxes(a_sorted, -1, axis)
return a_sorted_swapped
if op.kind != "quicksort":
warnings.warn(
(
f'Numba function sort doesn\'t support kind="{op.kind}"'
" switching to `quicksort`."
),
UserWarning,
)
return sort_f
@numba_funcify.register(ArgSortOp)
def numba_funcify_ArgSortOp(op, node, **kwargs):
def argsort_f_kind(kind):
@numba_njit
def argort_vec(X, axis):
axis = axis.item()
Y = np.swapaxes(X, axis, 0)
result = np.empty_like(Y)
indices = list(np.ndindex(Y.shape[1:]))
for idx in indices:
result[(slice(None), *idx)] = np.argsort(
Y[(slice(None), *idx)], kind=kind
)
result = np.swapaxes(result, 0, axis)
return result
return argort_vec
kind = op.kind
if kind not in ["quicksort", "mergesort"]:
kind = "quicksort"
warnings.warn(
(
f'Numba function argsort doesn\'t support kind="{op.kind}"'
" switching to `quicksort`."
),
UserWarning,
)
return argsort_f_kind(kind)
@numba.extending.intrinsic @numba.extending.intrinsic
def direct_cast(typingctx, val, typ): def direct_cast(typingctx, val, typ):
if isinstance(typ, numba.types.TypeRef): if isinstance(typ, numba.types.TypeRef):
......
...@@ -34,6 +34,7 @@ from pytensor.scalar.basic import ScalarOp, as_scalar ...@@ -34,6 +34,7 @@ from pytensor.scalar.basic import ScalarOp, as_scalar
from pytensor.tensor import blas from pytensor.tensor import blas
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from pytensor.tensor.sort import ArgSortOp, SortOp
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -383,6 +384,70 @@ def test_Shape(x, i): ...@@ -383,6 +384,70 @@ def test_Shape(x, i):
compare_numba_and_py([], [g], []) compare_numba_and_py([], [g], [])
@pytest.mark.parametrize(
"x",
[
[], # Empty list
[3, 2, 1], # Simple list
np.random.randint(0, 10, (3, 2, 3, 4, 4)), # Multi-dimensional array
],
)
@pytest.mark.parametrize("axis", [0, -1, None])
@pytest.mark.parametrize(
("kind", "exc"),
[
["quicksort", None],
["mergesort", UserWarning],
["heapsort", UserWarning],
["stable", UserWarning],
],
)
def test_Sort(x, axis, kind, exc):
if axis:
g = SortOp(kind)(pt.as_tensor_variable(x), axis)
else:
g = SortOp(kind)(pt.as_tensor_variable(x))
cm = contextlib.suppress() if not exc else pytest.warns(exc)
with cm:
compare_numba_and_py([], [g], [])
@pytest.mark.parametrize(
"x",
[
[], # Empty list
[3, 2, 1], # Simple list
None, # Multi-dimensional array (see below)
],
)
@pytest.mark.parametrize("axis", [0, -1, None])
@pytest.mark.parametrize(
("kind", "exc"),
[
["quicksort", None],
["heapsort", None],
["stable", UserWarning],
],
)
def test_ArgSort(x, axis, kind, exc):
if x is None:
x = np.arange(5 * 5 * 5 * 5)
np.random.shuffle(x)
x = np.reshape(x, (5, 5, 5, 5))
if axis:
g = ArgSortOp(kind)(pt.as_tensor_variable(x), axis)
else:
g = ArgSortOp(kind)(pt.as_tensor_variable(x))
cm = contextlib.suppress() if not exc else pytest.warns(exc)
with cm:
compare_numba_and_py([], [g], [])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"v, shape, ndim", "v, shape, ndim",
[ [
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论