提交 26ba6733 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement JAX dispatch for Argsort

上级 3dd1f80f
from jax import numpy as jnp from jax import numpy as jnp
from pytensor.link.jax.dispatch import jax_funcify from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.sort import SortOp from pytensor.tensor.sort import ArgSortOp, SortOp
@jax_funcify.register(SortOp) @jax_funcify.register(SortOp)
...@@ -12,3 +12,13 @@ def jax_funcify_Sort(op, **kwargs): ...@@ -12,3 +12,13 @@ def jax_funcify_Sort(op, **kwargs):
return jnp.sort(arr, axis=axis, stable=stable) return jnp.sort(arr, axis=axis, stable=stable)
return sort return sort
@jax_funcify.register(ArgSortOp)
def jax_funcify_ArgSort(op, **kwargs):
stable = op.kind == "stable"
def argsort(arr, axis):
return jnp.argsort(arr, axis=axis, stable=stable)
return argsort
...@@ -3,14 +3,15 @@ import pytest ...@@ -3,14 +3,15 @@ import pytest
from pytensor.graph import FunctionGraph from pytensor.graph import FunctionGraph
from pytensor.tensor import matrix from pytensor.tensor import matrix
from pytensor.tensor.sort import sort from pytensor.tensor.sort import argsort, sort
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
@pytest.mark.parametrize("axis", [None, -1]) @pytest.mark.parametrize("axis", [None, -1])
def test_sort(axis): @pytest.mark.parametrize("func", (sort, argsort))
def test_sort(func, axis):
x = matrix("x", shape=(2, 2), dtype="float64") x = matrix("x", shape=(2, 2), dtype="float64")
out = sort(x, axis=axis) out = func(x, axis=axis)
fgraph = FunctionGraph([x], [out]) fgraph = FunctionGraph([x], [out])
arr = np.array([[1.0, 4.0], [5.0, 2.0]]) arr = np.array([[1.0, 4.0], [5.0, 2.0]])
compare_jax_and_py(fgraph, [arr]) compare_jax_and_py(fgraph, [arr])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论