提交 afe290db authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Update signature of compare_jax_and_py helper

上级 03858395
from functools import partial
from typing import Optional
import numpy as np
import pytest
......@@ -64,10 +65,10 @@ def set_aesara_flags():
def compare_jax_and_py(
fgraph,
inputs,
assert_fn=None,
must_be_device_array=True,
fgraph: FunctionGraph,
test_inputs: iter,
assert_fn: Optional[callable] = None,
must_be_device_array: bool = True,
):
"""Function to compare python graph output and jax compiled output for testing equality
......@@ -79,8 +80,8 @@ def compare_jax_and_py(
----------
fgraph: FunctionGraph
Aesara function Graph object
inputs: iter
Inputs for function graph
test_inputs: iter
Numerical inputs for testing the function graph
assert_fn: func, opt
Assert function used to check for equality between python and jax. If not
provided uses np.testing.assert_allclose
......@@ -98,7 +99,7 @@ def compare_jax_and_py(
fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)]
aesara_jax_fn = function(fn_inputs, fgraph.outputs, mode=jax_mode)
jax_res = aesara_jax_fn(*inputs)
jax_res = aesara_jax_fn(*test_inputs)
if must_be_device_array:
if isinstance(jax_res, list):
......@@ -109,7 +110,7 @@ def compare_jax_and_py(
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
aesara_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode)
py_res = aesara_py_fn(*inputs)
py_res = aesara_py_fn(*test_inputs)
if len(fgraph.outputs) > 1:
for j, p in zip(jax_res, py_res):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论