提交 306aceb5 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Move sort Ops dispatchers to their own file

上级 c90ee4ba
...@@ -12,6 +12,7 @@ import pytensor.link.numba.dispatch.scalar ...@@ -12,6 +12,7 @@ import pytensor.link.numba.dispatch.scalar
import pytensor.link.numba.dispatch.shape import pytensor.link.numba.dispatch.shape
import pytensor.link.numba.dispatch.signal import pytensor.link.numba.dispatch.signal
import pytensor.link.numba.dispatch.slinalg import pytensor.link.numba.dispatch.slinalg
import pytensor.link.numba.dispatch.sort
import pytensor.link.numba.dispatch.sparse import pytensor.link.numba.dispatch.sparse
import pytensor.link.numba.dispatch.subtensor import pytensor.link.numba.dispatch.subtensor
import pytensor.link.numba.dispatch.tensor_basic import pytensor.link.numba.dispatch.tensor_basic
......
...@@ -27,7 +27,6 @@ from pytensor.sparse import SparseTensorType ...@@ -27,7 +27,6 @@ from pytensor.sparse import SparseTensorType
from pytensor.tensor.basic import Nonzero from pytensor.tensor.basic import Nonzero
from pytensor.tensor.blas import BatchedDot from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot from pytensor.tensor.math import Dot
from pytensor.tensor.sort import ArgSortOp, SortOp
from pytensor.tensor.type import TensorType from pytensor.tensor.type import TensorType
...@@ -317,68 +316,6 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs): ...@@ -317,68 +316,6 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs):
return deepcopyop return deepcopyop
@numba_funcify.register(SortOp)
def numba_funcify_SortOp(op, node, **kwargs):
@numba_njit
def sort_f(a, axis):
axis = axis.item()
a_swapped = np.swapaxes(a, axis, -1)
a_sorted = np.sort(a_swapped)
a_sorted_swapped = np.swapaxes(a_sorted, -1, axis)
return a_sorted_swapped
if op.kind != "quicksort":
warnings.warn(
(
f'Numba function sort doesn\'t support kind="{op.kind}"'
" switching to `quicksort`."
),
UserWarning,
)
return sort_f
@numba_funcify.register(ArgSortOp)
def numba_funcify_ArgSortOp(op, node, **kwargs):
def argsort_f_kind(kind):
@numba_njit
def argort_vec(X, axis):
axis = axis.item()
Y = np.swapaxes(X, axis, 0)
result = np.empty_like(Y, dtype="int64")
indices = list(np.ndindex(Y.shape[1:]))
for idx in indices:
result[(slice(None), *idx)] = np.argsort(
Y[(slice(None), *idx)], kind=kind
)
result = np.swapaxes(result, 0, axis)
return result
return argort_vec
kind = op.kind
if kind not in ["quicksort", "mergesort"]:
kind = "quicksort"
warnings.warn(
(
f'Numba function argsort doesn\'t support kind="{op.kind}"'
" switching to `quicksort`."
),
UserWarning,
)
return argsort_f_kind(kind)
@numba.extending.intrinsic @numba.extending.intrinsic
def direct_cast(typingctx, val, typ): def direct_cast(typingctx, val, typ):
if isinstance(typ, numba.types.TypeRef): if isinstance(typ, numba.types.TypeRef):
......
import warnings
import numpy as np
from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import numba_njit
from pytensor.tensor.sort import ArgSortOp, SortOp
@numba_funcify.register(SortOp)
def numba_funcify_SortOp(op, node, **kwargs):
@numba_njit
def sort_f(a, axis):
axis = axis.item()
a_swapped = np.swapaxes(a, axis, -1)
a_sorted = np.sort(a_swapped)
a_sorted_swapped = np.swapaxes(a_sorted, -1, axis)
return a_sorted_swapped
if op.kind != "quicksort":
warnings.warn(
(
f'Numba function sort doesn\'t support kind="{op.kind}"'
" switching to `quicksort`."
),
UserWarning,
)
return sort_f
@numba_funcify.register(ArgSortOp)
def numba_funcify_ArgSortOp(op, node, **kwargs):
def argsort_f_kind(kind):
@numba_njit
def argort_vec(X, axis):
axis = axis.item()
Y = np.swapaxes(X, axis, 0)
result = np.empty_like(Y, dtype="int64")
indices = list(np.ndindex(Y.shape[1:]))
for idx in indices:
result[(slice(None), *idx)] = np.argsort(
Y[(slice(None), *idx)], kind=kind
)
result = np.swapaxes(result, 0, axis)
return result
return argort_vec
kind = op.kind
if kind not in ["quicksort", "mergesort"]:
kind = "quicksort"
warnings.warn(
(
f'Numba function argsort doesn\'t support kind="{op.kind}"'
" switching to `quicksort`."
),
UserWarning,
)
return argsort_f_kind(kind)
...@@ -31,7 +31,6 @@ from pytensor.raise_op import assert_op ...@@ -31,7 +31,6 @@ from pytensor.raise_op import assert_op
from pytensor.scalar.basic import ScalarOp, as_scalar from pytensor.scalar.basic import ScalarOp, as_scalar
from pytensor.tensor import blas, tensor from pytensor.tensor import blas, tensor
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.sort import ArgSortOp, SortOp
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -331,70 +330,6 @@ def test_create_numba_signature(v, expected, force_scalar): ...@@ -331,70 +330,6 @@ def test_create_numba_signature(v, expected, force_scalar):
assert res == expected assert res == expected
@pytest.mark.parametrize(
"x",
[
[], # Empty list
[3, 2, 1], # Simple list
np.random.randint(0, 10, (3, 2, 3, 4, 4)), # Multi-dimensional array
],
)
@pytest.mark.parametrize("axis", [0, -1, None])
@pytest.mark.parametrize(
("kind", "exc"),
[
["quicksort", None],
["mergesort", UserWarning],
["heapsort", UserWarning],
["stable", UserWarning],
],
)
def test_Sort(x, axis, kind, exc):
if axis:
g = SortOp(kind)(pt.as_tensor_variable(x), axis)
else:
g = SortOp(kind)(pt.as_tensor_variable(x))
cm = contextlib.suppress() if not exc else pytest.warns(exc)
with cm:
compare_numba_and_py([], [g], [])
@pytest.mark.parametrize(
"x",
[
[], # Empty list
[3, 2, 1], # Simple list
None, # Multi-dimensional array (see below)
],
)
@pytest.mark.parametrize("axis", [0, -1, None])
@pytest.mark.parametrize(
("kind", "exc"),
[
["quicksort", None],
["heapsort", None],
["stable", UserWarning],
],
)
def test_ArgSort(x, axis, kind, exc):
if x is None:
x = np.arange(5 * 5 * 5 * 5)
np.random.shuffle(x)
x = np.reshape(x, (5, 5, 5, 5))
if axis:
g = ArgSortOp(kind)(pt.as_tensor_variable(x), axis)
else:
g = ArgSortOp(kind)(pt.as_tensor_variable(x))
cm = contextlib.suppress() if not exc else pytest.warns(exc)
with cm:
compare_numba_and_py([], [g], [])
def test_ViewOp(): def test_ViewOp():
v = pt.vector() v = pt.vector()
v_test_value = np.arange(4, dtype=config.floatX) v_test_value = np.arange(4, dtype=config.floatX)
......
import contextlib
import numpy as np
import pytest
from pytensor import tensor as pt
from pytensor.tensor.sort import ArgSortOp, SortOp
from tests.link.numba.test_basic import compare_numba_and_py
@pytest.mark.parametrize(
"x",
[
[], # Empty list
[3, 2, 1], # Simple list
np.random.randint(0, 10, (3, 2, 3, 4, 4)), # Multi-dimensional array
],
)
@pytest.mark.parametrize("axis", [0, -1, None])
@pytest.mark.parametrize(
("kind", "exc"),
[
["quicksort", None],
["mergesort", UserWarning],
["heapsort", UserWarning],
["stable", UserWarning],
],
)
def test_Sort(x, axis, kind, exc):
if axis:
g = SortOp(kind)(pt.as_tensor_variable(x), axis)
else:
g = SortOp(kind)(pt.as_tensor_variable(x))
cm = contextlib.suppress() if not exc else pytest.warns(exc)
with cm:
compare_numba_and_py([], [g], [])
@pytest.mark.parametrize(
"x",
[
[], # Empty list
[3, 2, 1], # Simple list
None, # Multi-dimensional array (see below)
],
)
@pytest.mark.parametrize("axis", [0, -1, None])
@pytest.mark.parametrize(
("kind", "exc"),
[
["quicksort", None],
["heapsort", None],
["stable", UserWarning],
],
)
def test_ArgSort(x, axis, kind, exc):
if x is None:
x = np.arange(5 * 5 * 5 * 5)
np.random.shuffle(x)
x = np.reshape(x, (5, 5, 5, 5))
if axis:
g = ArgSortOp(kind)(pt.as_tensor_variable(x), axis)
else:
g = ArgSortOp(kind)(pt.as_tensor_variable(x))
cm = contextlib.suppress() if not exc else pytest.warns(exc)
with cm:
compare_numba_and_py([], [g], [])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论