提交 5f560e15 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Michael Osthege

Make tests compatible with newer version of JAX

上级 5affc308
......@@ -71,9 +71,7 @@ def compare_jax_and_py(
if must_be_device_array:
if isinstance(jax_res, list):
assert all(
isinstance(res, jax.interpreters.xla.DeviceArray) for res in jax_res
)
assert all(isinstance(res, jax.Array) for res in jax_res)
else:
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
......@@ -146,13 +144,13 @@ def test_shared():
pytensor_jax_fn = function([], a, mode="JAX")
jax_res = pytensor_jax_fn()
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
assert isinstance(jax_res, jax.Array)
np.testing.assert_allclose(jax_res, a.get_value())
pytensor_jax_fn = function([], a * 2, mode="JAX")
jax_res = pytensor_jax_fn()
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
assert isinstance(jax_res, jax.Array)
np.testing.assert_allclose(jax_res, a.get_value() * 2)
# Changed the shared value and make sure that the JAX-compiled
......@@ -161,7 +159,7 @@ def test_shared():
a.set_value(new_a_value)
jax_res = pytensor_jax_fn()
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
assert isinstance(jax_res, jax.Array)
np.testing.assert_allclose(jax_res, new_a_value * 2)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论