提交 5f6c0103 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Test organization: group Solve and Decompositions tests

上级 993c2c64
...@@ -410,12 +410,86 @@ class TestSolves: ...@@ -410,12 +410,86 @@ class TestSolves:
# Can never destroy non-contiguous inputs # Can never destroy non-contiguous inputs
np.testing.assert_allclose(b_val_not_contig, b_val) np.testing.assert_allclose(b_val_not_contig, b_val)
@pytest.mark.parametrize("trans", [True, False], ids=lambda x: f"trans = {x}")
@pytest.mark.parametrize(
"overwrite_b", [False, True], ids=["no_overwrite", "overwrite_b"]
)
@pytest.mark.parametrize(
"b_func, b_shape",
[(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))],
ids=["b_col_vec", "b_matrix", "b_vec"],
)
def test_lu_solve(
self, b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bool
):
A = pt.matrix("A", dtype=floatX)
b = pt.tensor("b", shape=b_shape, dtype=floatX)
@pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower={x}") rng = np.random.default_rng(418)
@pytest.mark.parametrize( A_val = rng.normal(size=(5, 5)).astype(floatX)
b_val = rng.normal(size=b_shape).astype(floatX)
lu_and_piv = pt.linalg.lu_factor(A)
X = pt.linalg.lu_solve(
lu_and_piv,
b,
b_ndim=len(b_shape),
trans=trans,
)
f, res = compare_numba_and_py(
[A, In(b, mutable=overwrite_b)],
X,
test_inputs=[A_val, b_val],
inplace=True,
numba_mode=numba_inplace_mode,
eval_obj_mode=False,
)
# Test with F_contiguous inputs
A_val_f_contig = np.copy(A_val, order="F")
b_val_f_contig = np.copy(b_val, order="F")
res_f_contig = f(A_val_f_contig, b_val_f_contig)
np.testing.assert_allclose(res_f_contig, res)
all_equal = (b_val == b_val_f_contig).all()
should_destroy = overwrite_b and trans
if should_destroy:
assert not all_equal
else:
assert all_equal
# Test with C_contiguous inputs
A_val_c_contig = np.copy(A_val, order="C")
b_val_c_contig = np.copy(b_val, order="C")
res_c_contig = f(A_val_c_contig, b_val_c_contig)
np.testing.assert_allclose(res_c_contig, res)
np.testing.assert_allclose(A_val_c_contig, A_val)
# b c_contiguous vectors are also f_contiguous and destroyable
assert not (
should_destroy and b_val_c_contig.flags.f_contiguous
) == np.allclose(b_val_c_contig, b_val)
# Test with non-contiguous inputs
A_val_not_contig = np.repeat(A_val, 2, axis=0)[::2]
b_val_not_contig = np.repeat(b_val, 2, axis=0)[::2]
res_not_contig = f(A_val_not_contig, b_val_not_contig)
np.testing.assert_allclose(res_not_contig, res)
np.testing.assert_allclose(A_val_not_contig, A_val)
# Can never destroy non-contiguous inputs
np.testing.assert_allclose(b_val_not_contig, b_val)
class TestDecompositions:
@pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower={x}")
@pytest.mark.parametrize(
"overwrite_a", [False, True], ids=["no_overwrite", "overwrite_a"] "overwrite_a", [False, True], ids=["no_overwrite", "overwrite_a"]
) )
def test_cholesky(lower: bool, overwrite_a: bool): def test_cholesky(self, lower: bool, overwrite_a: bool):
cov = pt.matrix("cov") cov = pt.matrix("cov")
chol = pt.linalg.cholesky(cov, lower=lower) chol = pt.linalg.cholesky(cov, lower=lower)
...@@ -459,8 +533,7 @@ def test_cholesky(lower: bool, overwrite_a: bool): ...@@ -459,8 +533,7 @@ def test_cholesky(lower: bool, overwrite_a: bool):
# Cannot destroy non-contiguous input # Cannot destroy non-contiguous input
np.testing.assert_allclose(val_not_contig, val) np.testing.assert_allclose(val_not_contig, val)
def test_cholesky_raises_on_nan_input(self):
def test_cholesky_raises_on_nan_input():
test_value = rng.random(size=(3, 3)).astype(floatX) test_value = rng.random(size=(3, 3)).astype(floatX)
test_value[0, 0] = np.nan test_value[0, 0] = np.nan
...@@ -472,9 +545,8 @@ def test_cholesky_raises_on_nan_input(): ...@@ -472,9 +545,8 @@ def test_cholesky_raises_on_nan_input():
with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"): with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"):
f(test_value) f(test_value)
@pytest.mark.parametrize("on_error", ["nan", "raise"])
@pytest.mark.parametrize("on_error", ["nan", "raise"]) def test_cholesky_raise_on(self, on_error):
def test_cholesky_raise_on(on_error):
test_value = rng.random(size=(3, 3)).astype(floatX) test_value = rng.random(size=(3, 3)).astype(floatX)
x = pt.tensor(dtype=floatX, shape=(3, 3)) x = pt.tensor(dtype=floatX, shape=(3, 3))
...@@ -483,59 +555,22 @@ def test_cholesky_raise_on(on_error): ...@@ -483,59 +555,22 @@ def test_cholesky_raise_on(on_error):
if on_error == "raise": if on_error == "raise":
with pytest.raises( with pytest.raises(
np.linalg.LinAlgError, match=r"Input to cholesky is not positive definite" np.linalg.LinAlgError,
match=r"Input to cholesky is not positive definite",
): ):
f(test_value) f(test_value)
else: else:
assert np.all(np.isnan(f(test_value))) assert np.all(np.isnan(f(test_value)))
@pytest.mark.parametrize(
def test_block_diag():
A = pt.matrix("A")
B = pt.matrix("B")
C = pt.matrix("C")
D = pt.matrix("D")
X = pt.linalg.block_diag(A, B, C, D)
A_val = np.random.normal(size=(5, 5)).astype(floatX)
B_val = np.random.normal(size=(3, 3)).astype(floatX)
C_val = np.random.normal(size=(2, 2)).astype(floatX)
D_val = np.random.normal(size=(4, 4)).astype(floatX)
compare_numba_and_py([A, B, C, D], [X], [A_val, B_val, C_val, D_val])
@pytest.mark.parametrize("inverse", [True, False], ids=["p_inv", "p"])
def test_pivot_to_permutation(inverse):
from pytensor.tensor.slinalg import pivot_to_permutation
rng = np.random.default_rng(123)
A = rng.normal(size=(5, 5)).astype(floatX)
perm_pt = pt.vector("p", dtype="int32")
piv_pt = pivot_to_permutation(perm_pt, inverse=inverse)
f = pytensor.function([perm_pt], piv_pt, mode="NUMBA")
_, piv = scipy.linalg.lu_factor(A)
if inverse:
p = np.arange(len(piv))
for i in range(len(piv)):
p[i], p[piv[i]] = p[piv[i]], p[i]
np.testing.assert_allclose(f(piv), p)
else:
p, *_ = scipy.linalg.lu(A, p_indices=True)
np.testing.assert_allclose(f(piv), p)
@pytest.mark.parametrize(
"permute_l, p_indices", "permute_l, p_indices",
[(True, False), (False, True), (False, False)], [(True, False), (False, True), (False, False)],
ids=["PL", "p_indices", "P"], ids=["PL", "p_indices", "P"],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"] "overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"]
) )
def test_lu(permute_l, p_indices, overwrite_a): def test_lu(self, permute_l, p_indices, overwrite_a):
shape = (5, 5) shape = (5, 5)
rng = np.random.default_rng() rng = np.random.default_rng()
A = pt.tensor( A = pt.tensor(
...@@ -595,11 +630,10 @@ def test_lu(permute_l, p_indices, overwrite_a): ...@@ -595,11 +630,10 @@ def test_lu(permute_l, p_indices, overwrite_a):
# Cannot destroy non-contiguous input # Cannot destroy non-contiguous input
np.testing.assert_allclose(val_not_contig, A_val) np.testing.assert_allclose(val_not_contig, A_val)
@pytest.mark.parametrize(
@pytest.mark.parametrize(
"overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"] "overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"]
) )
def test_lu_factor(overwrite_a): def test_lu_factor(self, overwrite_a):
shape = (5, 5) shape = (5, 5)
rng = np.random.default_rng() rng = np.random.default_rng()
...@@ -650,88 +684,15 @@ def test_lu_factor(overwrite_a): ...@@ -650,88 +684,15 @@ def test_lu_factor(overwrite_a):
# Cannot destroy non-contiguous input # Cannot destroy non-contiguous input
np.testing.assert_allclose(val_not_contig, A_val) np.testing.assert_allclose(val_not_contig, A_val)
@pytest.mark.parametrize(
@pytest.mark.parametrize("trans", [True, False], ids=lambda x: f"trans = {x}")
@pytest.mark.parametrize(
"overwrite_b", [False, True], ids=["no_overwrite", "overwrite_b"]
)
@pytest.mark.parametrize(
"b_func, b_shape",
[(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))],
ids=["b_col_vec", "b_matrix", "b_vec"],
)
def test_lu_solve(b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bool):
A = pt.matrix("A", dtype=floatX)
b = pt.tensor("b", shape=b_shape, dtype=floatX)
rng = np.random.default_rng(418)
A_val = rng.normal(size=(5, 5)).astype(floatX)
b_val = rng.normal(size=b_shape).astype(floatX)
lu_and_piv = pt.linalg.lu_factor(A)
X = pt.linalg.lu_solve(
lu_and_piv,
b,
b_ndim=len(b_shape),
trans=trans,
)
f, res = compare_numba_and_py(
[A, In(b, mutable=overwrite_b)],
X,
test_inputs=[A_val, b_val],
inplace=True,
numba_mode=numba_inplace_mode,
eval_obj_mode=False,
)
# Test with F_contiguous inputs
A_val_f_contig = np.copy(A_val, order="F")
b_val_f_contig = np.copy(b_val, order="F")
res_f_contig = f(A_val_f_contig, b_val_f_contig)
np.testing.assert_allclose(res_f_contig, res)
all_equal = (b_val == b_val_f_contig).all()
should_destroy = overwrite_b and trans
if should_destroy:
assert not all_equal
else:
assert all_equal
# Test with C_contiguous inputs
A_val_c_contig = np.copy(A_val, order="C")
b_val_c_contig = np.copy(b_val, order="C")
res_c_contig = f(A_val_c_contig, b_val_c_contig)
np.testing.assert_allclose(res_c_contig, res)
np.testing.assert_allclose(A_val_c_contig, A_val)
# b c_contiguous vectors are also f_contiguous and destroyable
assert not (should_destroy and b_val_c_contig.flags.f_contiguous) == np.allclose(
b_val_c_contig, b_val
)
# Test with non-contiguous inputs
A_val_not_contig = np.repeat(A_val, 2, axis=0)[::2]
b_val_not_contig = np.repeat(b_val, 2, axis=0)[::2]
res_not_contig = f(A_val_not_contig, b_val_not_contig)
np.testing.assert_allclose(res_not_contig, res)
np.testing.assert_allclose(A_val_not_contig, A_val)
# Can never destroy non-contiguous inputs
np.testing.assert_allclose(b_val_not_contig, b_val)
@pytest.mark.parametrize(
"mode, pivoting", "mode, pivoting",
[("economic", False), ("full", True), ("r", False), ("raw", True)], [("economic", False), ("full", True), ("r", False), ("raw", True)],
ids=["economic", "full_pivot", "r", "raw_pivot"], ids=["economic", "full_pivot", "r", "raw_pivot"],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"] "overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"]
) )
def test_qr(mode, pivoting, overwrite_a): def test_qr(self, mode, pivoting, overwrite_a):
shape = (5, 5) shape = (5, 5)
rng = np.random.default_rng() rng = np.random.default_rng()
A = pt.tensor( A = pt.tensor(
...@@ -788,3 +749,40 @@ def test_qr(mode, pivoting, overwrite_a): ...@@ -788,3 +749,40 @@ def test_qr(mode, pivoting, overwrite_a):
# Cannot destroy non-contiguous input # Cannot destroy non-contiguous input
np.testing.assert_allclose(val_not_contig, A_val) np.testing.assert_allclose(val_not_contig, A_val)
def test_block_diag():
A = pt.matrix("A")
B = pt.matrix("B")
C = pt.matrix("C")
D = pt.matrix("D")
X = pt.linalg.block_diag(A, B, C, D)
A_val = np.random.normal(size=(5, 5)).astype(floatX)
B_val = np.random.normal(size=(3, 3)).astype(floatX)
C_val = np.random.normal(size=(2, 2)).astype(floatX)
D_val = np.random.normal(size=(4, 4)).astype(floatX)
compare_numba_and_py([A, B, C, D], [X], [A_val, B_val, C_val, D_val])
@pytest.mark.parametrize("inverse", [True, False], ids=["p_inv", "p"])
def test_pivot_to_permutation(inverse):
from pytensor.tensor.slinalg import pivot_to_permutation
rng = np.random.default_rng(123)
A = rng.normal(size=(5, 5)).astype(floatX)
perm_pt = pt.vector("p", dtype="int32")
piv_pt = pivot_to_permutation(perm_pt, inverse=inverse)
f = pytensor.function([perm_pt], piv_pt, mode="NUMBA")
_, piv = scipy.linalg.lu_factor(A)
if inverse:
p = np.arange(len(piv))
for i in range(len(piv)):
p[i], p[piv[i]] = p[piv[i]], p[i]
np.testing.assert_allclose(f(piv), p)
else:
p, *_ = scipy.linalg.lu(A, p_indices=True)
np.testing.assert_allclose(f(piv), p)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论