Unverified 提交 d175203b authored 作者: Harshvir Sandhu's avatar Harshvir Sandhu 提交者: GitHub

Add JAX support for SortOp (#657)

上级 ad55b69f
......@@ -22,6 +22,7 @@ from pytensor.tensor.basic import (
)
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.shape import Shape_i
from pytensor.tensor.sort import SortOp
ARANGE_CONCRETE_VALUE_ERROR = """JAX requires the arguments of `jax.numpy.arange`
......@@ -205,3 +206,11 @@ def jax_funcify_Tri(op, node, **kwargs):
return jnp.tri(*args, dtype=op.dtype)
return tri
@jax_funcify.register(SortOp)
def jax_funcify_Sort(op, **kwargs):
def sort(arr, axis):
return jnp.sort(arr, axis=axis)
return sort
......@@ -218,6 +218,15 @@ def test_tri():
compare_jax_and_py(fgraph, [])
@pytest.mark.parametrize("axis", [None, -1])
def test_sort(axis):
x = matrix("x", shape=(2, 2), dtype="float64")
out = pytensor.tensor.sort(x, axis=axis)
fgraph = FunctionGraph([x], [out])
arr = np.array([[1.0, 4.0], [5.0, 2.0]])
compare_jax_and_py(fgraph, [arr])
def test_tri_nonconcrete():
"""JAX cannot JIT-compile `jax.numpy.tri` when arguments are not concrete values."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论