提交 26ba6733 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement JAX dispatch for Argsort

上级 3dd1f80f
from jax import numpy as jnp
from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.sort import SortOp
from pytensor.tensor.sort import ArgSortOp, SortOp
@jax_funcify.register(SortOp)
......@@ -12,3 +12,13 @@ def jax_funcify_Sort(op, **kwargs):
return jnp.sort(arr, axis=axis, stable=stable)
return sort
@jax_funcify.register(ArgSortOp)
def jax_funcify_ArgSort(op, **kwargs):
stable = op.kind == "stable"
def argsort(arr, axis):
return jnp.argsort(arr, axis=axis, stable=stable)
return argsort
......@@ -3,14 +3,15 @@ import pytest
from pytensor.graph import FunctionGraph
from pytensor.tensor import matrix
from pytensor.tensor.sort import sort
from pytensor.tensor.sort import argsort, sort
from tests.link.jax.test_basic import compare_jax_and_py
@pytest.mark.parametrize("axis", [None, -1])
def test_sort(axis):
@pytest.mark.parametrize("func", (sort, argsort))
def test_sort(func, axis):
x = matrix("x", shape=(2, 2), dtype="float64")
out = sort(x, axis=axis)
out = func(x, axis=axis)
fgraph = FunctionGraph([x], [out])
arr = np.array([[1.0, 4.0], [5.0, 2.0]])
compare_jax_and_py(fgraph, [arr])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论