提交 0fd8315f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix contiguity bugs in Numba lapack routines

Also removes redundant tests
上级 a149f6c9
......@@ -26,6 +26,12 @@ from pytensor.tensor.slinalg import (
)
@numba_basic.numba_njit(inline="always")
def _copy_to_fortran_order_even_if_1d(x):
# Numba's _copy_to_fortran_order doesn't do anything for vectors
return x.copy() if x.ndim == 1 else _copy_to_fortran_order(x)
@numba_basic.numba_njit(inline="always")
def _solve_check(n, info, lamch=False, rcond=None):
"""
......@@ -132,18 +138,13 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
# This will only copy if A is not already fortran contiguous
A_f = np.asfortranarray(A)
if overwrite_b:
if B_is_1d:
B_copy = np.expand_dims(B, -1)
else:
# This *will* allow inplace destruction of B, but only if it is already fortran contiguous.
# Otherwise, there's no way to get around the need to copy the data before going into TRTRS
B_copy = np.asfortranarray(B)
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
if B_is_1d:
B_copy = np.copy(np.expand_dims(B, -1))
else:
B_copy = _copy_to_fortran_order(B)
B_copy = _copy_to_fortran_order_even_if_1d(B)
if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)
NRHS = 1 if B_is_1d else int(B_copy.shape[-1])
......@@ -247,10 +248,10 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
LDA = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)
if not overwrite_a:
A_copy = _copy_to_fortran_order(A)
else:
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
numba_potrf(
UPLO,
......@@ -283,7 +284,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
In particular, the `inplace` argument is not supported, which is why we choose to implement our own version.
"""
lower = op.lower
overwrite_a = False
overwrite_a = op.overwrite_a
check_finite = op.check_finite
on_error = op.on_error
......@@ -497,10 +498,10 @@ def getrf_impl(
) -> tuple[np.ndarray, np.ndarray, int]:
_M, _N = np.int32(A.shape[-2:]) # type: ignore
if not overwrite_a:
A_copy = _copy_to_fortran_order(A)
else:
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
M = val_to_int_ptr(_M) # type: ignore
N = val_to_int_ptr(_N) # type: ignore
......@@ -545,10 +546,10 @@ def getrs_impl(
B_is_1d = B.ndim == 1
if not overwrite_b:
B_copy = _copy_to_fortran_order(B)
else:
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)
if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)
......@@ -576,7 +577,7 @@ def getrs_impl(
)
if B_is_1d:
return B_copy[..., 0], int_ptr_to_val(INFO)
B_copy = B_copy[..., 0]
return B_copy, int_ptr_to_val(INFO)
......@@ -681,19 +682,20 @@ def sysv_impl(
_LDA, _N = np.int32(A.shape[-2:]) # type: ignore
_solve_check_input_shapes(A, B)
if not overwrite_a:
A_copy = _copy_to_fortran_order(A)
else:
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
B_is_1d = B.ndim == 1
if not overwrite_b:
B_copy = _copy_to_fortran_order(B)
else:
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)
if B_is_1d:
B_copy = np.asfortranarray(np.expand_dims(B_copy, -1))
B_copy = np.expand_dims(B_copy, -1)
NRHS = 1 if B_is_1d else int(B.shape[-1])
......@@ -903,17 +905,17 @@ def posv_impl(
_N = np.int32(A.shape[-1])
if not overwrite_a:
A_copy = _copy_to_fortran_order(A)
else:
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
B_is_1d = B.ndim == 1
if not overwrite_b:
B_copy = _copy_to_fortran_order(B)
else:
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)
if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)
......@@ -1102,12 +1104,15 @@ def numba_funcify_Solve(op, node, **kwargs):
return solve
def _cho_solve(A_and_lower, B, overwrite_a=False, overwrite_b=False, check_finite=True):
def _cho_solve(
C: np.ndarray, B: np.ndarray, lower: bool, overwrite_b: bool, check_finite: bool
):
"""
Solve a positive-definite linear system using the Cholesky decomposition.
"""
A, lower = A_and_lower
return linalg.cho_solve((A, lower), B)
return linalg.cho_solve(
(C, lower), b=B, overwrite_b=overwrite_b, check_finite=check_finite
)
@overload(_cho_solve)
......@@ -1123,13 +1128,16 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
_solve_check_input_shapes(C, B)
_N = np.int32(C.shape[-1])
C_copy = _copy_to_fortran_order(C)
C_f = np.asfortranarray(C)
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)
B_is_1d = B.ndim == 1
if B_is_1d:
B_copy = np.asfortranarray(np.expand_dims(B, -1))
else:
B_copy = _copy_to_fortran_order(B)
B_copy = np.expand_dims(B_copy, -1)
NRHS = 1 if B_is_1d else int(B.shape[-1])
......@@ -1144,16 +1152,18 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
UPLO,
N,
NRHS,
C_copy.view(w_type).ctypes,
C_f.view(w_type).ctypes,
LDA,
B_copy.view(w_type).ctypes,
LDB,
INFO,
)
_solve_check(_N, int_ptr_to_val(INFO))
if B_is_1d:
return B_copy[..., 0], int_ptr_to_val(INFO)
return B_copy, int_ptr_to_val(INFO)
return B_copy[..., 0]
return B_copy
return impl
......@@ -1182,16 +1192,8 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
"Non-numeric values (nan or inf) in input b to cho_solve"
)
res, info = _cho_solve(
return _cho_solve(
c, b, lower=lower, overwrite_b=overwrite_b, check_finite=check_finite
)
if info < 0:
raise np.linalg.LinAlgError("Illegal values found in input to cho_solve")
elif info > 0:
raise np.linalg.LinAlgError(
"Matrix is not positive definite in input to cho_solve"
)
return res
return cho_solve
......@@ -7,6 +7,7 @@ from unittest import mock
import numpy as np
import pytest
from pytensor.compile import SymbolicInput
from tests.tensor.test_math_scipy import scipy
......@@ -120,6 +121,7 @@ opts = RewriteDatabaseQuery(
numba_mode = Mode(
NumbaLinker(), opts.including("numba", "local_useless_unbatched_blockwise")
)
numba_inplace_mode = numba_mode.including("inplace")
py_mode = Mode("py", opts)
rng = np.random.default_rng(42849)
......@@ -261,7 +263,11 @@ def compare_numba_and_py(
x, y
)
if any(inp.owner is not None for inp in graph_inputs):
if any(
inp.owner is not None
for inp in graph_inputs
if not isinstance(inp, SymbolicInput)
):
raise ValueError("Inputs must be root variables")
pytensor_py_fn = function(
......
import re
from functools import partial
from typing import Literal
import numpy as np
import pytest
from numpy.testing import assert_allclose
import scipy
import pytensor
import pytensor.tensor as pt
from pytensor import config
from pytensor.tensor.slinalg import SolveTriangular
from tests import unittest_tools as utt
from tests.link.numba.test_basic import compare_numba_and_py
from pytensor import In, config
from pytensor.tensor.slinalg import Cholesky, CholeskySolve, Solve, SolveTriangular
from tests.link.numba.test_basic import compare_numba_and_py, numba_inplace_mode
numba = pytest.importorskip("numba")
......@@ -21,250 +19,6 @@ floatX = config.floatX
rng = np.random.default_rng(42849)
def transpose_func(x, trans):
if trans == 0:
return x
if trans == 1:
return x.T
if trans == 2:
return x.conj().T
@pytest.mark.parametrize(
"b_shape",
[(5, 1), (5, 5), (5,)],
ids=["b_col_vec", "b_matrix", "b_vec"],
)
@pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"])
@pytest.mark.parametrize("trans", [0, 1, 2], ids=["trans=N", "trans=C", "trans=T"])
@pytest.mark.parametrize(
"unit_diag", [True, False], ids=["unit_diag=True", "unit_diag=False"]
)
@pytest.mark.parametrize("is_complex", [True, False], ids=["complex", "real"])
@pytest.mark.filterwarnings(
'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
)
def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_complex):
if is_complex:
# TODO: Complex raises ValueError: To change to a dtype of a different size, the last axis must be contiguous,
# why?
pytest.skip("Complex inputs currently not supported to solve_triangular")
complex_dtype = "complex64" if floatX.endswith("32") else "complex128"
dtype = complex_dtype if is_complex else floatX
A = pt.matrix("A", dtype=dtype)
b = pt.tensor("b", shape=b_shape, dtype=dtype)
def A_func(x):
x = x @ x.conj().T
x_tri = pt.linalg.cholesky(x, lower=lower).astype(dtype)
if unit_diag:
x_tri = pt.fill_diagonal(x_tri, 1.0)
return x_tri
solve_op = partial(
pt.linalg.solve_triangular, lower=lower, trans=trans, unit_diagonal=unit_diag
)
X = solve_op(A_func(A), b)
f = pytensor.function([A, b], X, mode="NUMBA")
A_val = np.random.normal(size=(5, 5))
b_val = np.random.normal(size=b_shape)
if is_complex:
A_val = A_val + np.random.normal(size=(5, 5)) * 1j
b_val = b_val + np.random.normal(size=b_shape) * 1j
X_np = f(A_val.copy(), b_val.copy())
A_val_transformed = transpose_func(A_func(A_val), trans).eval()
np.testing.assert_allclose(
A_val_transformed @ X_np,
b_val,
atol=1e-8 if floatX.endswith("64") else 1e-4,
rtol=1e-8 if floatX.endswith("64") else 1e-4,
)
compiled_fgraph = f.maker.fgraph
compare_numba_and_py(
compiled_fgraph.inputs,
compiled_fgraph.outputs,
[A_val, b_val],
)
@pytest.mark.parametrize(
"lower, unit_diag, trans",
[(True, True, True), (False, False, False)],
ids=["lower_unit_trans", "defaults"],
)
def test_solve_triangular_grad(lower, unit_diag, trans):
A_val = np.random.normal(size=(5, 5)).astype(floatX)
b_val = np.random.normal(size=(5, 5)).astype(floatX)
# utt.verify_grad uses small perturbations to the input matrix to calculate the finite difference gradient. When
# a non-triangular matrix is passed to scipy.linalg.solve_triangular, no error is raise, but the result will be
# wrong, resulting in wrong gradients. As a result, it is necessary to add a mapping from the space of all matrices
# to the space of triangular matrices, and test the gradient of that entire graph.
def A_func_pt(x):
x = x @ x.conj().T
x_tri = pt.linalg.cholesky(x, lower=lower).astype(floatX)
if unit_diag:
n = A_val.shape[0]
x_tri = x_tri[np.diag_indices(n)].set(1.0)
return transpose_func(x_tri.astype(floatX), trans)
solve_op = partial(
pt.linalg.solve_triangular, lower=lower, trans=trans, unit_diagonal=unit_diag
)
utt.verify_grad(
lambda A, b: solve_op(A_func_pt(A), b),
[A_val.copy(), b_val.copy()],
mode="NUMBA",
)
@pytest.mark.parametrize("overwrite_b", [True, False], ids=["inplace", "not_inplace"])
def test_solve_triangular_overwrite_b_correct(overwrite_b):
# Regression test for issue #1233
rng = np.random.default_rng(utt.fetch_seed())
a_test_py = np.asfortranarray(rng.normal(size=(3, 3)))
a_test_py = np.tril(a_test_py)
b_test_py = np.asfortranarray(rng.normal(size=(3, 2)))
# .T.copy().T creates an f-contiguous copy of an f-contiguous array (otherwise the copy is c-contiguous)
a_test_nb = a_test_py.copy(order="F")
b_test_nb = b_test_py.copy(order="F")
op = SolveTriangular(
unit_diagonal=False,
lower=False,
check_finite=True,
b_ndim=2,
overwrite_b=overwrite_b,
)
a_pt = pt.matrix("a", shape=(3, 3))
b_pt = pt.matrix("b", shape=(3, 2))
out = op(a_pt, b_pt)
py_fn = pytensor.function([a_pt, b_pt], out, accept_inplace=True)
numba_fn = pytensor.function([a_pt, b_pt], out, accept_inplace=True, mode="NUMBA")
x_py = py_fn(a_test_py, b_test_py)
x_nb = numba_fn(a_test_nb, b_test_nb)
np.testing.assert_allclose(
py_fn(a_test_py, b_test_py), numba_fn(a_test_nb, b_test_nb)
)
np.testing.assert_allclose(b_test_py, b_test_nb)
if overwrite_b:
np.testing.assert_allclose(b_test_py, x_py)
np.testing.assert_allclose(b_test_nb, x_nb)
@pytest.mark.parametrize("value", [np.nan, np.inf])
@pytest.mark.filterwarnings(
'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
)
def test_solve_triangular_raises_on_nan_inf(value):
A = pt.matrix("A")
b = pt.matrix("b")
X = pt.linalg.solve_triangular(A, b, check_finite=True)
f = pytensor.function([A, b], X, mode="NUMBA")
A_val = np.random.normal(size=(5, 5)).astype(floatX)
A_sym = A_val @ A_val.conj().T
A_tri = np.linalg.cholesky(A_sym).astype(floatX)
b = np.full((5, 1), value).astype(floatX)
with pytest.raises(
np.linalg.LinAlgError,
match=re.escape("Non-numeric values"),
):
f(A_tri, b)
@pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"])
@pytest.mark.parametrize("trans", [True, False], ids=["trans=True", "trans=False"])
def test_numba_Cholesky(lower, trans):
cov = pt.matrix("cov")
if trans:
cov_ = cov.T
else:
cov_ = cov
chol = pt.linalg.cholesky(cov_, lower=lower)
x = np.array([0.1, 0.2, 0.3]).astype(floatX)
val = np.eye(3).astype(floatX) + x[None, :] * x[:, None]
compare_numba_and_py([cov], [chol], [val])
def test_numba_Cholesky_raises_on_nan_input():
test_value = rng.random(size=(3, 3)).astype(floatX)
test_value[0, 0] = np.nan
x = pt.tensor(dtype=floatX, shape=(3, 3))
x = x.T.dot(x)
g = pt.linalg.cholesky(x)
f = pytensor.function([x], g, mode="NUMBA")
with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"):
f(test_value)
@pytest.mark.parametrize("on_error", ["nan", "raise"])
def test_numba_Cholesky_raise_on(on_error):
test_value = rng.random(size=(3, 3)).astype(floatX)
x = pt.tensor(dtype=floatX, shape=(3, 3))
g = pt.linalg.cholesky(x, on_error=on_error)
f = pytensor.function([x], g, mode="NUMBA")
if on_error == "raise":
with pytest.raises(
np.linalg.LinAlgError, match=r"Input to cholesky is not positive definite"
):
f(test_value)
else:
assert np.all(np.isnan(f(test_value)))
@pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"])
def test_numba_Cholesky_grad(lower):
rng = np.random.default_rng(utt.fetch_seed())
L = rng.normal(size=(5, 5)).astype(floatX)
X = L @ L.T
chol_op = partial(pt.linalg.cholesky, lower=lower)
utt.verify_grad(chol_op, [X], mode="NUMBA")
def test_block_diag():
A = pt.matrix("A")
B = pt.matrix("B")
C = pt.matrix("C")
D = pt.matrix("D")
X = pt.linalg.block_diag(A, B, C, D)
A_val = np.random.normal(size=(5, 5)).astype(floatX)
B_val = np.random.normal(size=(3, 3)).astype(floatX)
C_val = np.random.normal(size=(2, 2)).astype(floatX)
D_val = np.random.normal(size=(4, 4)).astype(floatX)
compare_numba_and_py([A, B, C, D], [X], [A_val, B_val, C_val, D_val])
def test_lamch():
from scipy.linalg import get_lapack_funcs
......@@ -328,171 +82,396 @@ def test_xgecon(ord_numba, ord_scipy):
np.testing.assert_allclose(rcond, rcond2)
@pytest.mark.parametrize("overwrite_a", [True, False])
def test_getrf(overwrite_a):
from scipy.linalg import lu_factor
from pytensor.link.numba.dispatch.slinalg import _getrf
# TODO: Refactor this test to use compare_numba_and_py after we implement lu_factor in pytensor
@numba.njit()
def getrf(x, overwrite_a):
return _getrf(x, overwrite_a=overwrite_a)
x = np.random.normal(size=(5, 5)).astype(floatX)
x = np.asfortranarray(
x
) # x needs to be fortran-contiguous going into getrf for the overwrite option to work
lu, ipiv = lu_factor(x, overwrite_a=False)
LU, IPIV, info = getrf(x, overwrite_a=overwrite_a)
assert info == 0
assert_allclose(LU, lu)
if overwrite_a:
assert_allclose(x, LU)
# TODO: It seems IPIV is 1-indexed in FORTRAN, so we need to subtract 1. I can't find evidence that scipy is doing
# this, though.
assert_allclose(IPIV - 1, ipiv)
@pytest.mark.parametrize("trans", [0, 1])
@pytest.mark.parametrize("overwrite_a", [True, False])
@pytest.mark.parametrize("overwrite_b", [True, False])
@pytest.mark.parametrize("b_shape", [(5,), (5, 3)], ids=["b_1d", "b_2d"])
def test_getrs(trans, overwrite_a, overwrite_b, b_shape):
from scipy.linalg import lu_factor
from scipy.linalg import lu_solve as sp_lu_solve
from pytensor.link.numba.dispatch.slinalg import _getrf, _getrs
# TODO: Refactor this test to use compare_numba_and_py after we implement lu_solve in pytensor
@numba.njit()
def lu_solve(a, b, trans, overwrite_a, overwrite_b):
lu, ipiv, info = _getrf(a, overwrite_a=overwrite_a)
x, info = _getrs(lu, b, ipiv, trans=trans, overwrite_b=overwrite_b)
return x, lu, info
a = np.random.normal(size=(5, 5)).astype(floatX)
b = np.random.normal(size=b_shape).astype(floatX)
# inputs need to be fortran-contiguous going into getrf and getrs for the overwrite option to work
a = np.asfortranarray(a)
b = np.asfortranarray(b)
lu_and_piv = lu_factor(a, overwrite_a=False)
x_sp = sp_lu_solve(lu_and_piv, b, trans, overwrite_b=False)
x, lu, info = lu_solve(
a, b, trans, overwrite_a=overwrite_a, overwrite_b=overwrite_b
class TestSolves:
@pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower={x}")
@pytest.mark.parametrize(
"overwrite_a, overwrite_b",
[(False, False), (True, False), (False, True)],
ids=["no_overwrite", "overwrite_a", "overwrite_b"],
)
assert info == 0
if overwrite_a:
assert_allclose(a, lu)
if overwrite_b:
assert_allclose(b, x)
assert_allclose(x, x_sp)
@pytest.mark.parametrize(
"b_shape",
[(5, 1), (5, 5), (5,)],
ids=["b_col_vec", "b_matrix", "b_vec"],
)
@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str)
@pytest.mark.filterwarnings(
'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
)
def test_solve(b_shape: tuple[int], assume_a: Literal["gen", "sym", "pos"]):
A = pt.matrix("A", dtype=floatX)
b = pt.tensor("b", shape=b_shape, dtype=floatX)
A_val = np.asfortranarray(np.random.normal(size=(5, 5)).astype(floatX))
b_val = np.asfortranarray(np.random.normal(size=b_shape).astype(floatX))
def A_func(x):
if assume_a == "pos":
x = x @ x.T
elif assume_a == "sym":
x = (x + x.T) / 2
return x
X = pt.linalg.solve(
A_func(A),
b,
assume_a=assume_a,
b_ndim=len(b_shape),
@pytest.mark.parametrize(
"b_shape",
[(5, 1), (5, 5), (5,)],
ids=["b_col_vec", "b_matrix", "b_vec"],
)
@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str)
def test_solve(
self,
b_shape: tuple[int],
assume_a: Literal["gen", "sym", "pos"],
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
):
if assume_a not in ("sym", "her", "pos") and not lower:
# Avoid redundant tests with lower=True and lower=False for non symmetric matrices
pytest.skip("Skipping redundant test already covered by lower=True")
def A_func(x):
if assume_a == "pos":
x = x @ x.T
x = np.tril(x) if lower else np.triu(x)
elif assume_a == "sym":
x = (x + x.T) / 2
n = x.shape[0]
# We have to set the unused triangle to something other than zero
# to see lapack destroying it.
x[np.triu_indices(n, 1) if lower else np.tril_indices(n, 1)] = np.pi
return x
A = pt.matrix("A", dtype=floatX)
b = pt.tensor("b", shape=b_shape, dtype=floatX)
rng = np.random.default_rng(418)
A_val = A_func(rng.normal(size=(5, 5))).astype(floatX)
b_val = rng.normal(size=b_shape).astype(floatX)
X = pt.linalg.solve(
A,
b,
assume_a=assume_a,
b_ndim=len(b_shape),
)
f, res = compare_numba_and_py(
[In(A, mutable=overwrite_a), In(b, mutable=overwrite_b)],
X,
test_inputs=[A_val, b_val],
inplace=True,
numba_mode=numba_inplace_mode,
)
op = f.maker.fgraph.outputs[0].owner.op
assert isinstance(op, Solve)
destroy_map = op.destroy_map
if overwrite_a and overwrite_b:
raise NotImplementedError(
"Test not implemented for simultaneous overwrite_a and overwrite_b, as that's not currently supported by PyTensor"
)
elif overwrite_a:
assert destroy_map == {0: [0]}
elif overwrite_b:
assert destroy_map == {0: [1]}
else:
assert destroy_map == {}
# Test with F_contiguous inputs
A_val_f_contig = np.copy(A_val, order="F")
b_val_f_contig = np.copy(b_val, order="F")
res_f_contig = f(A_val_f_contig, b_val_f_contig)
np.testing.assert_allclose(res_f_contig, res)
# Should always be destroyable
assert (A_val == A_val_f_contig).all() == (not overwrite_a)
assert (b_val == b_val_f_contig).all() == (not overwrite_b)
# Test with C_contiguous inputs
A_val_c_contig = np.copy(A_val, order="C")
b_val_c_contig = np.copy(b_val, order="C")
res_c_contig = f(A_val_c_contig, b_val_c_contig)
np.testing.assert_allclose(res_c_contig, res)
np.testing.assert_allclose(A_val_c_contig, A_val)
# b vectors are always f_contiguous if also c_contiguous
assert np.allclose(b_val_c_contig, b_val) == (
not (overwrite_b and b_val_c_contig.flags.f_contiguous)
)
# Test right results if inputs are not contiguous in either format
A_val_not_contig = np.repeat(A_val, 2, axis=0)[::2]
b_val_not_contig = np.repeat(b_val, 2, axis=0)[::2]
res_not_contig = f(A_val_not_contig, b_val_not_contig)
np.testing.assert_allclose(res_not_contig, res)
# Can never destroy non-contiguous inputs
np.testing.assert_allclose(A_val_not_contig, A_val)
np.testing.assert_allclose(b_val_not_contig, b_val)
@pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower={x}")
@pytest.mark.parametrize(
"transposed", [False, True], ids=lambda x: f"transposed={x}"
)
f = pytensor.function(
[pytensor.In(A, mutable=True), pytensor.In(b, mutable=True)], X, mode="NUMBA"
@pytest.mark.parametrize(
"overwrite_b", [False, True], ids=["no_overwrite", "overwrite_b"]
)
op = f.maker.fgraph.outputs[0].owner.op
compare_numba_and_py([A, b], [X], test_inputs=[A_val, b_val], inplace=True)
# Calling this is destructive and will rewrite b_val to be the answer. Store copies of the inputs first.
A_val_copy = A_val.copy()
b_val_copy = b_val.copy()
X_np = f(A_val, b_val)
# overwrite_b is preferred when both inputs can be destroyed
assert op.destroy_map == {0: [1]}
@pytest.mark.parametrize(
"unit_diagonal", [True, False], ids=lambda x: f"unit_diagonal={x}"
)
@pytest.mark.parametrize(
"b_shape",
[(5, 1), (5, 5), (5,)],
ids=["b_col_vec", "b_matrix", "b_vec"],
)
@pytest.mark.parametrize("is_complex", [True, False], ids=["complex", "real"])
def test_solve_triangular(
self,
b_shape: tuple[int],
lower: bool,
transposed: bool,
unit_diagonal: bool,
is_complex: bool,
overwrite_b: bool,
):
if is_complex:
# TODO: Complex raises ValueError: To change to a dtype of a different size, the last axis must be contiguous,
# why?
pytest.skip("Complex inputs currently not supported to solve_triangular")
def A_func(x):
complex_dtype = "complex64" if floatX.endswith("32") else "complex128"
dtype = complex_dtype if is_complex else floatX
x = x @ x.conj().T
x_tri = scipy.linalg.cholesky(x, lower=lower).astype(dtype)
if unit_diagonal:
x_tri[np.diag_indices(x_tri.shape[0])] = 1.0
return x_tri
A = pt.matrix("A", dtype=floatX)
b = pt.tensor("b", shape=b_shape, dtype=floatX)
rng = np.random.default_rng(418)
A_val = A_func(rng.normal(size=(5, 5))).astype(floatX)
b_val = rng.normal(size=b_shape).astype(floatX)
X = pt.linalg.solve_triangular(
A,
b,
lower=lower,
trans="N" if (not transposed) else ("C" if is_complex else "T"),
unit_diagonal=unit_diagonal,
b_ndim=len(b_shape),
)
f, res = compare_numba_and_py(
[A, In(b, mutable=overwrite_b)],
X,
test_inputs=[A_val, b_val],
inplace=True,
numba_mode=numba_inplace_mode,
)
op = f.maker.fgraph.outputs[0].owner.op
assert isinstance(op, SolveTriangular)
destroy_map = op.destroy_map
if overwrite_b:
assert destroy_map == {0: [1]}
else:
assert destroy_map == {}
# Test with F_contiguous inputs
A_val_f_contig = np.copy(A_val, order="F")
b_val_f_contig = np.copy(b_val, order="F")
res_f_contig = f(A_val_f_contig, b_val_f_contig)
np.testing.assert_allclose(res_f_contig, res)
# solve_triangular never destroys A
np.testing.assert_allclose(A_val, A_val_f_contig)
# b Should always be destroyable
assert (b_val == b_val_f_contig).all() == (not overwrite_b)
# Test with C_contiguous inputs
A_val_c_contig = np.copy(A_val, order="C")
b_val_c_contig = np.copy(b_val, order="C")
res_c_contig = f(A_val_c_contig, b_val_c_contig)
np.testing.assert_allclose(res_c_contig, res)
np.testing.assert_allclose(A_val_c_contig, A_val)
# b c_contiguous vectors are also f_contiguous and destroyable
assert np.allclose(b_val_c_contig, b_val) == (
not (overwrite_b and b_val_c_contig.flags.f_contiguous)
)
# Test with non-contiguous inputs
A_val_not_contig = np.repeat(A_val, 2, axis=0)[::2]
b_val_not_contig = np.repeat(b_val, 2, axis=0)[::2]
res_not_contig = f(A_val_not_contig, b_val_not_contig)
np.testing.assert_allclose(res_not_contig, res)
np.testing.assert_allclose(A_val_not_contig, A_val)
# Can never destroy non-contiguous inputs
np.testing.assert_allclose(b_val_not_contig, b_val)
@pytest.mark.parametrize("value", [np.nan, np.inf])
def test_solve_triangular_raises_on_nan_inf(self, value):
A = pt.matrix("A")
b = pt.matrix("b")
X = pt.linalg.solve_triangular(A, b, check_finite=True)
f = pytensor.function([A, b], X, mode="NUMBA")
A_val = np.random.normal(size=(5, 5)).astype(floatX)
A_sym = A_val @ A_val.conj().T
A_tri = np.linalg.cholesky(A_sym).astype(floatX)
b = np.full((5, 1), value).astype(floatX)
# Confirm inputs were destroyed by checking against the copies
assert (A_val == A_val_copy).all() == (op.destroy_map.get(0, None) != [0])
assert (b_val == b_val_copy).all() == (op.destroy_map.get(0, None) != [1])
with pytest.raises(
np.linalg.LinAlgError,
match=re.escape("Non-numeric values"),
):
f(A_tri, b)
ATOL = 1e-8 if floatX.endswith("64") else 1e-4
RTOL = 1e-8 if floatX.endswith("64") else 1e-4
@pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower = {x}")
@pytest.mark.parametrize(
"overwrite_b", [False, True], ids=["no_overwrite", "overwrite_b"]
)
@pytest.mark.parametrize(
"b_func, b_shape",
[(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))],
ids=["b_col_vec", "b_matrix", "b_vec"],
)
def test_cho_solve(
self, b_func, b_shape: tuple[int, ...], lower: bool, overwrite_b: bool
):
def A_func(x):
x = x @ x.conj().T
x = scipy.linalg.cholesky(x, lower=lower)
return x
A = pt.matrix("A", dtype=floatX)
b = pt.tensor("b", shape=b_shape, dtype=floatX)
rng = np.random.default_rng(418)
A_val = A_func(rng.normal(size=(5, 5))).astype(floatX)
b_val = rng.normal(size=b_shape).astype(floatX)
X = pt.linalg.cho_solve(
(A, lower),
b,
b_ndim=len(b_shape),
)
f, res = compare_numba_and_py(
[A, In(b, mutable=overwrite_b)],
X,
test_inputs=[A_val, b_val],
inplace=True,
numba_mode=numba_inplace_mode,
)
op = f.maker.fgraph.outputs[0].owner.op
assert isinstance(op, CholeskySolve)
destroy_map = op.destroy_map
if overwrite_b:
assert destroy_map == {0: [1]}
else:
assert destroy_map == {}
# Test with F_contiguous inputs
A_val_f_contig = np.copy(A_val, order="F")
b_val_f_contig = np.copy(b_val, order="F")
res_f_contig = f(A_val_f_contig, b_val_f_contig)
np.testing.assert_allclose(res_f_contig, res)
# cho_solve never destroys A
np.testing.assert_allclose(A_val, A_val_f_contig)
# b Should always be destroyable
assert (b_val == b_val_f_contig).all() == (not overwrite_b)
# Test with C_contiguous inputs
A_val_c_contig = np.copy(A_val, order="C")
b_val_c_contig = np.copy(b_val, order="C")
res_c_contig = f(A_val_c_contig, b_val_c_contig)
np.testing.assert_allclose(res_c_contig, res)
np.testing.assert_allclose(A_val_c_contig, A_val)
# b c_contiguous vectors are also f_contiguous and destroyable
assert np.allclose(b_val_c_contig, b_val) == (
not (overwrite_b and b_val_c_contig.flags.f_contiguous)
)
# Test with non-contiguous inputs
A_val_not_contig = np.repeat(A_val, 2, axis=0)[::2]
b_val_not_contig = np.repeat(b_val, 2, axis=0)[::2]
res_not_contig = f(A_val_not_contig, b_val_not_contig)
np.testing.assert_allclose(res_not_contig, res)
np.testing.assert_allclose(A_val_not_contig, A_val)
# Can never destroy non-contiguous inputs
np.testing.assert_allclose(b_val_not_contig, b_val)
@pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower={x}")
@pytest.mark.parametrize(
"overwrite_a", [False, True], ids=["no_overwrite", "overwrite_a"]
)
def test_cholesky(lower: bool, overwrite_a: bool):
cov = pt.matrix("cov")
chol = pt.linalg.cholesky(cov, lower=lower)
# Confirm b_val is used to store to solution
np.testing.assert_allclose(X_np, b_val, atol=ATOL, rtol=RTOL)
assert not np.allclose(b_val, b_val_copy)
x = np.array([0.1, 0.2, 0.3]).astype(floatX)
val = np.eye(3).astype(floatX) + x[None, :] * x[:, None]
# Test that the result is numerically correct. Need to use the unmodified copy
np.testing.assert_allclose(
A_func(A_val_copy) @ X_np, b_val_copy, atol=ATOL, rtol=RTOL
fn, res = compare_numba_and_py(
[In(cov, mutable=overwrite_a)],
[chol],
[val],
numba_mode=numba_inplace_mode,
inplace=True,
)
# See the note in tensor/test_slinalg.py::test_solve_correctness for details about the setup here
utt.verify_grad(
lambda A, b: pt.linalg.solve(
A_func(A), b, lower=False, assume_a=assume_a, b_ndim=len(b_shape)
),
[A_val_copy, b_val_copy],
mode="NUMBA",
)
op = fn.maker.fgraph.outputs[0].owner.op
assert isinstance(op, Cholesky)
destroy_map = op.destroy_map
if overwrite_a:
assert destroy_map == {0: [0]}
else:
assert destroy_map == {}
# Test F-contiguous input
val_f_contig = np.copy(val, order="F")
res_f_contig = fn(val_f_contig)
np.testing.assert_allclose(res_f_contig, res)
# Should always be destroyable
assert (val == val_f_contig).all() == (not overwrite_a)
# Test C-contiguous input
val_c_contig = np.copy(val, order="C")
res_c_contig = fn(val_c_contig)
np.testing.assert_allclose(res_c_contig, res)
# Cannot destroy C-contiguous input
np.testing.assert_allclose(val_c_contig, val)
# Test non-contiguous input
val_not_contig = np.repeat(val, 2, axis=0)[::2]
res_not_contig = fn(val_not_contig)
np.testing.assert_allclose(res_not_contig, res)
# Cannot destroy non-contiguous input
np.testing.assert_allclose(val_not_contig, val)
def test_cholesky_raises_on_nan_input():
test_value = rng.random(size=(3, 3)).astype(floatX)
test_value[0, 0] = np.nan
x = pt.tensor(dtype=floatX, shape=(3, 3))
x = x.T.dot(x)
g = pt.linalg.cholesky(x)
f = pytensor.function([x], g, mode="NUMBA")
@pytest.mark.parametrize(
"b_func, b_size",
[(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))],
ids=["b_col_vec", "b_matrix", "b_vec"],
)
@pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower = {x}")
def test_cho_solve(b_func, b_size, lower):
A = pt.matrix("A", dtype=floatX)
b = b_func("b", dtype=floatX)
with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"):
f(test_value)
C = pt.linalg.cholesky(A, lower=lower)
X = pt.linalg.cho_solve((C, lower), b)
f = pytensor.function([A, b], X, mode="NUMBA")
A = np.random.normal(size=(5, 5)).astype(floatX)
A = A @ A.conj().T
@pytest.mark.parametrize("on_error", ["nan", "raise"])
def test_cholesky_raise_on(on_error):
test_value = rng.random(size=(3, 3)).astype(floatX)
x = pt.tensor(dtype=floatX, shape=(3, 3))
g = pt.linalg.cholesky(x, on_error=on_error)
f = pytensor.function([x], g, mode="NUMBA")
b = np.random.normal(size=b_size)
b = b.astype(floatX)
if on_error == "raise":
with pytest.raises(
np.linalg.LinAlgError, match=r"Input to cholesky is not positive definite"
):
f(test_value)
else:
assert np.all(np.isnan(f(test_value)))
X_np = f(A, b)
ATOL = 1e-8 if floatX.endswith("64") else 1e-4
RTOL = 1e-8 if floatX.endswith("64") else 1e-4
def test_block_diag():
A = pt.matrix("A")
B = pt.matrix("B")
C = pt.matrix("C")
D = pt.matrix("D")
X = pt.linalg.block_diag(A, B, C, D)
np.testing.assert_allclose(A @ X_np, b, atol=ATOL, rtol=RTOL)
A_val = np.random.normal(size=(5, 5)).astype(floatX)
B_val = np.random.normal(size=(3, 3)).astype(floatX)
C_val = np.random.normal(size=(2, 2)).astype(floatX)
D_val = np.random.normal(size=(4, 4)).astype(floatX)
compare_numba_and_py([A, B, C, D], [X], [A_val, B_val, C_val, D_val])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论