提交 306aceb5 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Move sort Ops dispatchers to their own file

上级 c90ee4ba
......@@ -12,6 +12,7 @@ import pytensor.link.numba.dispatch.scalar
import pytensor.link.numba.dispatch.shape
import pytensor.link.numba.dispatch.signal
import pytensor.link.numba.dispatch.slinalg
import pytensor.link.numba.dispatch.sort
import pytensor.link.numba.dispatch.sparse
import pytensor.link.numba.dispatch.subtensor
import pytensor.link.numba.dispatch.tensor_basic
......
......@@ -27,7 +27,6 @@ from pytensor.sparse import SparseTensorType
from pytensor.tensor.basic import Nonzero
from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot
from pytensor.tensor.sort import ArgSortOp, SortOp
from pytensor.tensor.type import TensorType
......@@ -317,68 +316,6 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs):
return deepcopyop
@numba_funcify.register(SortOp)
def numba_funcify_SortOp(op, node, **kwargs):
@numba_njit
def sort_f(a, axis):
axis = axis.item()
a_swapped = np.swapaxes(a, axis, -1)
a_sorted = np.sort(a_swapped)
a_sorted_swapped = np.swapaxes(a_sorted, -1, axis)
return a_sorted_swapped
if op.kind != "quicksort":
warnings.warn(
(
f'Numba function sort doesn\'t support kind="{op.kind}"'
" switching to `quicksort`."
),
UserWarning,
)
return sort_f
@numba_funcify.register(ArgSortOp)
def numba_funcify_ArgSortOp(op, node, **kwargs):
def argsort_f_kind(kind):
@numba_njit
def argort_vec(X, axis):
axis = axis.item()
Y = np.swapaxes(X, axis, 0)
result = np.empty_like(Y, dtype="int64")
indices = list(np.ndindex(Y.shape[1:]))
for idx in indices:
result[(slice(None), *idx)] = np.argsort(
Y[(slice(None), *idx)], kind=kind
)
result = np.swapaxes(result, 0, axis)
return result
return argort_vec
kind = op.kind
if kind not in ["quicksort", "mergesort"]:
kind = "quicksort"
warnings.warn(
(
f'Numba function argsort doesn\'t support kind="{op.kind}"'
" switching to `quicksort`."
),
UserWarning,
)
return argsort_f_kind(kind)
@numba.extending.intrinsic
def direct_cast(typingctx, val, typ):
if isinstance(typ, numba.types.TypeRef):
......
import warnings
import numpy as np
from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import numba_njit
from pytensor.tensor.sort import ArgSortOp, SortOp
@numba_funcify.register(SortOp)
def numba_funcify_SortOp(op, node, **kwargs):
@numba_njit
def sort_f(a, axis):
axis = axis.item()
a_swapped = np.swapaxes(a, axis, -1)
a_sorted = np.sort(a_swapped)
a_sorted_swapped = np.swapaxes(a_sorted, -1, axis)
return a_sorted_swapped
if op.kind != "quicksort":
warnings.warn(
(
f'Numba function sort doesn\'t support kind="{op.kind}"'
" switching to `quicksort`."
),
UserWarning,
)
return sort_f
@numba_funcify.register(ArgSortOp)
def numba_funcify_ArgSortOp(op, node, **kwargs):
def argsort_f_kind(kind):
@numba_njit
def argort_vec(X, axis):
axis = axis.item()
Y = np.swapaxes(X, axis, 0)
result = np.empty_like(Y, dtype="int64")
indices = list(np.ndindex(Y.shape[1:]))
for idx in indices:
result[(slice(None), *idx)] = np.argsort(
Y[(slice(None), *idx)], kind=kind
)
result = np.swapaxes(result, 0, axis)
return result
return argort_vec
kind = op.kind
if kind not in ["quicksort", "mergesort"]:
kind = "quicksort"
warnings.warn(
(
f'Numba function argsort doesn\'t support kind="{op.kind}"'
" switching to `quicksort`."
),
UserWarning,
)
return argsort_f_kind(kind)
......@@ -31,7 +31,6 @@ from pytensor.raise_op import assert_op
from pytensor.scalar.basic import ScalarOp, as_scalar
from pytensor.tensor import blas, tensor
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.sort import ArgSortOp, SortOp
if TYPE_CHECKING:
......@@ -331,70 +330,6 @@ def test_create_numba_signature(v, expected, force_scalar):
assert res == expected
@pytest.mark.parametrize(
"x",
[
[], # Empty list
[3, 2, 1], # Simple list
np.random.randint(0, 10, (3, 2, 3, 4, 4)), # Multi-dimensional array
],
)
@pytest.mark.parametrize("axis", [0, -1, None])
@pytest.mark.parametrize(
("kind", "exc"),
[
["quicksort", None],
["mergesort", UserWarning],
["heapsort", UserWarning],
["stable", UserWarning],
],
)
def test_Sort(x, axis, kind, exc):
if axis:
g = SortOp(kind)(pt.as_tensor_variable(x), axis)
else:
g = SortOp(kind)(pt.as_tensor_variable(x))
cm = contextlib.suppress() if not exc else pytest.warns(exc)
with cm:
compare_numba_and_py([], [g], [])
@pytest.mark.parametrize(
"x",
[
[], # Empty list
[3, 2, 1], # Simple list
None, # Multi-dimensional array (see below)
],
)
@pytest.mark.parametrize("axis", [0, -1, None])
@pytest.mark.parametrize(
("kind", "exc"),
[
["quicksort", None],
["heapsort", None],
["stable", UserWarning],
],
)
def test_ArgSort(x, axis, kind, exc):
if x is None:
x = np.arange(5 * 5 * 5 * 5)
np.random.shuffle(x)
x = np.reshape(x, (5, 5, 5, 5))
if axis:
g = ArgSortOp(kind)(pt.as_tensor_variable(x), axis)
else:
g = ArgSortOp(kind)(pt.as_tensor_variable(x))
cm = contextlib.suppress() if not exc else pytest.warns(exc)
with cm:
compare_numba_and_py([], [g], [])
def test_ViewOp():
v = pt.vector()
v_test_value = np.arange(4, dtype=config.floatX)
......
import contextlib
import numpy as np
import pytest
from pytensor import tensor as pt
from pytensor.tensor.sort import ArgSortOp, SortOp
from tests.link.numba.test_basic import compare_numba_and_py
@pytest.mark.parametrize(
"x",
[
[], # Empty list
[3, 2, 1], # Simple list
np.random.randint(0, 10, (3, 2, 3, 4, 4)), # Multi-dimensional array
],
)
@pytest.mark.parametrize("axis", [0, -1, None])
@pytest.mark.parametrize(
("kind", "exc"),
[
["quicksort", None],
["mergesort", UserWarning],
["heapsort", UserWarning],
["stable", UserWarning],
],
)
def test_Sort(x, axis, kind, exc):
if axis:
g = SortOp(kind)(pt.as_tensor_variable(x), axis)
else:
g = SortOp(kind)(pt.as_tensor_variable(x))
cm = contextlib.suppress() if not exc else pytest.warns(exc)
with cm:
compare_numba_and_py([], [g], [])
@pytest.mark.parametrize(
"x",
[
[], # Empty list
[3, 2, 1], # Simple list
None, # Multi-dimensional array (see below)
],
)
@pytest.mark.parametrize("axis", [0, -1, None])
@pytest.mark.parametrize(
("kind", "exc"),
[
["quicksort", None],
["heapsort", None],
["stable", UserWarning],
],
)
def test_ArgSort(x, axis, kind, exc):
if x is None:
x = np.arange(5 * 5 * 5 * 5)
np.random.shuffle(x)
x = np.reshape(x, (5, 5, 5, 5))
if axis:
g = ArgSortOp(kind)(pt.as_tensor_variable(x), axis)
else:
g = ArgSortOp(kind)(pt.as_tensor_variable(x))
cm = contextlib.suppress() if not exc else pytest.warns(exc)
with cm:
compare_numba_and_py([], [g], [])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论