提交 fc193d77 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Jesse Grabowski

Actually check types and dtypes match in numba testing helper

NOTE: CI failing at this point
上级 4829455b
......@@ -260,9 +260,12 @@ def compare_numba_and_py(
if assert_fn is None:
def assert_fn(x, y):
return np.testing.assert_allclose(x, y, rtol=1e-4) and compare_shape_dtype(
x, y
)
np.testing.assert_allclose(x, y, rtol=1e-4, strict=True)
# Make sure we don't have one input be a np.ndarray while the other is not
if isinstance(x, np.ndarray):
assert isinstance(y, np.ndarray), "y is not a NumPy array, but x is"
else:
assert not isinstance(y, np.ndarray), "y is a NumPy array, but x is not"
if any(
inp.owner is not None
......@@ -295,8 +298,8 @@ def compare_numba_and_py(
test_inputs_copy = (inp.copy() for inp in test_inputs) if inplace else test_inputs
numba_res = pytensor_numba_fn(*test_inputs_copy)
if isinstance(graph_outputs, tuple | list):
for j, p in zip(numba_res, py_res, strict=True):
assert_fn(j, p)
for numba_res_i, python_res_i in zip(numba_res, py_res, strict=True):
assert_fn(numba_res_i, python_res_i)
else:
assert_fn(numba_res, py_res)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论