提交 9ae884dd authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Support updates and more input types in compare_numba_and_py

上级 f8771c13
import contextlib import contextlib
import inspect import inspect
from typing import TYPE_CHECKING, Callable, Optional, Sequence, Tuple, Union
from unittest import mock from unittest import mock
import numba import numba
...@@ -31,6 +32,11 @@ from aesara.tensor.elemwise import Elemwise ...@@ -31,6 +32,11 @@ from aesara.tensor.elemwise import Elemwise
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
if TYPE_CHECKING:
from aesara.graph.basic import Variable
from aesara.tensor import TensorLike
class MyType(Type): class MyType(Type):
def filter(self, data): def filter(self, data):
return data return data
...@@ -98,7 +104,7 @@ def compare_shape_dtype(x, y): ...@@ -98,7 +104,7 @@ def compare_shape_dtype(x, y):
return x.shape == y.shape and x.dtype == y.dtype return x.shape == y.shape and x.dtype == y.dtype
def eval_python_only(fn_inputs, fgraph, inputs, mode=numba_mode): def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode):
"""Evaluate the Numba implementation in pure Python for coverage purposes.""" """Evaluate the Numba implementation in pure Python for coverage purposes."""
def py_tuple_setitem(t, i, v): def py_tuple_setitem(t, i, v):
...@@ -163,7 +169,7 @@ def eval_python_only(fn_inputs, fgraph, inputs, mode=numba_mode): ...@@ -163,7 +169,7 @@ def eval_python_only(fn_inputs, fgraph, inputs, mode=numba_mode):
aesara_numba_fn = function( aesara_numba_fn = function(
fn_inputs, fn_inputs,
fgraph.outputs, fn_outputs,
mode=mode, mode=mode,
accept_inplace=True, accept_inplace=True,
) )
...@@ -171,7 +177,12 @@ def eval_python_only(fn_inputs, fgraph, inputs, mode=numba_mode): ...@@ -171,7 +177,12 @@ def eval_python_only(fn_inputs, fgraph, inputs, mode=numba_mode):
def compare_numba_and_py( def compare_numba_and_py(
fgraph, inputs, assert_fn=None, numba_mode=numba_mode, py_mode=py_mode fgraph: Union[FunctionGraph, Tuple[Sequence["Variable"], Sequence["Variable"]]],
inputs: Sequence["TensorLike"],
assert_fn: Optional[Callable] = None,
numba_mode=numba_mode,
py_mode=py_mode,
updates=None,
): ):
"""Function to compare python graph output and Numba compiled output for testing equality """Function to compare python graph output and Numba compiled output for testing equality
...@@ -181,13 +192,15 @@ def compare_numba_and_py( ...@@ -181,13 +192,15 @@ def compare_numba_and_py(
Parameters Parameters
---------- ----------
fgraph: FunctionGraph fgraph
Aesara function Graph object `FunctionGraph` or inputs to compare.
inputs: iter inputs
Inputs for function graph Numeric inputs to be passed to the compiled graphs.
assert_fn: func, opt assert_fn
Assert function used to check for equality between python and Numba. If not Assert function used to check for equality between python and Numba. If not
provided uses np.testing.assert_allclose provided uses `np.testing.assert_allclose`.
updates
Updates to be passed to `aesara.function`.
""" """
if assert_fn is None: if assert_fn is None:
...@@ -197,25 +210,32 @@ def compare_numba_and_py( ...@@ -197,25 +210,32 @@ def compare_numba_and_py(
x, y x, y
) )
fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)] if isinstance(fgraph, tuple):
fn_inputs, fn_outputs = fgraph
else:
fn_inputs = fgraph.inputs
fn_outputs = fgraph.outputs
fn_inputs = [i for i in fn_inputs if not isinstance(i, SharedVariable)]
aesara_py_fn = function( aesara_py_fn = function(
fn_inputs, fgraph.outputs, mode=py_mode, accept_inplace=True fn_inputs, fn_outputs, mode=py_mode, accept_inplace=True, updates=updates
) )
py_res = aesara_py_fn(*inputs) py_res = aesara_py_fn(*inputs)
aesara_numba_fn = function( aesara_numba_fn = function(
fn_inputs, fn_inputs,
fgraph.outputs, fn_outputs,
mode=numba_mode, mode=numba_mode,
accept_inplace=True, accept_inplace=True,
updates=updates,
) )
numba_res = aesara_numba_fn(*inputs) numba_res = aesara_numba_fn(*inputs)
# Get some coverage # Get some coverage
eval_python_only(fn_inputs, fgraph, inputs, mode=numba_mode) eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode)
if len(fgraph.outputs) > 1: if len(fn_outputs) > 1:
for j, p in zip(numba_res, py_res): for j, p in zip(numba_res, py_res):
assert_fn(j, p) assert_fn(j, p)
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论