提交 58840bae authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix JAX test check

上级 0824dba8
......@@ -76,7 +76,7 @@ def compare_jax_and_py(
if isinstance(jax_res, list):
assert all(isinstance(res, jax.Array) for res in jax_res)
else:
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
assert isinstance(jax_res, jax.Array)
pytensor_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode)
py_res = pytensor_py_fn(*test_inputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论