提交 9ae884dd authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Support updates and more input types in compare_numba_and_py

上级 f8771c13
import contextlib
import inspect
from typing import TYPE_CHECKING, Callable, Optional, Sequence, Tuple, Union
from unittest import mock
import numba
......@@ -31,6 +32,11 @@ from aesara.tensor.elemwise import Elemwise
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
if TYPE_CHECKING:
from aesara.graph.basic import Variable
from aesara.tensor import TensorLike
class MyType(Type):
def filter(self, data):
return data
......@@ -98,7 +104,7 @@ def compare_shape_dtype(x, y):
return x.shape == y.shape and x.dtype == y.dtype
def eval_python_only(fn_inputs, fgraph, inputs, mode=numba_mode):
def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode):
"""Evaluate the Numba implementation in pure Python for coverage purposes."""
def py_tuple_setitem(t, i, v):
......@@ -163,7 +169,7 @@ def eval_python_only(fn_inputs, fgraph, inputs, mode=numba_mode):
aesara_numba_fn = function(
fn_inputs,
fgraph.outputs,
fn_outputs,
mode=mode,
accept_inplace=True,
)
......@@ -171,7 +177,12 @@ def eval_python_only(fn_inputs, fgraph, inputs, mode=numba_mode):
def compare_numba_and_py(
fgraph, inputs, assert_fn=None, numba_mode=numba_mode, py_mode=py_mode
fgraph: Union[FunctionGraph, Tuple[Sequence["Variable"], Sequence["Variable"]]],
inputs: Sequence["TensorLike"],
assert_fn: Optional[Callable] = None,
numba_mode=numba_mode,
py_mode=py_mode,
updates=None,
):
"""Function to compare python graph output and Numba compiled output for testing equality
......@@ -181,13 +192,15 @@ def compare_numba_and_py(
Parameters
----------
fgraph: FunctionGraph
Aesara function Graph object
inputs: iter
Inputs for function graph
assert_fn: func, opt
fgraph
`FunctionGraph` or inputs to compare.
inputs
Numeric inputs to be passed to the compiled graphs.
assert_fn
Assert function used to check for equality between python and Numba. If not
provided uses np.testing.assert_allclose
provided uses `np.testing.assert_allclose`.
updates
Updates to be passed to `aesara.function`.
"""
if assert_fn is None:
......@@ -197,25 +210,32 @@ def compare_numba_and_py(
x, y
)
fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)]
if isinstance(fgraph, tuple):
fn_inputs, fn_outputs = fgraph
else:
fn_inputs = fgraph.inputs
fn_outputs = fgraph.outputs
fn_inputs = [i for i in fn_inputs if not isinstance(i, SharedVariable)]
aesara_py_fn = function(
fn_inputs, fgraph.outputs, mode=py_mode, accept_inplace=True
fn_inputs, fn_outputs, mode=py_mode, accept_inplace=True, updates=updates
)
py_res = aesara_py_fn(*inputs)
aesara_numba_fn = function(
fn_inputs,
fgraph.outputs,
fn_outputs,
mode=numba_mode,
accept_inplace=True,
updates=updates,
)
numba_res = aesara_numba_fn(*inputs)
# Get some coverage
eval_python_only(fn_inputs, fgraph, inputs, mode=numba_mode)
eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode)
if len(fgraph.outputs) > 1:
if len(fn_outputs) > 1:
for j, p in zip(numba_res, py_res):
assert_fn(j, p)
else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论