提交 d4a0433d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

compare_numba_and_py: Check for accidental input mutation

上级 fbee4164
import contextlib
import copy
from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Any
from unittest import mock
......@@ -232,8 +233,19 @@ def compare_numba_and_py(
):
raise ValueError("Inputs must be root variables")
test_input_deepcopy = None
if not inplace:
test_input_deepcopy = [
i.copy() if isinstance(i, np.ndarray) else copy.deepcopy(i)
for i in test_inputs
]
pytensor_py_fn = function(
graph_inputs, graph_outputs, mode=py_mode, accept_inplace=True, updates=updates
graph_inputs,
graph_outputs,
mode=py_mode,
accept_inplace=inplace,
updates=updates,
)
test_inputs_copy = (inp.copy() for inp in test_inputs) if inplace else test_inputs
......@@ -250,11 +262,20 @@ def compare_numba_and_py(
graph_inputs,
graph_outputs,
mode=numba_mode,
accept_inplace=True,
accept_inplace=inplace,
updates=updates,
)
test_inputs_copy = (inp.copy() for inp in test_inputs) if inplace else test_inputs
numba_res = pytensor_numba_fn(*test_inputs_copy)
if not inplace:
# Check we did not accidentally modify the inputs inplace
for test_input, test_input_copy in zip(test_inputs, test_input_deepcopy):
try:
assert_fn(test_input, test_input_copy)
except AssertionError as e:
raise AssertionError("Inputs were modified inplace") from e
if isinstance(graph_outputs, tuple | list):
for numba_res_i, python_res_i in zip(numba_res, py_res, strict=True):
assert_fn(numba_res_i, python_res_i)
......
......@@ -117,11 +117,14 @@ add_inplace = Elemwise(scalar_add, {0: 0})
)
def test_Elemwise(inputs, input_vals, output_fn):
outputs = output_fn(*inputs)
if not isinstance(outputs, tuple | list):
outputs = [outputs]
compare_numba_and_py(
inputs,
outputs,
input_vals,
inplace=outputs[0].owner.op.destroy_map,
)
......
......@@ -84,6 +84,7 @@ def test_CumOp(val, axis, mode):
)
@pytest.mark.xfail(reason="Implementation works inplace!")
def test_FillDiagonal():
a = pt.lmatrix("a")
test_a = np.zeros((10, 2), dtype="int64")
......
from functools import partial
import numpy as np
import pytest
import scipy as sp
......@@ -16,6 +18,23 @@ from tests.link.numba.test_basic import compare_numba_and_py
pytestmark = pytest.mark.filterwarnings("error")
def sparse_assert_fn(a, b):
a_is_sparse = sp.sparse.issparse(a)
assert a_is_sparse == sp.sparse.issparse(b)
if a_is_sparse:
assert a.format == b.format
assert a.dtype == b.dtype
assert a.shape == b.shape
np.testing.assert_allclose(a.data, b.data, strict=True)
np.testing.assert_allclose(a.indices, b.indices, strict=True)
np.testing.assert_allclose(a.indptr, b.indptr, strict=True)
else:
np.testing.assert_allclose(a, b, strict=True)
compare_numba_and_py_sparse = partial(compare_numba_and_py, assert_fn=sparse_assert_fn)
def test_sparse_unboxing():
@numba.njit
def test_unboxing(x, y):
......@@ -93,11 +112,15 @@ def test_sparse_objmode():
out = Dot()(x, y)
x_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX)
y_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX)
x_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX, format="csc")
y_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX, format="csc")
with pytest.warns(
UserWarning,
match="Numba will use object mode to run SparseDot's perform method",
):
compare_numba_and_py([x, y], out, [x_val, y_val])
compare_numba_and_py_sparse(
[x, y],
out,
[x_val, y_val],
)
......@@ -259,7 +259,7 @@ def test_IncSubtensor(x, y, indices):
x_pt = x.type()
out_pt = set_subtensor(x_pt[indices], y, inplace=True)
assert isinstance(out_pt.owner.op, IncSubtensor)
compare_numba_and_py([x_pt], [out_pt], [x.data])
compare_numba_and_py([x_pt], [out_pt], [x.data], inplace=True)
@pytest.mark.parametrize(
......@@ -313,13 +313,13 @@ def test_AdvancedIncSubtensor1(x, y, indices):
out_pt = AdvancedIncSubtensor1(inplace=True)(x_pt, y_pt, *indices)
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1)
compare_numba_and_py([x_pt, y_pt], [out_pt], [x.data, y.data])
compare_numba_and_py([x_pt, y_pt], [out_pt], [x.data, y.data], inplace=True)
out_pt = AdvancedIncSubtensor1(set_instead_of_inc=True, inplace=True)(
x_pt, y_pt, *indices
)
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1)
compare_numba_and_py([x_pt, y_pt], [out_pt], [x.data, y.data])
compare_numba_and_py([x_pt, y_pt], [out_pt], [x.data, y.data], inplace=True)
@pytest.mark.parametrize(
......@@ -526,7 +526,9 @@ def test_AdvancedIncSubtensor(
if set_requires_objmode
else contextlib.nullcontext()
):
fn, _ = compare_numba_and_py([x_pt, y_pt], out_pt, [x, y], numba_mode=mode)
fn, _ = compare_numba_and_py(
[x_pt, y_pt], out_pt, [x, y], numba_mode=mode, inplace=inplace
)
if inplace:
# Test updates inplace
......@@ -546,7 +548,9 @@ def test_AdvancedIncSubtensor(
if inc_requires_objmode
else contextlib.nullcontext()
):
fn, _ = compare_numba_and_py([x_pt, y_pt], out_pt, [x, y], numba_mode=mode)
fn, _ = compare_numba_and_py(
[x_pt, y_pt], out_pt, [x, y], numba_mode=mode, inplace=inplace
)
if inplace:
# Test updates inplace
x_orig = x.copy()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论