提交 68a8e224 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Factor out Python-only evaluation in tests.link.test_numba.compare_numba_and_py

上级 ba436b05
import contextlib
import inspect
from unittest import mock
import numba
......@@ -60,6 +61,66 @@ def compare_shape_dtype(x, y):
return x.shape == y.shape and x.dtype == y.dtype
def eval_python_only(fn_inputs, fgraph, inputs):
"""Evaluate the Numba implementation in pure Python for coverage purposes."""
def py_tuple_setitem(t, i, v):
ll = list(t)
ll[i] = v
return tuple(ll)
def py_to_scalar(x):
if isinstance(x, np.ndarray):
return x.item()
else:
return x
def njit_noop(*args, **kwargs):
if len(args) == 1:
return args[0]
else:
return lambda x: x
def vectorize_noop(*args, **kwargs):
def wrap(fn):
# `numba.vectorize` allows an `out` positional argument. We need
# to account for that
sig = inspect.signature(fn)
nparams = len(sig.parameters)
def inner_vec(*args):
if len(args) > nparams:
out = args[-1]
out[:] = fn(*args[:nparams])
else:
return fn(*args)
return inner_vec
return wrap
with mock.patch("aesara.link.numba.dispatch.numba.njit", njit_noop), mock.patch(
"aesara.link.numba.dispatch.numba.vectorize",
vectorize_noop,
), mock.patch(
"aesara.link.numba.dispatch.tuple_setitem", py_tuple_setitem
), mock.patch(
"aesara.link.numba.dispatch.direct_cast", lambda x, dtype: x
), mock.patch(
"aesara.link.numba.dispatch.numba.np.numpy_support.from_dtype",
lambda dtype: dtype,
), mock.patch(
"aesara.link.numba.dispatch.to_scalar", py_to_scalar
):
aesara_numba_fn = function(
fn_inputs,
fgraph.outputs,
mode=numba_mode,
accept_inplace=True,
)
_ = aesara_numba_fn(*inputs)
def compare_numba_and_py(
fgraph,
inputs,
......@@ -104,38 +165,8 @@ def compare_numba_and_py(
)
numba_res = aesara_numba_fn(*inputs)
# We evaluate the Numba implementation in pure Python for coverage
# purposes.
def py_tuple_setitem(t, i, v):
l = list(t)
l[i] = v
return tuple(l)
def py_to_scalar(x):
if isinstance(x, np.ndarray):
return x.item()
else:
return x
with mock.patch("aesara.link.numba.dispatch.numba.njit", lambda x: x), mock.patch(
"aesara.link.numba.dispatch.numba.vectorize", lambda x: x
), mock.patch(
"aesara.link.numba.dispatch.tuple_setitem", py_tuple_setitem
), mock.patch(
"aesara.link.numba.dispatch.direct_cast", lambda x, dtype: x
), mock.patch(
"aesara.link.numba.dispatch.numba.np.numpy_support.from_dtype",
lambda dtype: dtype,
), mock.patch(
"aesara.link.numba.dispatch.to_scalar", py_to_scalar
):
aesara_numba_fn = function(
fn_inputs,
fgraph.outputs,
mode=numba_mode,
accept_inplace=True,
)
_ = aesara_numba_fn(*inputs)
# Get some coverage
eval_python_only(fn_inputs, fgraph, inputs)
if len(fgraph.outputs) > 1:
for j, p in zip(numba_res, py_res):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论