提交 e68bc6f0 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow overriding modes used in `compare_jax_and_py` helper

上级 29169d43
......@@ -40,6 +40,8 @@ def compare_jax_and_py(
test_inputs: Iterable,
assert_fn: Optional[Callable] = None,
must_be_device_array: bool = True,
jax_mode=jax_mode,
py_mode=py_mode,
):
"""Function to compare python graph output and jax compiled output for testing equality
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论