提交 35f4706c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Move JAX sort dispatch to its own module

上级 2307d877
...@@ -14,5 +14,6 @@ import pytensor.link.jax.dispatch.elemwise ...@@ -14,5 +14,6 @@ import pytensor.link.jax.dispatch.elemwise
import pytensor.link.jax.dispatch.scan import pytensor.link.jax.dispatch.scan
import pytensor.link.jax.dispatch.sparse import pytensor.link.jax.dispatch.sparse
import pytensor.link.jax.dispatch.blockwise import pytensor.link.jax.dispatch.blockwise
import pytensor.link.jax.dispatch.sort
# isort: on # isort: on
from jax import numpy as jnp
from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.sort import SortOp
@jax_funcify.register(SortOp)
def jax_funcify_Sort(op, **kwargs):
def sort(arr, axis):
return jnp.sort(arr, axis=axis)
return sort
...@@ -22,7 +22,6 @@ from pytensor.tensor.basic import ( ...@@ -22,7 +22,6 @@ from pytensor.tensor.basic import (
) )
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.shape import Shape_i from pytensor.tensor.shape import Shape_i
from pytensor.tensor.sort import SortOp
ARANGE_CONCRETE_VALUE_ERROR = """JAX requires the arguments of `jax.numpy.arange` ARANGE_CONCRETE_VALUE_ERROR = """JAX requires the arguments of `jax.numpy.arange`
...@@ -206,11 +205,3 @@ def jax_funcify_Tri(op, node, **kwargs): ...@@ -206,11 +205,3 @@ def jax_funcify_Tri(op, node, **kwargs):
return jnp.tri(*args, dtype=op.dtype) return jnp.tri(*args, dtype=op.dtype)
return tri return tri
@jax_funcify.register(SortOp)
def jax_funcify_Sort(op, **kwargs):
def sort(arr, axis):
return jnp.sort(arr, axis=axis)
return sort
import numpy as np
import pytest
from pytensor.graph import FunctionGraph
from pytensor.tensor import matrix
from pytensor.tensor.sort import sort
from tests.link.jax.test_basic import compare_jax_and_py
@pytest.mark.parametrize("axis", [None, -1])
def test_sort(axis):
x = matrix("x", shape=(2, 2), dtype="float64")
out = sort(x, axis=axis)
fgraph = FunctionGraph([x], [out])
arr = np.array([[1.0, 4.0], [5.0, 2.0]])
compare_jax_and_py(fgraph, [arr])
...@@ -218,15 +218,6 @@ def test_tri(): ...@@ -218,15 +218,6 @@ def test_tri():
compare_jax_and_py(fgraph, []) compare_jax_and_py(fgraph, [])
@pytest.mark.parametrize("axis", [None, -1])
def test_sort(axis):
x = matrix("x", shape=(2, 2), dtype="float64")
out = pytensor.tensor.sort(x, axis=axis)
fgraph = FunctionGraph([x], [out])
arr = np.array([[1.0, 4.0], [5.0, 2.0]])
compare_jax_and_py(fgraph, [arr])
def test_tri_nonconcrete(): def test_tri_nonconcrete():
"""JAX cannot JIT-compile `jax.numpy.tri` when arguments are not concrete values.""" """JAX cannot JIT-compile `jax.numpy.tri` when arguments are not concrete values."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论