提交 fc193d77 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Jesse Grabowski

Actually check types and dtypes match in numba testing helper

NOTE: CI failing at this point
上级 4829455b
...@@ -260,9 +260,12 @@ def compare_numba_and_py( ...@@ -260,9 +260,12 @@ def compare_numba_and_py(
if assert_fn is None: if assert_fn is None:
def assert_fn(x, y): def assert_fn(x, y):
return np.testing.assert_allclose(x, y, rtol=1e-4) and compare_shape_dtype( np.testing.assert_allclose(x, y, rtol=1e-4, strict=True)
x, y # Make sure we don't have one input be a np.ndarray while the other is not
) if isinstance(x, np.ndarray):
assert isinstance(y, np.ndarray), "y is not a NumPy array, but x is"
else:
assert not isinstance(y, np.ndarray), "y is a NumPy array, but x is not"
if any( if any(
inp.owner is not None inp.owner is not None
...@@ -295,8 +298,8 @@ def compare_numba_and_py( ...@@ -295,8 +298,8 @@ def compare_numba_and_py(
test_inputs_copy = (inp.copy() for inp in test_inputs) if inplace else test_inputs test_inputs_copy = (inp.copy() for inp in test_inputs) if inplace else test_inputs
numba_res = pytensor_numba_fn(*test_inputs_copy) numba_res = pytensor_numba_fn(*test_inputs_copy)
if isinstance(graph_outputs, tuple | list): if isinstance(graph_outputs, tuple | list):
for j, p in zip(numba_res, py_res, strict=True): for numba_res_i, python_res_i in zip(numba_res, py_res, strict=True):
assert_fn(j, p) assert_fn(numba_res_i, python_res_i)
else: else:
assert_fn(numba_res, py_res) assert_fn(numba_res, py_res)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论