提交 7b609047 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Use `"JAX"` mode by default when testing jax dispatching

上级 a7738482
......@@ -5,15 +5,13 @@ import numpy as np
import pytest
from pytensor.compile.function import function
from pytensor.compile.mode import Mode
from pytensor.compile.mode import get_mode
from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op, get_test_value
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.ifelse import ifelse
from pytensor.link.jax import JAXLinker
from pytensor.raise_op import assert_op
from pytensor.tensor.type import dscalar, scalar, vector
......@@ -27,12 +25,9 @@ def set_pytensor_flags():
jax = pytest.importorskip("jax")
jax_mode = Mode(
JAXLinker(), RewriteDatabaseQuery(include=["jax"], exclude=["cxx_only", "BlasOpt"])
)
py_mode = Mode(
"py", RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
)
# We assume that the JAX mode includes all the rewrites needed to transpile JAX graphs
jax_mode = get_mode("JAX")
py_mode = get_mode("FAST_COMPILE")
def compare_jax_and_py(
......
......@@ -71,13 +71,15 @@ def test_jax_Subtensor_dynamic():
def test_jax_Subtensor_boolean_mask():
"""JAX does not support resizing arrays with boolean masks."""
x_at = at.arange(-5, 5)
x_at = at.vector("x", dtype="float64")
out_at = x_at[x_at < 0]
assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([x_at], [out_at])
x_at_test = np.arange(-5, 5)
with pytest.raises(NotImplementedError, match="resizing arrays with boolean"):
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
compare_jax_and_py(out_fg, [x_at_test])
def test_jax_Subtensor_boolean_mask_reexpressible():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论