提交 afe290db authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Update signature of compare_jax_and_py helper

上级 03858395
from functools import partial from functools import partial
from typing import Optional
import numpy as np import numpy as np
import pytest import pytest
...@@ -64,10 +65,10 @@ def set_aesara_flags(): ...@@ -64,10 +65,10 @@ def set_aesara_flags():
def compare_jax_and_py( def compare_jax_and_py(
fgraph, fgraph: FunctionGraph,
inputs, test_inputs: iter,
assert_fn=None, assert_fn: Optional[callable] = None,
must_be_device_array=True, must_be_device_array: bool = True,
): ):
"""Function to compare python graph output and jax compiled output for testing equality """Function to compare python graph output and jax compiled output for testing equality
...@@ -79,8 +80,8 @@ def compare_jax_and_py( ...@@ -79,8 +80,8 @@ def compare_jax_and_py(
---------- ----------
fgraph: FunctionGraph fgraph: FunctionGraph
Aesara function Graph object Aesara function Graph object
inputs: iter test_inputs: iter
Inputs for function graph Numerical inputs for testing the function graph
assert_fn: func, opt assert_fn: func, opt
Assert function used to check for equality between python and jax. If not Assert function used to check for equality between python and jax. If not
provided uses np.testing.assert_allclose provided uses np.testing.assert_allclose
...@@ -98,7 +99,7 @@ def compare_jax_and_py( ...@@ -98,7 +99,7 @@ def compare_jax_and_py(
fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)] fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)]
aesara_jax_fn = function(fn_inputs, fgraph.outputs, mode=jax_mode) aesara_jax_fn = function(fn_inputs, fgraph.outputs, mode=jax_mode)
jax_res = aesara_jax_fn(*inputs) jax_res = aesara_jax_fn(*test_inputs)
if must_be_device_array: if must_be_device_array:
if isinstance(jax_res, list): if isinstance(jax_res, list):
...@@ -109,7 +110,7 @@ def compare_jax_and_py( ...@@ -109,7 +110,7 @@ def compare_jax_and_py(
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray) assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
aesara_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode) aesara_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode)
py_res = aesara_py_fn(*inputs) py_res = aesara_py_fn(*test_inputs)
if len(fgraph.outputs) > 1: if len(fgraph.outputs) > 1:
for j, p in zip(jax_res, py_res): for j, p in zip(jax_res, py_res):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论