提交 a7738482 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Return PyTensor JAX function in `compare_jax_and_py` helper

上级 e68bc6f0
...@@ -89,7 +89,7 @@ def compare_jax_and_py( ...@@ -89,7 +89,7 @@ def compare_jax_and_py(
else: else:
assert_fn(jax_res, py_res) assert_fn(jax_res, py_res)
return jax_res return pytensor_jax_fn, jax_res
def test_jax_FunctionGraph_once(): def test_jax_FunctionGraph_once():
......
...@@ -33,7 +33,7 @@ def test_jax_basic(): ...@@ -33,7 +33,7 @@ def test_jax_basic():
np.tile(np.arange(10), (10, 1)).astype(config.floatX), np.tile(np.arange(10), (10, 1)).astype(config.floatX),
np.tile(np.arange(10, 20), (10, 1)).astype(config.floatX), np.tile(np.arange(10, 20), (10, 1)).astype(config.floatX),
] ]
(jax_res,) = compare_jax_and_py(out_fg, test_input_vals) _, [jax_res] = compare_jax_and_py(out_fg, test_input_vals)
# Confirm that the `Subtensor` slice operations are correct # Confirm that the `Subtensor` slice operations are correct
assert jax_res.shape == (5, 3) assert jax_res.shape == (5, 3)
......
...@@ -15,7 +15,7 @@ def test_jax_Alloc(): ...@@ -15,7 +15,7 @@ def test_jax_Alloc():
x = at.alloc(0.0, 2, 3) x = at.alloc(0.0, 2, 3)
x_fg = FunctionGraph([], [x]) x_fg = FunctionGraph([], [x])
(jax_res,) = compare_jax_and_py(x_fg, []) _, [jax_res] = compare_jax_and_py(x_fg, [])
assert jax_res.shape == (2, 3) assert jax_res.shape == (2, 3)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论