提交 a7738482 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Return PyTensor JAX function in `compare_jax_and_py` helper

上级 e68bc6f0
......@@ -89,7 +89,7 @@ def compare_jax_and_py(
else:
assert_fn(jax_res, py_res)
return jax_res
return pytensor_jax_fn, jax_res
def test_jax_FunctionGraph_once():
......
......@@ -33,7 +33,7 @@ def test_jax_basic():
np.tile(np.arange(10), (10, 1)).astype(config.floatX),
np.tile(np.arange(10, 20), (10, 1)).astype(config.floatX),
]
(jax_res,) = compare_jax_and_py(out_fg, test_input_vals)
_, [jax_res] = compare_jax_and_py(out_fg, test_input_vals)
# Confirm that the `Subtensor` slice operations are correct
assert jax_res.shape == (5, 3)
......
......@@ -15,7 +15,7 @@ def test_jax_Alloc():
x = at.alloc(0.0, 2, 3)
x_fg = FunctionGraph([], [x])
(jax_res,) = compare_jax_and_py(x_fg, [])
_, [jax_res] = compare_jax_and_py(x_fg, [])
assert jax_res.shape == (2, 3)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论