提交 d4a0433d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

compare_numba_and_py: Check for accidental input mutation

上级 fbee4164
import contextlib import contextlib
import copy
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from unittest import mock from unittest import mock
...@@ -232,8 +233,19 @@ def compare_numba_and_py( ...@@ -232,8 +233,19 @@ def compare_numba_and_py(
): ):
raise ValueError("Inputs must be root variables") raise ValueError("Inputs must be root variables")
test_input_deepcopy = None
if not inplace:
test_input_deepcopy = [
i.copy() if isinstance(i, np.ndarray) else copy.deepcopy(i)
for i in test_inputs
]
pytensor_py_fn = function( pytensor_py_fn = function(
graph_inputs, graph_outputs, mode=py_mode, accept_inplace=True, updates=updates graph_inputs,
graph_outputs,
mode=py_mode,
accept_inplace=inplace,
updates=updates,
) )
test_inputs_copy = (inp.copy() for inp in test_inputs) if inplace else test_inputs test_inputs_copy = (inp.copy() for inp in test_inputs) if inplace else test_inputs
...@@ -250,11 +262,20 @@ def compare_numba_and_py( ...@@ -250,11 +262,20 @@ def compare_numba_and_py(
graph_inputs, graph_inputs,
graph_outputs, graph_outputs,
mode=numba_mode, mode=numba_mode,
accept_inplace=True, accept_inplace=inplace,
updates=updates, updates=updates,
) )
test_inputs_copy = (inp.copy() for inp in test_inputs) if inplace else test_inputs test_inputs_copy = (inp.copy() for inp in test_inputs) if inplace else test_inputs
numba_res = pytensor_numba_fn(*test_inputs_copy) numba_res = pytensor_numba_fn(*test_inputs_copy)
if not inplace:
# Check we did not accidentally modify the inputs inplace
for test_input, test_input_copy in zip(test_inputs, test_input_deepcopy):
try:
assert_fn(test_input, test_input_copy)
except AssertionError as e:
raise AssertionError("Inputs were modified inplace") from e
if isinstance(graph_outputs, tuple | list): if isinstance(graph_outputs, tuple | list):
for numba_res_i, python_res_i in zip(numba_res, py_res, strict=True): for numba_res_i, python_res_i in zip(numba_res, py_res, strict=True):
assert_fn(numba_res_i, python_res_i) assert_fn(numba_res_i, python_res_i)
......
...@@ -117,11 +117,14 @@ add_inplace = Elemwise(scalar_add, {0: 0}) ...@@ -117,11 +117,14 @@ add_inplace = Elemwise(scalar_add, {0: 0})
) )
def test_Elemwise(inputs, input_vals, output_fn): def test_Elemwise(inputs, input_vals, output_fn):
outputs = output_fn(*inputs) outputs = output_fn(*inputs)
if not isinstance(outputs, tuple | list):
outputs = [outputs]
compare_numba_and_py( compare_numba_and_py(
inputs, inputs,
outputs, outputs,
input_vals, input_vals,
inplace=outputs[0].owner.op.destroy_map,
) )
......
...@@ -84,6 +84,7 @@ def test_CumOp(val, axis, mode): ...@@ -84,6 +84,7 @@ def test_CumOp(val, axis, mode):
) )
@pytest.mark.xfail(reason="Implementation works inplace!")
def test_FillDiagonal(): def test_FillDiagonal():
a = pt.lmatrix("a") a = pt.lmatrix("a")
test_a = np.zeros((10, 2), dtype="int64") test_a = np.zeros((10, 2), dtype="int64")
......
from functools import partial
import numpy as np import numpy as np
import pytest import pytest
import scipy as sp import scipy as sp
...@@ -16,6 +18,23 @@ from tests.link.numba.test_basic import compare_numba_and_py ...@@ -16,6 +18,23 @@ from tests.link.numba.test_basic import compare_numba_and_py
pytestmark = pytest.mark.filterwarnings("error") pytestmark = pytest.mark.filterwarnings("error")
def sparse_assert_fn(a, b):
a_is_sparse = sp.sparse.issparse(a)
assert a_is_sparse == sp.sparse.issparse(b)
if a_is_sparse:
assert a.format == b.format
assert a.dtype == b.dtype
assert a.shape == b.shape
np.testing.assert_allclose(a.data, b.data, strict=True)
np.testing.assert_allclose(a.indices, b.indices, strict=True)
np.testing.assert_allclose(a.indptr, b.indptr, strict=True)
else:
np.testing.assert_allclose(a, b, strict=True)
compare_numba_and_py_sparse = partial(compare_numba_and_py, assert_fn=sparse_assert_fn)
def test_sparse_unboxing(): def test_sparse_unboxing():
@numba.njit @numba.njit
def test_unboxing(x, y): def test_unboxing(x, y):
...@@ -93,11 +112,15 @@ def test_sparse_objmode(): ...@@ -93,11 +112,15 @@ def test_sparse_objmode():
out = Dot()(x, y) out = Dot()(x, y)
x_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX) x_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX, format="csc")
y_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX) y_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX, format="csc")
with pytest.warns( with pytest.warns(
UserWarning, UserWarning,
match="Numba will use object mode to run SparseDot's perform method", match="Numba will use object mode to run SparseDot's perform method",
): ):
compare_numba_and_py([x, y], out, [x_val, y_val]) compare_numba_and_py_sparse(
[x, y],
out,
[x_val, y_val],
)
...@@ -259,7 +259,7 @@ def test_IncSubtensor(x, y, indices): ...@@ -259,7 +259,7 @@ def test_IncSubtensor(x, y, indices):
x_pt = x.type() x_pt = x.type()
out_pt = set_subtensor(x_pt[indices], y, inplace=True) out_pt = set_subtensor(x_pt[indices], y, inplace=True)
assert isinstance(out_pt.owner.op, IncSubtensor) assert isinstance(out_pt.owner.op, IncSubtensor)
compare_numba_and_py([x_pt], [out_pt], [x.data]) compare_numba_and_py([x_pt], [out_pt], [x.data], inplace=True)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -313,13 +313,13 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -313,13 +313,13 @@ def test_AdvancedIncSubtensor1(x, y, indices):
out_pt = AdvancedIncSubtensor1(inplace=True)(x_pt, y_pt, *indices) out_pt = AdvancedIncSubtensor1(inplace=True)(x_pt, y_pt, *indices)
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1) assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1)
compare_numba_and_py([x_pt, y_pt], [out_pt], [x.data, y.data]) compare_numba_and_py([x_pt, y_pt], [out_pt], [x.data, y.data], inplace=True)
out_pt = AdvancedIncSubtensor1(set_instead_of_inc=True, inplace=True)( out_pt = AdvancedIncSubtensor1(set_instead_of_inc=True, inplace=True)(
x_pt, y_pt, *indices x_pt, y_pt, *indices
) )
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1) assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1)
compare_numba_and_py([x_pt, y_pt], [out_pt], [x.data, y.data]) compare_numba_and_py([x_pt, y_pt], [out_pt], [x.data, y.data], inplace=True)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -526,7 +526,9 @@ def test_AdvancedIncSubtensor( ...@@ -526,7 +526,9 @@ def test_AdvancedIncSubtensor(
if set_requires_objmode if set_requires_objmode
else contextlib.nullcontext() else contextlib.nullcontext()
): ):
fn, _ = compare_numba_and_py([x_pt, y_pt], out_pt, [x, y], numba_mode=mode) fn, _ = compare_numba_and_py(
[x_pt, y_pt], out_pt, [x, y], numba_mode=mode, inplace=inplace
)
if inplace: if inplace:
# Test updates inplace # Test updates inplace
...@@ -546,7 +548,9 @@ def test_AdvancedIncSubtensor( ...@@ -546,7 +548,9 @@ def test_AdvancedIncSubtensor(
if inc_requires_objmode if inc_requires_objmode
else contextlib.nullcontext() else contextlib.nullcontext()
): ):
fn, _ = compare_numba_and_py([x_pt, y_pt], out_pt, [x, y], numba_mode=mode) fn, _ = compare_numba_and_py(
[x_pt, y_pt], out_pt, [x, y], numba_mode=mode, inplace=inplace
)
if inplace: if inplace:
# Test updates inplace # Test updates inplace
x_orig = x.copy() x_orig = x.copy()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论