提交 3dd1f80f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove useless order kwarg from Sort and add numpy 2.0 stable kwarg

上级 35f4706c
...@@ -6,7 +6,9 @@ from pytensor.tensor.sort import SortOp ...@@ -6,7 +6,9 @@ from pytensor.tensor.sort import SortOp
@jax_funcify.register(SortOp) @jax_funcify.register(SortOp)
def jax_funcify_Sort(op, **kwargs): def jax_funcify_Sort(op, **kwargs):
stable = op.kind == "stable"
def sort(arr, axis): def sort(arr, axis):
return jnp.sort(arr, axis=axis) return jnp.sort(arr, axis=axis, stable=stable)
return sort return sort
import typing
import numpy as np import numpy as np
from pytensor.gradient import grad_undefined from pytensor.gradient import grad_undefined
...@@ -9,20 +11,34 @@ from pytensor.tensor.math import eq, ge ...@@ -9,20 +11,34 @@ from pytensor.tensor.math import eq, ge
from pytensor.tensor.type import TensorType from pytensor.tensor.type import TensorType
KIND = typing.Literal["quicksort", "mergesort", "heapsort", "stable"]
KIND_VALUES = typing.get_args(KIND)
def _parse_sort_args(kind: KIND | None, order, stable: bool | None) -> KIND:
if order is not None:
raise ValueError("The order argument is not applicable to PyTensor graphs")
if stable is not None and kind is not None:
raise ValueError("kind and stable cannot be set at the same time")
if stable:
kind = "stable"
elif kind is None:
kind = "quicksort"
if kind not in KIND_VALUES:
raise ValueError(f"kind must be one of {KIND_VALUES}, got {kind}")
return kind
class SortOp(Op): class SortOp(Op):
""" """
This class is a wrapper for numpy sort function. This class is a wrapper for numpy sort function.
""" """
__props__ = ("kind", "order") __props__ = ("kind",)
def __init__(self, kind, order=None): def __init__(self, kind: KIND):
self.kind = kind self.kind = kind
self.order = order
def __str__(self):
return self.__class__.__name__ + f"{{{self.kind}, {self.order}}}"
def make_node(self, input, axis=-1): def make_node(self, input, axis=-1):
input = as_tensor_variable(input) input = as_tensor_variable(input)
...@@ -33,7 +49,7 @@ class SortOp(Op): ...@@ -33,7 +49,7 @@ class SortOp(Op):
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
a, axis = inputs a, axis = inputs
z = output_storage[0] z = output_storage[0]
z[0] = np.sort(a, int(axis), self.kind, self.order) z[0] = np.sort(a, int(axis), self.kind)
def infer_shape(self, fgraph, node, inputs_shapes): def infer_shape(self, fgraph, node, inputs_shapes):
assert node.inputs[0].ndim == node.outputs[0].ndim assert node.inputs[0].ndim == node.outputs[0].ndim
...@@ -75,9 +91,9 @@ class SortOp(Op): ...@@ -75,9 +91,9 @@ class SortOp(Op):
# The goal is to get gradient wrt input from gradient # The goal is to get gradient wrt input from gradient
# wrt sort(input, axis) # wrt sort(input, axis)
idx = argsort(a, axis, kind=self.kind, order=self.order) idx = argsort(a, axis, kind=self.kind)
# rev_idx is the reverse of previous argsort operation # rev_idx is the reverse of previous argsort operation
rev_idx = argsort(idx, axis, kind=self.kind, order=self.order) rev_idx = argsort(idx, axis, kind=self.kind)
indices = [] indices = []
axis_data = switch(ge(axis.data, 0), axis.data, a.ndim + axis.data) axis_data = switch(ge(axis.data, 0), axis.data, a.ndim + axis.data)
for i in range(a.ndim): for i in range(a.ndim):
...@@ -101,7 +117,9 @@ class SortOp(Op): ...@@ -101,7 +117,9 @@ class SortOp(Op):
""" """
def sort(a, axis=-1, kind="quicksort", order=None): def sort(
a, axis=-1, kind: KIND | None = None, order=None, *, stable: bool | None = None
):
""" """
Parameters Parameters
...@@ -111,12 +129,12 @@ def sort(a, axis=-1, kind="quicksort", order=None): ...@@ -111,12 +129,12 @@ def sort(a, axis=-1, kind="quicksort", order=None):
axis: TensorVariable axis: TensorVariable
Axis along which to sort. If None, the array is flattened before Axis along which to sort. If None, the array is flattened before
sorting. sorting.
kind: {'quicksort', 'mergesort', 'heapsort'}, optional kind: {'quicksort', 'mergesort', 'heapsort' 'stable'}, optional
Sorting algorithm. Default is 'quicksort'. Sorting algorithm. Default is 'quicksort' unless stable is defined.
order: list, optional order: list, optional
When `a` is a structured array, this argument specifies which For compatibility with numpy sort signature. Cannot be specified.
fields to compare first, second, and so on. This list does not stable: bool, optional
need to include all of the fields. Same as specifying kind = 'stable'. Cannot be specified at the same time as kind
Returns Returns
------- -------
...@@ -124,10 +142,12 @@ def sort(a, axis=-1, kind="quicksort", order=None): ...@@ -124,10 +142,12 @@ def sort(a, axis=-1, kind="quicksort", order=None):
A sorted copy of an array. A sorted copy of an array.
""" """
kind = _parse_sort_args(kind, order, stable)
if axis is None: if axis is None:
a = a.flatten() a = a.flatten()
axis = 0 axis = 0
return SortOp(kind, order)(a, axis) return SortOp(kind)(a, axis)
class ArgSortOp(Op): class ArgSortOp(Op):
...@@ -136,14 +156,10 @@ class ArgSortOp(Op): ...@@ -136,14 +156,10 @@ class ArgSortOp(Op):
""" """
__props__ = ("kind", "order") __props__ = ("kind",)
def __init__(self, kind, order=None): def __init__(self, kind: KIND):
self.kind = kind self.kind = kind
self.order = order
def __str__(self):
return self.__class__.__name__ + f"{{{self.kind}, {self.order}}}"
def make_node(self, input, axis=-1): def make_node(self, input, axis=-1):
input = as_tensor_variable(input) input = as_tensor_variable(input)
...@@ -158,7 +174,7 @@ class ArgSortOp(Op): ...@@ -158,7 +174,7 @@ class ArgSortOp(Op):
a, axis = inputs a, axis = inputs
z = output_storage[0] z = output_storage[0]
z[0] = _asarray( z[0] = _asarray(
np.argsort(a, int(axis), self.kind, self.order), np.argsort(a, int(axis), self.kind),
dtype=node.outputs[0].dtype, dtype=node.outputs[0].dtype,
) )
...@@ -192,7 +208,9 @@ class ArgSortOp(Op): ...@@ -192,7 +208,9 @@ class ArgSortOp(Op):
""" """
def argsort(a, axis=-1, kind="quicksort", order=None): def argsort(
a, axis=-1, kind: KIND | None = None, order=None, stable: bool | None = None
):
""" """
Returns the indices that would sort an array. Returns the indices that would sort an array.
...@@ -202,7 +220,8 @@ def argsort(a, axis=-1, kind="quicksort", order=None): ...@@ -202,7 +220,8 @@ def argsort(a, axis=-1, kind="quicksort", order=None):
order. order.
""" """
kind = _parse_sort_args(kind, order, stable)
if axis is None: if axis is None:
a = a.flatten() a = a.flatten()
axis = 0 axis = 0
return ArgSortOp(kind, order)(a, axis) return ArgSortOp(kind)(a, axis)
import numpy as np import numpy as np
import pytest
import pytensor import pytensor
from pytensor.tensor.sort import ArgSortOp, SortOp, argsort, sort from pytensor.tensor.sort import ArgSortOp, SortOp, argsort, sort
...@@ -65,13 +66,12 @@ class TestSort: ...@@ -65,13 +66,12 @@ class TestSort:
utt.assert_allclose(gv, gt) utt.assert_allclose(gv, gt)
def test5(self): def test5(self):
a1 = SortOp("mergesort", []) a1 = SortOp("mergesort")
a2 = SortOp("quicksort", []) a2 = SortOp("quicksort")
# All the below should give true
assert a1 != a2 assert a1 != a2
assert a1 == SortOp("mergesort", []) assert a1 == SortOp("mergesort")
assert a2 == SortOp("quicksort", []) assert a2 == SortOp("quicksort")
def test_None(self): def test_None(self):
a = dmatrix() a = dmatrix()
...@@ -208,14 +208,11 @@ def test_argsort(): ...@@ -208,14 +208,11 @@ def test_argsort():
utt.assert_allclose(gv, gt) utt.assert_allclose(gv, gt)
# Example 5 # Example 5
a = dmatrix() a1 = ArgSortOp("mergesort")
axis = lscalar() a2 = ArgSortOp("quicksort")
a1 = ArgSortOp("mergesort", [])
a2 = ArgSortOp("quicksort", [])
# All the below should give true
assert a1 != a2 assert a1 != a2
assert a1 == ArgSortOp("mergesort", []) assert a1 == ArgSortOp("mergesort")
assert a2 == ArgSortOp("quicksort", []) assert a2 == ArgSortOp("quicksort")
# Example 6: Testing axis=None # Example 6: Testing axis=None
a = dmatrix() a = dmatrix()
...@@ -237,3 +234,22 @@ def test_argsort_grad(): ...@@ -237,3 +234,22 @@ def test_argsort_grad():
data = rng.random((2, 3, 3)).astype(pytensor.config.floatX) data = rng.random((2, 3, 3)).astype(pytensor.config.floatX)
utt.verify_grad(lambda x: argsort(x, axis=2), [data]) utt.verify_grad(lambda x: argsort(x, axis=2), [data])
@pytest.mark.parametrize("func", (sort, argsort))
def test_parse_sort_args(func):
x = matrix("x")
assert func(x).owner.op.kind == "quicksort"
assert func(x, stable=True).owner.op.kind == "stable"
with pytest.raises(ValueError, match="kind must be one of"):
func(x, kind="hanoi")
with pytest.raises(
ValueError, match="kind and stable cannot be set at the same time"
):
func(x, kind="quicksort", stable=True)
with pytest.raises(ValueError, match="order argument is not applicable"):
func(x, order=[])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论