提交 0356220c authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make compilation modes configurable in compare_numba_and_py

上级 69c10443
......@@ -98,7 +98,7 @@ def compare_shape_dtype(x, y):
return x.shape == y.shape and x.dtype == y.dtype
def eval_python_only(fn_inputs, fgraph, inputs):
def eval_python_only(fn_inputs, fgraph, inputs, mode=numba_mode):
"""Evaluate the Numba implementation in pure Python for coverage purposes."""
def py_tuple_setitem(t, i, v):
......@@ -164,13 +164,15 @@ def eval_python_only(fn_inputs, fgraph, inputs):
aesara_numba_fn = function(
fn_inputs,
fgraph.outputs,
mode=numba_mode,
mode=mode,
accept_inplace=True,
)
_ = aesara_numba_fn(*inputs)
def compare_numba_and_py(fgraph, inputs, assert_fn=None):
def compare_numba_and_py(
fgraph, inputs, assert_fn=None, numba_mode=numba_mode, py_mode=py_mode
):
"""Function to compare python graph output and Numba compiled output for testing equality
In the tests below computational graphs are defined in Aesara. These graphs are then passed to
......@@ -211,7 +213,7 @@ def compare_numba_and_py(fgraph, inputs, assert_fn=None):
numba_res = aesara_numba_fn(*inputs)
# Get some coverage
eval_python_only(fn_inputs, fgraph, inputs)
eval_python_only(fn_inputs, fgraph, inputs, mode=numba_mode)
if len(fgraph.outputs) > 1:
for j, p in zip(numba_res, py_res):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论