提交 0356220c authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make compilation modes configurable in compare_numba_and_py

上级 69c10443
...@@ -98,7 +98,7 @@ def compare_shape_dtype(x, y): ...@@ -98,7 +98,7 @@ def compare_shape_dtype(x, y):
return x.shape == y.shape and x.dtype == y.dtype return x.shape == y.shape and x.dtype == y.dtype
def eval_python_only(fn_inputs, fgraph, inputs): def eval_python_only(fn_inputs, fgraph, inputs, mode=numba_mode):
"""Evaluate the Numba implementation in pure Python for coverage purposes.""" """Evaluate the Numba implementation in pure Python for coverage purposes."""
def py_tuple_setitem(t, i, v): def py_tuple_setitem(t, i, v):
...@@ -164,13 +164,15 @@ def eval_python_only(fn_inputs, fgraph, inputs): ...@@ -164,13 +164,15 @@ def eval_python_only(fn_inputs, fgraph, inputs):
aesara_numba_fn = function( aesara_numba_fn = function(
fn_inputs, fn_inputs,
fgraph.outputs, fgraph.outputs,
mode=numba_mode, mode=mode,
accept_inplace=True, accept_inplace=True,
) )
_ = aesara_numba_fn(*inputs) _ = aesara_numba_fn(*inputs)
def compare_numba_and_py(fgraph, inputs, assert_fn=None): def compare_numba_and_py(
fgraph, inputs, assert_fn=None, numba_mode=numba_mode, py_mode=py_mode
):
"""Function to compare python graph output and Numba compiled output for testing equality """Function to compare python graph output and Numba compiled output for testing equality
In the tests below computational graphs are defined in Aesara. These graphs are then passed to In the tests below computational graphs are defined in Aesara. These graphs are then passed to
...@@ -211,7 +213,7 @@ def compare_numba_and_py(fgraph, inputs, assert_fn=None): ...@@ -211,7 +213,7 @@ def compare_numba_and_py(fgraph, inputs, assert_fn=None):
numba_res = aesara_numba_fn(*inputs) numba_res = aesara_numba_fn(*inputs)
# Get some coverage # Get some coverage
eval_python_only(fn_inputs, fgraph, inputs) eval_python_only(fn_inputs, fgraph, inputs, mode=numba_mode)
if len(fgraph.outputs) > 1: if len(fgraph.outputs) > 1:
for j, p in zip(numba_res, py_res): for j, p in zip(numba_res, py_res):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论