提交 7b609047 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Use `"JAX"` mode by default when testing jax dispatching

上级 a7738482
...@@ -5,15 +5,13 @@ import numpy as np ...@@ -5,15 +5,13 @@ import numpy as np
import pytest import pytest
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.compile.mode import Mode from pytensor.compile.mode import get_mode
from pytensor.compile.sharedvalue import SharedVariable, shared from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op, get_test_value from pytensor.graph.op import Op, get_test_value
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.ifelse import ifelse from pytensor.ifelse import ifelse
from pytensor.link.jax import JAXLinker
from pytensor.raise_op import assert_op from pytensor.raise_op import assert_op
from pytensor.tensor.type import dscalar, scalar, vector from pytensor.tensor.type import dscalar, scalar, vector
...@@ -27,12 +25,9 @@ def set_pytensor_flags(): ...@@ -27,12 +25,9 @@ def set_pytensor_flags():
jax = pytest.importorskip("jax") jax = pytest.importorskip("jax")
jax_mode = Mode( # We assume that the JAX mode includes all the rewrites needed to transpile JAX graphs
JAXLinker(), RewriteDatabaseQuery(include=["jax"], exclude=["cxx_only", "BlasOpt"]) jax_mode = get_mode("JAX")
) py_mode = get_mode("FAST_COMPILE")
py_mode = Mode(
"py", RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
)
def compare_jax_and_py( def compare_jax_and_py(
......
...@@ -71,13 +71,15 @@ def test_jax_Subtensor_dynamic(): ...@@ -71,13 +71,15 @@ def test_jax_Subtensor_dynamic():
def test_jax_Subtensor_boolean_mask(): def test_jax_Subtensor_boolean_mask():
"""JAX does not support resizing arrays with boolean masks.""" """JAX does not support resizing arrays with boolean masks."""
x_at = at.arange(-5, 5) x_at = at.vector("x", dtype="float64")
out_at = x_at[x_at < 0] out_at = x_at[x_at < 0]
assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor) assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([x_at], [out_at])
x_at_test = np.arange(-5, 5)
with pytest.raises(NotImplementedError, match="resizing arrays with boolean"): with pytest.raises(NotImplementedError, match="resizing arrays with boolean"):
out_fg = FunctionGraph([], [out_at]) compare_jax_and_py(out_fg, [x_at_test])
compare_jax_and_py(out_fg, [])
def test_jax_Subtensor_boolean_mask_reexpressible(): def test_jax_Subtensor_boolean_mask_reexpressible():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论