提交 5f560e15 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Michael Osthege

Make tests compatible with newer version of JAX

上级 5affc308
...@@ -71,9 +71,7 @@ def compare_jax_and_py( ...@@ -71,9 +71,7 @@ def compare_jax_and_py(
if must_be_device_array: if must_be_device_array:
if isinstance(jax_res, list): if isinstance(jax_res, list):
assert all( assert all(isinstance(res, jax.Array) for res in jax_res)
isinstance(res, jax.interpreters.xla.DeviceArray) for res in jax_res
)
else: else:
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray) assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
...@@ -146,13 +144,13 @@ def test_shared(): ...@@ -146,13 +144,13 @@ def test_shared():
pytensor_jax_fn = function([], a, mode="JAX") pytensor_jax_fn = function([], a, mode="JAX")
jax_res = pytensor_jax_fn() jax_res = pytensor_jax_fn()
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray) assert isinstance(jax_res, jax.Array)
np.testing.assert_allclose(jax_res, a.get_value()) np.testing.assert_allclose(jax_res, a.get_value())
pytensor_jax_fn = function([], a * 2, mode="JAX") pytensor_jax_fn = function([], a * 2, mode="JAX")
jax_res = pytensor_jax_fn() jax_res = pytensor_jax_fn()
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray) assert isinstance(jax_res, jax.Array)
np.testing.assert_allclose(jax_res, a.get_value() * 2) np.testing.assert_allclose(jax_res, a.get_value() * 2)
# Changed the shared value and make sure that the JAX-compiled # Changed the shared value and make sure that the JAX-compiled
...@@ -161,7 +159,7 @@ def test_shared(): ...@@ -161,7 +159,7 @@ def test_shared():
a.set_value(new_a_value) a.set_value(new_a_value)
jax_res = pytensor_jax_fn() jax_res = pytensor_jax_fn()
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray) assert isinstance(jax_res, jax.Array)
np.testing.assert_allclose(jax_res, new_a_value * 2) np.testing.assert_allclose(jax_res, new_a_value * 2)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论