提交 f254492b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Ricardo Vieira

Use separate DB queries for each JAX test mode

上级 e108fab5
...@@ -27,9 +27,12 @@ def set_pytensor_flags(): ...@@ -27,9 +27,12 @@ def set_pytensor_flags():
jax = pytest.importorskip("jax") jax = pytest.importorskip("jax")
opts = RewriteDatabaseQuery(include=["jax"], exclude=["cxx_only", "BlasOpt"]) jax_mode = Mode(
jax_mode = Mode(JAXLinker(), opts) JAXLinker(), RewriteDatabaseQuery(include=["jax"], exclude=["cxx_only", "BlasOpt"])
py_mode = Mode("py", opts) )
py_mode = Mode(
"py", RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
)
def compare_jax_and_py( def compare_jax_and_py(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论