提交 e98cbbcf authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: Jesse Grabowski

Numba dispatch for LU ops

上级 679b2f71
...@@ -76,7 +76,7 @@ def numba_njit(*args, fastmath=None, **kwargs): ...@@ -76,7 +76,7 @@ def numba_njit(*args, fastmath=None, **kwargs):
message=( message=(
"(\x1b\\[1m)*" # ansi escape code for bold text "(\x1b\\[1m)*" # ansi escape code for bold text
"Cannot cache compiled function " "Cannot cache compiled function "
'"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve)" ' '"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve|lu_factor)" '
"as it uses dynamic globals" "as it uses dynamic globals"
), ),
category=NumbaWarning, category=NumbaWarning,
......
from collections.abc import Callable
from typing import cast as typing_cast
import numpy as np
from numba import njit as numba_njit
from numba.core.extending import overload
from numba.np.linalg import ensure_lapack
from scipy import linalg
from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _getrf
from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix
@numba_njit
def _pivot_to_permutation(p, dtype):
p_inv = np.arange(len(p)).astype(dtype)
for i in range(len(p)):
p_inv[i], p_inv[p[i]] = p_inv[p[i]], p_inv[i]
return p_inv
@numba_njit
def _lu_factor_to_lu(a, dtype, overwrite_a):
A_copy, IPIV, INFO = _getrf(a, overwrite_a=overwrite_a)
L = np.eye(A_copy.shape[-1], dtype=dtype)
L += np.tril(A_copy, k=-1)
U = np.triu(A_copy)
# Fortran is 1 indexed, so we need to subtract 1 from the IPIV array
IPIV = IPIV - 1
p_inv = _pivot_to_permutation(IPIV, dtype=dtype)
perm = np.argsort(p_inv)
return perm, L, U
def _lu_1(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor.
Called when permute_l is True and p_indices is False, and returns a tuple of (perm, L, U), where perm an integer
array of row swaps, such that L[perm] @ U = A.
"""
return typing_cast(
tuple[np.ndarray, np.ndarray, np.ndarray],
linalg.lu(
a,
permute_l=permute_l,
check_finite=check_finite,
p_indices=p_indices,
overwrite_a=overwrite_a,
),
)
def _lu_2(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray]:
"""
Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor.
Called when permute_l is False and p_indices is True, and returns a tuple of (PL, U), where PL is the
permuted L matrix, PL = P @ L.
"""
return typing_cast(
tuple[np.ndarray, np.ndarray],
linalg.lu(
a,
permute_l=permute_l,
check_finite=check_finite,
p_indices=p_indices,
overwrite_a=overwrite_a,
),
)
def _lu_3(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor.
Called when permute_l is False and p_indices is False, and returns a tuple of (P, L, U), where P is the permutation
matrix, P @ L @ U = A.
"""
return typing_cast(
tuple[np.ndarray, np.ndarray, np.ndarray],
linalg.lu(
a,
permute_l=permute_l,
check_finite=check_finite,
p_indices=p_indices,
overwrite_a=overwrite_a,
),
)
@overload(_lu_1)
def lu_impl_1(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> Callable[
[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray]
]:
"""
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
False. Returns a tuple of (perm, L, U), where perm an integer array of row swaps, such that L[perm] @ U = A.
"""
ensure_lapack()
_check_scipy_linalg_matrix(a, "lu")
dtype = a.dtype
def impl(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
perm, L, U = _lu_factor_to_lu(a, dtype, overwrite_a)
return perm, L, U
return impl
@overload(_lu_2)
def lu_impl_2(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> Callable[[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray]]:
"""
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is False and p_indices is
True. Returns a tuple of (PL, U), where PL is the permuted L matrix, PL = P @ L.
"""
ensure_lapack()
_check_scipy_linalg_matrix(a, "lu")
dtype = a.dtype
def impl(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray]:
perm, L, U = _lu_factor_to_lu(a, dtype, overwrite_a)
PL = L[perm]
return PL, U
return impl
@overload(_lu_3)
def lu_impl_3(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> Callable[
[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray]
]:
"""
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
False. Returns a tuple of (P, L, U), such that P @ L @ U = A.
"""
ensure_lapack()
_check_scipy_linalg_matrix(a, "lu")
dtype = a.dtype
def impl(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
perm, L, U = _lu_factor_to_lu(a, dtype, overwrite_a)
P = np.eye(a.shape[-1], dtype=dtype)[perm]
return P, L, U
return impl
from collections.abc import Callable
import numpy as np
from numba.core.extending import overload
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix,
)
def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]:
"""
Underlying LAPACK function used for LU factorization. Compared to scipy.linalg.lu_factorize, this function also
returns an info code with diagnostic information.
"""
(getrf,) = linalg.get_lapack_funcs("getrf", (A,))
A_copy, ipiv, info = getrf(A, overwrite_a=overwrite_a)
return A_copy, ipiv, info
@overload(_getrf)
def getrf_impl(
A: np.ndarray, overwrite_a: bool = False
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray, int]]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "getrf")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_getrf = _LAPACK().numba_xgetrf(dtype)
def impl(
A: np.ndarray, overwrite_a: bool = False
) -> tuple[np.ndarray, np.ndarray, int]:
_M, _N = np.int32(A.shape[-2:]) # type: ignore
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
M = val_to_int_ptr(_M) # type: ignore
N = val_to_int_ptr(_N) # type: ignore
LDA = val_to_int_ptr(_M) # type: ignore
IPIV = np.empty(_N, dtype=np.int32) # type: ignore
INFO = val_to_int_ptr(0)
numba_getrf(M, N, A_copy.view(w_type).ctypes, LDA, IPIV.ctypes, INFO)
return A_copy, IPIV, int_ptr_to_val(INFO)
return impl
def _lu_factor(A: np.ndarray, overwrite_a: bool = False):
"""
Thin wrapper around scipy.linalg.lu_factor. Used as an overload target to avoid side-effects on users who import
Pytensor.
"""
return linalg.lu_factor(A, overwrite_a=overwrite_a)
@overload(_lu_factor)
def lu_factor_impl(
A: np.ndarray, overwrite_a: bool = False
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray]]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "lu_factor")
def impl(A: np.ndarray, overwrite_a: bool = False) -> tuple[np.ndarray, np.ndarray]:
A_copy, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a)
IPIV -= 1 # LAPACK uses 1-based indexing, convert to 0-based
if INFO != 0:
raise np.linalg.LinAlgError("LU decomposition failed")
return A_copy, IPIV
return impl
...@@ -11,13 +11,13 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import ( ...@@ -11,13 +11,13 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
int_ptr_to_val, int_ptr_to_val,
val_to_int_ptr, val_to_int_ptr,
) )
from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _getrf
from pytensor.link.numba.dispatch.linalg.solve.lu_solve import _getrs
from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import ( from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix, _check_scipy_linalg_matrix,
_copy_to_fortran_order_even_if_1d,
_solve_check, _solve_check,
_trans_char_to_int,
) )
...@@ -72,116 +72,6 @@ def xgecon_impl( ...@@ -72,116 +72,6 @@ def xgecon_impl(
return impl return impl
def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]:
"""
Placeholder for LU factorization; used by linalg.solve.
# TODO: Implement an LU_factor Op, then dispatch to this function in numba mode.
"""
return # type: ignore
@overload(_getrf)
def getrf_impl(
A: np.ndarray, overwrite_a: bool = False
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray, int]]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "getrf")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_getrf = _LAPACK().numba_xgetrf(dtype)
def impl(
A: np.ndarray, overwrite_a: bool = False
) -> tuple[np.ndarray, np.ndarray, int]:
_M, _N = np.int32(A.shape[-2:]) # type: ignore
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
M = val_to_int_ptr(_M) # type: ignore
N = val_to_int_ptr(_N) # type: ignore
LDA = val_to_int_ptr(_M) # type: ignore
IPIV = np.empty(_N, dtype=np.int32) # type: ignore
INFO = val_to_int_ptr(0)
numba_getrf(M, N, A_copy.view(w_type).ctypes, LDA, IPIV.ctypes, INFO)
return A_copy, IPIV, int_ptr_to_val(INFO)
return impl
def _getrs(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
) -> tuple[np.ndarray, int]:
"""
Placeholder for solving a linear system with a matrix that has been LU-factored; used by linalg.solve.
# TODO: Implement an LU_solve Op, then dispatch to this function in numba mode.
"""
return # type: ignore
@overload(_getrs)
def getrs_impl(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, int, bool], tuple[np.ndarray, int]]:
ensure_lapack()
_check_scipy_linalg_matrix(LU, "getrs")
_check_scipy_linalg_matrix(B, "getrs")
dtype = LU.dtype
w_type = _get_underlying_float(dtype)
numba_getrs = _LAPACK().numba_xgetrs(dtype)
def impl(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
) -> tuple[np.ndarray, int]:
_N = np.int32(LU.shape[-1])
_solve_check_input_shapes(LU, B)
B_is_1d = B.ndim == 1
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)
if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)
NRHS = 1 if B_is_1d else int(B_copy.shape[-1])
TRANS = val_to_int_ptr(_trans_char_to_int(trans))
N = val_to_int_ptr(_N)
NRHS = val_to_int_ptr(NRHS)
LDA = val_to_int_ptr(_N)
LDB = val_to_int_ptr(_N)
IPIV = _copy_to_fortran_order(IPIV)
INFO = val_to_int_ptr(0)
numba_getrs(
TRANS,
N,
NRHS,
LU.view(w_type).ctypes,
LDA,
IPIV.ctypes,
B_copy.view(w_type).ctypes,
LDB,
INFO,
)
if B_is_1d:
B_copy = B_copy[..., 0]
return B_copy, int_ptr_to_val(INFO)
return impl
def _solve_gen( def _solve_gen(
A: np.ndarray, A: np.ndarray,
B: np.ndarray, B: np.ndarray,
......
from collections.abc import Callable
import numpy as np
from numba.core.extending import overload
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix,
_copy_to_fortran_order_even_if_1d,
_solve_check,
_trans_char_to_int,
)
def _getrs(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
) -> tuple[np.ndarray, int]:
"""
Placeholder for solving a linear system with a matrix that has been LU-factored. Used by linalg.lu_solve.
"""
return # type: ignore
@overload(_getrs)
def getrs_impl(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, int, bool], tuple[np.ndarray, int]]:
ensure_lapack()
_check_scipy_linalg_matrix(LU, "getrs")
_check_scipy_linalg_matrix(B, "getrs")
dtype = LU.dtype
w_type = _get_underlying_float(dtype)
numba_getrs = _LAPACK().numba_xgetrs(dtype)
def impl(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
) -> tuple[np.ndarray, int]:
_N = np.int32(LU.shape[-1])
_solve_check_input_shapes(LU, B)
B_is_1d = B.ndim == 1
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)
if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)
NRHS = 1 if B_is_1d else int(B_copy.shape[-1])
TRANS = val_to_int_ptr(_trans_char_to_int(trans))
N = val_to_int_ptr(_N)
NRHS = val_to_int_ptr(NRHS)
LDA = val_to_int_ptr(_N)
LDB = val_to_int_ptr(_N)
IPIV = _copy_to_fortran_order(IPIV)
INFO = val_to_int_ptr(0)
numba_getrs(
TRANS,
N,
NRHS,
LU.view(w_type).ctypes,
LDA,
IPIV.ctypes,
B_copy.view(w_type).ctypes,
LDB,
INFO,
)
if B_is_1d:
B_copy = B_copy[..., 0]
return B_copy, int_ptr_to_val(INFO)
return impl
def _lu_solve(
lu_and_piv: tuple[np.ndarray, np.ndarray],
b: np.ndarray,
trans: int,
overwrite_b: bool,
check_finite: bool,
):
"""
Thin wrapper around scipy.lu_solve, used to avoid side effects from numba overloads on users who import Pytensor.
"""
return linalg.lu_solve(
lu_and_piv, b, trans=trans, overwrite_b=overwrite_b, check_finite=check_finite
)
@overload(_lu_solve)
def lu_solve_impl(
lu_and_piv: tuple[np.ndarray, np.ndarray],
b: np.ndarray,
trans: int,
overwrite_b: bool,
check_finite: bool,
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, bool, bool, bool], np.ndarray]:
ensure_lapack()
_check_scipy_linalg_matrix(lu_and_piv[0], "lu_solve")
_check_scipy_linalg_matrix(b, "lu_solve")
def impl(
lu: np.ndarray,
piv: np.ndarray,
b: np.ndarray,
trans: int,
overwrite_b: bool,
check_finite: bool,
) -> np.ndarray:
n = np.int32(lu.shape[0])
X, INFO = _getrs(LU=lu, B=b, IPIV=piv, trans=trans, overwrite_b=overwrite_b)
_solve_check(n, INFO)
return X
return impl
...@@ -4,6 +4,13 @@ import numpy as np ...@@ -4,6 +4,13 @@ import numpy as np
from pytensor.link.numba.dispatch.basic import numba_funcify, numba_njit from pytensor.link.numba.dispatch.basic import numba_funcify, numba_njit
from pytensor.link.numba.dispatch.linalg.decomposition.cholesky import _cholesky from pytensor.link.numba.dispatch.linalg.decomposition.cholesky import _cholesky
from pytensor.link.numba.dispatch.linalg.decomposition.lu import (
_lu_1,
_lu_2,
_lu_3,
_pivot_to_permutation,
)
from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _lu_factor
from pytensor.link.numba.dispatch.linalg.solve.cholesky import _cho_solve from pytensor.link.numba.dispatch.linalg.solve.cholesky import _cho_solve
from pytensor.link.numba.dispatch.linalg.solve.general import _solve_gen from pytensor.link.numba.dispatch.linalg.solve.general import _solve_gen
from pytensor.link.numba.dispatch.linalg.solve.posdef import _solve_psd from pytensor.link.numba.dispatch.linalg.solve.posdef import _solve_psd
...@@ -11,9 +18,12 @@ from pytensor.link.numba.dispatch.linalg.solve.symmetric import _solve_symmetric ...@@ -11,9 +18,12 @@ from pytensor.link.numba.dispatch.linalg.solve.symmetric import _solve_symmetric
from pytensor.link.numba.dispatch.linalg.solve.triangular import _solve_triangular from pytensor.link.numba.dispatch.linalg.solve.triangular import _solve_triangular
from pytensor.link.numba.dispatch.linalg.solve.tridiagonal import _solve_tridiagonal from pytensor.link.numba.dispatch.linalg.solve.tridiagonal import _solve_tridiagonal
from pytensor.tensor.slinalg import ( from pytensor.tensor.slinalg import (
LU,
BlockDiagonal, BlockDiagonal,
Cholesky, Cholesky,
CholeskySolve, CholeskySolve,
LUFactor,
PivotToPermutations,
Solve, Solve,
SolveTriangular, SolveTriangular,
) )
...@@ -70,6 +80,96 @@ def numba_funcify_Cholesky(op, node, **kwargs): ...@@ -70,6 +80,96 @@ def numba_funcify_Cholesky(op, node, **kwargs):
return cholesky return cholesky
@numba_funcify.register(PivotToPermutations)
def pivot_to_permutation(op, node, **kwargs):
inverse = op.inverse
dtype = node.inputs[0].dtype
@numba_njit
def numba_pivot_to_permutation(piv):
p_inv = _pivot_to_permutation(piv, dtype)
if inverse:
return p_inv
return np.argsort(p_inv)
return numba_pivot_to_permutation
@numba_funcify.register(LU)
def numba_funcify_LU(op, node, **kwargs):
permute_l = op.permute_l
check_finite = op.check_finite
p_indices = op.p_indices
overwrite_a = op.overwrite_a
dtype = node.inputs[0].dtype
if dtype in complex_dtypes:
NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
@numba_njit(inline="always")
def lu(a):
if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) found in input to lu"
)
if p_indices:
res = _lu_1(
a,
permute_l=permute_l,
check_finite=check_finite,
p_indices=p_indices,
overwrite_a=overwrite_a,
)
elif permute_l:
res = _lu_2(
a,
permute_l=permute_l,
check_finite=check_finite,
p_indices=p_indices,
overwrite_a=overwrite_a,
)
else:
res = _lu_3(
a,
permute_l=permute_l,
check_finite=check_finite,
p_indices=p_indices,
overwrite_a=overwrite_a,
)
return res
return lu
@numba_funcify.register(LUFactor)
def numba_funcify_LUFactor(op, node, **kwargs):
dtype = node.inputs[0].dtype
check_finite = op.check_finite
overwrite_a = op.overwrite_a
if dtype in complex_dtypes:
NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
@numba_njit
def lu_factor(a):
if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) found in input to cholesky"
)
LU, piv = _lu_factor(a, overwrite_a)
return LU, piv
return lu_factor
@numba_funcify.register(BlockDiagonal) @numba_funcify.register(BlockDiagonal)
def numba_funcify_BlockDiagonal(op, node, **kwargs): def numba_funcify_BlockDiagonal(op, node, **kwargs):
dtype = node.outputs[0].dtype dtype = node.outputs[0].dtype
......
...@@ -8,7 +8,14 @@ import scipy ...@@ -8,7 +8,14 @@ import scipy
import pytensor import pytensor
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import In, config from pytensor import In, config
from pytensor.tensor.slinalg import Cholesky, CholeskySolve, Solve, SolveTriangular from pytensor.tensor.slinalg import (
LU,
Cholesky,
CholeskySolve,
LUFactor,
Solve,
SolveTriangular,
)
from tests.link.numba.test_basic import compare_numba_and_py, numba_inplace_mode from tests.link.numba.test_basic import compare_numba_and_py, numba_inplace_mode
...@@ -494,3 +501,222 @@ def test_block_diag(): ...@@ -494,3 +501,222 @@ def test_block_diag():
C_val = np.random.normal(size=(2, 2)).astype(floatX) C_val = np.random.normal(size=(2, 2)).astype(floatX)
D_val = np.random.normal(size=(4, 4)).astype(floatX) D_val = np.random.normal(size=(4, 4)).astype(floatX)
compare_numba_and_py([A, B, C, D], [X], [A_val, B_val, C_val, D_val]) compare_numba_and_py([A, B, C, D], [X], [A_val, B_val, C_val, D_val])
@pytest.mark.parametrize("inverse", [True, False], ids=["p_inv", "p"])
def test_pivot_to_permutation(inverse):
from pytensor.tensor.slinalg import pivot_to_permutation
rng = np.random.default_rng(123)
A = rng.normal(size=(5, 5)).astype(floatX)
perm_pt = pt.vector("p", dtype="int32")
piv_pt = pivot_to_permutation(perm_pt, inverse=inverse)
f = pytensor.function([perm_pt], piv_pt, mode="NUMBA")
_, piv = scipy.linalg.lu_factor(A)
if inverse:
p = np.arange(len(piv))
for i in range(len(piv)):
p[i], p[piv[i]] = p[piv[i]], p[i]
np.testing.assert_allclose(f(piv), p)
else:
p, *_ = scipy.linalg.lu(A, p_indices=True)
np.testing.assert_allclose(f(piv), p)
@pytest.mark.parametrize(
"permute_l, p_indices",
[(True, False), (False, True), (False, False)],
ids=["PL", "p_indices", "P"],
)
@pytest.mark.parametrize(
"overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"]
)
def test_lu(permute_l, p_indices, overwrite_a):
shape = (5, 5)
rng = np.random.default_rng()
A = pt.tensor(
"A",
shape=shape,
dtype=config.floatX,
)
A_val = rng.normal(size=shape).astype(config.floatX)
lu_outputs = pt.linalg.lu(A, permute_l=permute_l, p_indices=p_indices)
fn, res = compare_numba_and_py(
[In(A, mutable=overwrite_a)],
lu_outputs,
[A_val],
numba_mode=numba_inplace_mode,
inplace=True,
)
op = fn.maker.fgraph.outputs[0].owner.op
assert isinstance(op, LU)
destroy_map = op.destroy_map
if overwrite_a and permute_l:
assert destroy_map == {0: [0]}
elif overwrite_a:
assert destroy_map == {1: [0]}
else:
assert destroy_map == {}
# Test F-contiguous input
val_f_contig = np.copy(A_val, order="F")
res_f_contig = fn(val_f_contig)
for x, x_f_contig in zip(res, res_f_contig, strict=True):
np.testing.assert_allclose(x, x_f_contig)
# Should always be destroyable
assert (A_val == val_f_contig).all() == (not overwrite_a)
# Test C-contiguous input
val_c_contig = np.copy(A_val, order="C")
res_c_contig = fn(val_c_contig)
for x, x_c_contig in zip(res, res_c_contig, strict=True):
np.testing.assert_allclose(x, x_c_contig)
# Cannot destroy C-contiguous input
np.testing.assert_allclose(val_c_contig, A_val)
# Test non-contiguous input
val_not_contig = np.repeat(A_val, 2, axis=0)[::2]
res_not_contig = fn(val_not_contig)
for x, x_not_contig in zip(res, res_not_contig, strict=True):
np.testing.assert_allclose(x, x_not_contig)
# Cannot destroy non-contiguous input
np.testing.assert_allclose(val_not_contig, A_val)
@pytest.mark.parametrize(
"overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"]
)
def test_lu_factor(overwrite_a):
shape = (5, 5)
rng = np.random.default_rng()
A = pt.tensor("A", shape=shape, dtype=config.floatX)
A_val = rng.normal(size=shape).astype(config.floatX)
LU, piv = pt.linalg.lu_factor(A)
fn, res = compare_numba_and_py(
[In(A, mutable=overwrite_a)],
[LU, piv],
[A_val],
numba_mode=numba_inplace_mode,
inplace=True,
)
op = fn.maker.fgraph.outputs[0].owner.op
assert isinstance(op, LUFactor)
if overwrite_a:
assert op.destroy_map == {1: [0]}
# Test F-contiguous input
val_f_contig = np.copy(A_val, order="F")
res_f_contig = fn(val_f_contig)
for x, x_f_contig in zip(res, res_f_contig, strict=True):
np.testing.assert_allclose(x, x_f_contig)
# Should always be destroyable
assert (A_val == val_f_contig).all() == (not overwrite_a)
# Test C-contiguous input
val_c_contig = np.copy(A_val, order="C")
res_c_contig = fn(val_c_contig)
for x, x_c_contig in zip(res, res_c_contig, strict=True):
np.testing.assert_allclose(x, x_c_contig)
# Cannot destroy C-contiguous input
np.testing.assert_allclose(val_c_contig, A_val)
# Test non-contiguous input
val_not_contig = np.repeat(A_val, 2, axis=0)[::2]
res_not_contig = fn(val_not_contig)
for x, x_not_contig in zip(res, res_not_contig, strict=True):
np.testing.assert_allclose(x, x_not_contig)
# Cannot destroy non-contiguous input
np.testing.assert_allclose(val_not_contig, A_val)
@pytest.mark.parametrize("trans", [True, False], ids=lambda x: f"trans = {x}")
@pytest.mark.parametrize(
"overwrite_b", [False, True], ids=["no_overwrite", "overwrite_b"]
)
@pytest.mark.parametrize(
"b_func, b_shape",
[(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))],
ids=["b_col_vec", "b_matrix", "b_vec"],
)
def test_lu_solve(b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bool):
A = pt.matrix("A", dtype=floatX)
b = pt.tensor("b", shape=b_shape, dtype=floatX)
rng = np.random.default_rng(418)
A_val = rng.normal(size=(5, 5)).astype(floatX)
b_val = rng.normal(size=b_shape).astype(floatX)
lu_and_piv = pt.linalg.lu_factor(A)
X = pt.linalg.lu_solve(
lu_and_piv,
b,
b_ndim=len(b_shape),
trans=trans,
)
f, res = compare_numba_and_py(
[A, In(b, mutable=overwrite_b)],
X,
test_inputs=[A_val, b_val],
inplace=True,
numba_mode=numba_inplace_mode,
eval_obj_mode=False,
)
# Test with F_contiguous inputs
A_val_f_contig = np.copy(A_val, order="F")
b_val_f_contig = np.copy(b_val, order="F")
res_f_contig = f(A_val_f_contig, b_val_f_contig)
np.testing.assert_allclose(res_f_contig, res)
all_equal = (b_val == b_val_f_contig).all()
should_destroy = overwrite_b and trans
if should_destroy:
assert not all_equal
else:
assert all_equal
# Test with C_contiguous inputs
A_val_c_contig = np.copy(A_val, order="C")
b_val_c_contig = np.copy(b_val, order="C")
res_c_contig = f(A_val_c_contig, b_val_c_contig)
np.testing.assert_allclose(res_c_contig, res)
np.testing.assert_allclose(A_val_c_contig, A_val)
# b c_contiguous vectors are also f_contiguous and destroyable
assert not (should_destroy and b_val_c_contig.flags.f_contiguous) == np.allclose(
b_val_c_contig, b_val
)
# Test with non-contiguous inputs
A_val_not_contig = np.repeat(A_val, 2, axis=0)[::2]
b_val_not_contig = np.repeat(b_val, 2, axis=0)[::2]
res_not_contig = f(A_val_not_contig, b_val_not_contig)
np.testing.assert_allclose(res_not_contig, res)
np.testing.assert_allclose(A_val_not_contig, A_val)
# Can never destroy non-contiguous inputs
np.testing.assert_allclose(b_val_not_contig, b_val)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论