提交 672a4829 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Do not raise in linalg Ops

上级 b2d8bc24
......@@ -92,7 +92,6 @@ def jax_funcify_Solve(op, **kwargs):
def jax_funcify_SolveTriangular(op, **kwargs):
lower = op.lower
unit_diagonal = op.unit_diagonal
check_finite = op.check_finite
def solve_triangular(A, b):
return jax.scipy.linalg.solve_triangular(
......@@ -101,7 +100,7 @@ def jax_funcify_SolveTriangular(op, **kwargs):
lower=lower,
trans=0, # this is handled by explicitly transposing A, so it will always be 0 when we get to here.
unit_diagonal=unit_diagonal,
check_finite=check_finite,
check_finite=False,
)
return solve_triangular
......@@ -132,27 +131,23 @@ def jax_funcify_PivotToPermutation(op, **kwargs):
def jax_funcify_LU(op, **kwargs):
permute_l = op.permute_l
p_indices = op.p_indices
check_finite = op.check_finite
if p_indices:
raise ValueError("JAX does not support the p_indices argument")
def lu(*inputs):
return jax.scipy.linalg.lu(
*inputs, permute_l=permute_l, check_finite=check_finite
)
return jax.scipy.linalg.lu(*inputs, permute_l=permute_l, check_finite=False)
return lu
@jax_funcify.register(LUFactor)
def jax_funcify_LUFactor(op, **kwargs):
check_finite = op.check_finite
overwrite_a = op.overwrite_a
def lu_factor(a):
return jax.scipy.linalg.lu_factor(
a, check_finite=check_finite, overwrite_a=overwrite_a
a, check_finite=False, overwrite_a=overwrite_a
)
return lu_factor
......@@ -161,12 +156,11 @@ def jax_funcify_LUFactor(op, **kwargs):
@jax_funcify.register(CholeskySolve)
def jax_funcify_ChoSolve(op, **kwargs):
lower = op.lower
check_finite = op.check_finite
overwrite_b = op.overwrite_b
def cho_solve(c, b):
return jax.scipy.linalg.cho_solve(
(c, lower), b, check_finite=check_finite, overwrite_b=overwrite_b
(c, lower), b, check_finite=False, overwrite_b=overwrite_b
)
return cho_solve
......
......@@ -263,122 +263,6 @@ class _LAPACK:
return potrs
@classmethod
def numba_xlange(cls, dtype) -> CPUDispatcher:
"""
Compute the value of the 1-norm, Frobenius norm, infinity-norm, or the largest absolute value of any element of
a general M-by-N matrix A.
Called by scipy.linalg.solve, but doesn't correspond to any Op in pytensor.
"""
kind = get_blas_kind(dtype)
float_type = _get_nb_float_from_dtype(kind, return_pointer=False)
float_pointer = _get_nb_float_from_dtype(kind, return_pointer=True)
unique_func_name = f"scipy.lapack.{kind}lange"
@numba_basic.numba_njit
def get_lange_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "lange")
return ptr
lange_function_type = types.FunctionType(
float_type(
nb_i32p, # NORM
nb_i32p, # M
nb_i32p, # N
float_pointer, # A
nb_i32p, # LDA
float_pointer, # WORK
)
)
@numba_basic.numba_njit
def lange(NORM, M, N, A, LDA, WORK):
fn = _call_cached_ptr(
get_ptr_func=get_lange_pointer,
func_type_ref=lange_function_type,
unique_func_name_lit=unique_func_name,
)
return fn(NORM, M, N, A, LDA, WORK)
return lange
@classmethod
def numba_xlamch(cls, dtype) -> CPUDispatcher:
"""
Determine machine precision for floating point arithmetic.
"""
kind = get_blas_kind(dtype)
float_type = _get_nb_float_from_dtype(kind, return_pointer=False)
unique_func_name = f"scipy.lapack.{kind}lamch"
@numba_basic.numba_njit
def get_lamch_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "lamch")
return ptr
lamch_function_type = types.FunctionType(
float_type( # Return type
nb_i32p, # CMACH
)
)
@numba_basic.numba_njit
def lamch(CMACH):
fn = _call_cached_ptr(
get_ptr_func=get_lamch_pointer,
func_type_ref=lamch_function_type,
unique_func_name_lit=unique_func_name,
)
res = fn(CMACH)
return res
return lamch
@classmethod
def numba_xgecon(cls, dtype) -> CPUDispatcher:
"""
Estimates the condition number of a matrix A, using the LU factorization computed by numba_getrf.
Called by scipy.linalg.solve when assume_a == "gen"
"""
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}gecon"
@numba_basic.numba_njit
def get_gecon_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "gecon")
return ptr
gecon_function_type = types.FunctionType(
types.void(
nb_i32p, # NORM
nb_i32p, # N
float_pointer, # A
nb_i32p, # LDA
float_pointer, # ANORM
float_pointer, # RCOND
float_pointer, # WORK
nb_i32p, # IWORK
nb_i32p, # INFO
)
)
@numba_basic.numba_njit
def gecon(NORM, N, A, LDA, ANORM, RCOND, WORK, IWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_gecon_pointer,
func_type_ref=gecon_function_type,
unique_func_name_lit=unique_func_name,
)
fn(NORM, N, A, LDA, ANORM, RCOND, WORK, IWORK, INFO)
return gecon
@classmethod
def numba_xgetrf(cls, dtype) -> CPUDispatcher:
"""
......@@ -506,91 +390,6 @@ class _LAPACK:
return sysv
@classmethod
def numba_xsycon(cls, dtype) -> CPUDispatcher:
"""
Estimate the reciprocal of the condition number of a symmetric matrix A using the UDU or LDL factorization
computed by xSYTRF.
"""
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}sycon"
@numba_basic.numba_njit
def get_sycon_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "sycon")
return ptr
sycon_function_type = types.FunctionType(
types.void(
nb_i32p, # UPLO
nb_i32p, # N
float_pointer, # A
nb_i32p, # LDA
nb_i32p, # IPIV
float_pointer, # ANORM
float_pointer, # RCOND
float_pointer, # WORK
nb_i32p, # IWORK
nb_i32p, # INFO
)
)
@numba_basic.numba_njit
def sycon(UPLO, N, A, LDA, IPIV, ANORM, RCOND, WORK, IWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_sycon_pointer,
func_type_ref=sycon_function_type,
unique_func_name_lit=unique_func_name,
)
fn(UPLO, N, A, LDA, IPIV, ANORM, RCOND, WORK, IWORK, INFO)
return sycon
@classmethod
def numba_xpocon(cls, dtype) -> CPUDispatcher:
"""
Estimates the reciprocal of the condition number of a positive definite matrix A using the Cholesky factorization
computed by potrf.
Called by scipy.linalg.solve when assume_a == "pos"
"""
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}pocon"
@numba_basic.numba_njit
def get_pocon_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "pocon")
return ptr
pocon_function_type = types.FunctionType(
types.void(
nb_i32p, # UPLO
nb_i32p, # N
float_pointer, # A
nb_i32p, # LDA
float_pointer, # ANORM
float_pointer, # RCOND
float_pointer, # WORK
nb_i32p, # IWORK
nb_i32p, # INFO
)
)
@numba_basic.numba_njit
def pocon(UPLO, N, A, LDA, ANORM, RCOND, WORK, IWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_pocon_pointer,
func_type_ref=pocon_function_type,
unique_func_name_lit=unique_func_name,
)
fn(UPLO, N, A, LDA, ANORM, RCOND, WORK, IWORK, INFO)
return pocon
@classmethod
def numba_xposv(cls, dtype) -> CPUDispatcher:
"""
......
......@@ -12,24 +12,19 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix
def _cholesky(a, lower=False, overwrite_a=False, check_finite=True):
return (
linalg.cholesky(
a, lower=lower, overwrite_a=overwrite_a, check_finite=check_finite
),
0,
)
def _cholesky(a, lower=False, overwrite_a=False):
return linalg.cholesky(a, lower=lower, overwrite_a=overwrite_a, check_finite=False)
@overload(_cholesky)
def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
def cholesky_impl(A, lower=0, overwrite_a=False):
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="cholesky")
dtype = A.dtype
numba_potrf = _LAPACK().numba_xpotrf(dtype)
def impl(A, lower=False, overwrite_a=False, check_finite=True):
def impl(A, lower=False, overwrite_a=False):
_N = np.int32(A.shape[-1])
if A.shape[-2] != _N:
raise linalg.LinAlgError("Last 2 dimensions of A must be square")
......@@ -58,6 +53,10 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
INFO,
)
if int_ptr_to_val(INFO) != 0:
A_copy = np.full_like(A_copy, np.nan)
return A_copy
if lower:
for j in range(1, _N):
for i in range(j):
......@@ -67,10 +66,9 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
for i in range(j + 1, _N):
A_copy[i, j] = 0.0
info_int = int_ptr_to_val(INFO)
if transposed:
return A_copy.T, info_int
return A_copy, info_int
return A_copy.T
else:
return A_copy
return impl
......@@ -39,7 +39,6 @@ def _lu_factor_to_lu(a, dtype, overwrite_a):
def _lu_1(
a: np.ndarray,
permute_l: Literal[True],
check_finite: bool,
p_indices: Literal[False],
overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
......@@ -52,7 +51,7 @@ def _lu_1(
return linalg.lu( # type: ignore[no-any-return]
a,
permute_l=permute_l,
check_finite=check_finite,
check_finite=False,
p_indices=p_indices,
overwrite_a=overwrite_a,
)
......@@ -61,7 +60,6 @@ def _lu_1(
def _lu_2(
a: np.ndarray,
permute_l: Literal[False],
check_finite: bool,
p_indices: Literal[True],
overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray]:
......@@ -74,7 +72,7 @@ def _lu_2(
return linalg.lu( # type: ignore[no-any-return]
a,
permute_l=permute_l,
check_finite=check_finite,
check_finite=False,
p_indices=p_indices,
overwrite_a=overwrite_a,
)
......@@ -83,7 +81,6 @@ def _lu_2(
def _lu_3(
a: np.ndarray,
permute_l: Literal[False],
check_finite: bool,
p_indices: Literal[False],
overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
......@@ -96,7 +93,7 @@ def _lu_3(
return linalg.lu( # type: ignore[no-any-return]
a,
permute_l=permute_l,
check_finite=check_finite,
check_finite=False,
p_indices=p_indices,
overwrite_a=overwrite_a,
)
......@@ -106,11 +103,10 @@ def _lu_3(
def lu_impl_1(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> Callable[
[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray]
[np.ndarray, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray]
]:
"""
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
......@@ -123,7 +119,6 @@ def lu_impl_1(
def impl(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
......@@ -137,10 +132,9 @@ def lu_impl_1(
def lu_impl_2(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> Callable[[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray]]:
) -> Callable[[np.ndarray, bool, bool, bool], tuple[np.ndarray, np.ndarray]]:
"""
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is False and p_indices is
True. Returns a tuple of (PL, U), where PL is the permuted L matrix, PL = P @ L.
......@@ -153,7 +147,6 @@ def lu_impl_2(
def impl(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray]:
......@@ -169,11 +162,10 @@ def lu_impl_2(
def lu_impl_3(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> Callable[
[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray]
[np.ndarray, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray]
]:
"""
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
......@@ -186,7 +178,6 @@ def lu_impl_3(
def impl(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
......
......@@ -79,11 +79,12 @@ def lu_factor_impl(
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="lu_factor")
def impl(A: np.ndarray, overwrite_a: bool = False) -> tuple[np.ndarray, np.ndarray]:
A_copy, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a)
A_copy, IPIV, info = _getrf(A, overwrite_a=overwrite_a)
IPIV -= 1 # LAPACK uses 1-based indexing, convert to 0-based
if INFO != 0:
raise np.linalg.LinAlgError("LU decomposition failed")
if info != 0:
A_copy = np.full_like(A_copy, np.nan)
return A_copy, IPIV
return impl
......@@ -228,7 +228,6 @@ def _qr_full_pivot(
mode: Literal["full", "economic"] = "full",
pivoting: Literal[True] = True,
overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None,
):
"""
......@@ -243,7 +242,7 @@ def _qr_full_pivot(
mode=mode,
pivoting=pivoting,
overwrite_a=overwrite_a,
check_finite=check_finite,
check_finite=False,
lwork=lwork,
)
......@@ -253,7 +252,6 @@ def _qr_full_no_pivot(
mode: Literal["full", "economic"] = "full",
pivoting: Literal[False] = False,
overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None,
):
"""
......@@ -267,7 +265,7 @@ def _qr_full_no_pivot(
mode=mode,
pivoting=pivoting,
overwrite_a=overwrite_a,
check_finite=check_finite,
check_finite=False,
lwork=lwork,
)
......@@ -277,7 +275,6 @@ def _qr_r_pivot(
mode: Literal["r", "raw"] = "r",
pivoting: Literal[True] = True,
overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None,
):
"""
......@@ -291,7 +288,7 @@ def _qr_r_pivot(
mode=mode,
pivoting=pivoting,
overwrite_a=overwrite_a,
check_finite=check_finite,
check_finite=False,
lwork=lwork,
)
......@@ -301,7 +298,6 @@ def _qr_r_no_pivot(
mode: Literal["r", "raw"] = "r",
pivoting: Literal[False] = False,
overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None,
):
"""
......@@ -315,7 +311,7 @@ def _qr_r_no_pivot(
mode=mode,
pivoting=pivoting,
overwrite_a=overwrite_a,
check_finite=check_finite,
check_finite=False,
lwork=lwork,
)
......@@ -325,7 +321,6 @@ def _qr_raw_no_pivot(
mode: Literal["raw"] = "raw",
pivoting: Literal[False] = False,
overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None,
):
"""
......@@ -339,7 +334,7 @@ def _qr_raw_no_pivot(
mode=mode,
pivoting=pivoting,
overwrite_a=overwrite_a,
check_finite=check_finite,
check_finite=False,
lwork=lwork,
)
......@@ -351,7 +346,6 @@ def _qr_raw_pivot(
mode: Literal["raw"] = "raw",
pivoting: Literal[True] = True,
overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None,
):
"""
......@@ -365,7 +359,7 @@ def _qr_raw_pivot(
mode=mode,
pivoting=pivoting,
overwrite_a=overwrite_a,
check_finite=check_finite,
check_finite=False,
lwork=lwork,
)
......@@ -373,9 +367,7 @@ def _qr_raw_pivot(
@overload(_qr_full_pivot)
def qr_full_pivot_impl(
x, mode="full", pivoting=True, overwrite_a=False, check_finite=False, lwork=None
):
def qr_full_pivot_impl(x, mode="full", pivoting=True, overwrite_a=False, lwork=None):
ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype
......@@ -395,7 +387,6 @@ def qr_full_pivot_impl(
mode="full",
pivoting=True,
overwrite_a=False,
check_finite=False,
lwork=None,
):
M = np.int32(x.shape[0])
......@@ -529,7 +520,7 @@ def qr_full_pivot_impl(
@overload(_qr_full_no_pivot)
def qr_full_no_pivot_impl(
x, mode="full", pivoting=False, overwrite_a=False, check_finite=False, lwork=None
x, mode="full", pivoting=False, overwrite_a=False, lwork=None
):
ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
......@@ -546,7 +537,6 @@ def qr_full_no_pivot_impl(
mode="full",
pivoting=False,
overwrite_a=False,
check_finite=False,
lwork=None,
):
M = np.int32(x.shape[0])
......@@ -645,9 +635,7 @@ def qr_full_no_pivot_impl(
@overload(_qr_r_pivot)
def qr_r_pivot_impl(
x, mode="r", pivoting=True, overwrite_a=False, check_finite=False, lwork=None
):
def qr_r_pivot_impl(x, mode="r", pivoting=True, overwrite_a=False, lwork=None):
ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype
......@@ -658,7 +646,6 @@ def qr_r_pivot_impl(
mode="r",
pivoting=True,
overwrite_a=False,
check_finite=False,
lwork=None,
):
M = np.int32(x.shape[0])
......@@ -720,9 +707,7 @@ def qr_r_pivot_impl(
@overload(_qr_r_no_pivot)
def qr_r_no_pivot_impl(
x, mode="r", pivoting=False, overwrite_a=False, check_finite=False, lwork=None
):
def qr_r_no_pivot_impl(x, mode="r", pivoting=False, overwrite_a=False, lwork=None):
ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype
......@@ -733,7 +718,6 @@ def qr_r_no_pivot_impl(
mode="r",
pivoting=False,
overwrite_a=False,
check_finite=False,
lwork=None,
):
M = np.int32(x.shape[0])
......@@ -792,9 +776,7 @@ def qr_r_no_pivot_impl(
@overload(_qr_raw_no_pivot)
def qr_raw_no_pivot_impl(
x, mode="raw", pivoting=False, overwrite_a=False, check_finite=False, lwork=None
):
def qr_raw_no_pivot_impl(x, mode="raw", pivoting=False, overwrite_a=False, lwork=None):
ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype
......@@ -805,7 +787,6 @@ def qr_raw_no_pivot_impl(
mode="raw",
pivoting=False,
overwrite_a=False,
check_finite=False,
lwork=None,
):
M = np.int32(x.shape[0])
......@@ -863,9 +844,7 @@ def qr_raw_no_pivot_impl(
@overload(_qr_raw_pivot)
def qr_raw_pivot_impl(
x, mode="raw", pivoting=True, overwrite_a=False, check_finite=False, lwork=None
):
def qr_raw_pivot_impl(x, mode="raw", pivoting=True, overwrite_a=False, lwork=None):
ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
......@@ -880,7 +859,6 @@ def qr_raw_pivot_impl(
mode="raw",
pivoting=True,
overwrite_a=False,
check_finite=False,
lwork=None,
):
M = np.int32(x.shape[0])
......
......@@ -14,23 +14,23 @@ from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match,
_check_linalg_matrix,
_copy_to_fortran_order_even_if_1d,
_solve_check,
)
def _cho_solve(
C: np.ndarray, B: np.ndarray, lower: bool, overwrite_b: bool, check_finite: bool
):
def _cho_solve(C: np.ndarray, B: np.ndarray, lower: bool, overwrite_b: bool):
"""
Solve a positive-definite linear system using the Cholesky decomposition.
"""
return linalg.cho_solve(
(C, lower), b=B, overwrite_b=overwrite_b, check_finite=check_finite
(C, lower),
b=B,
overwrite_b=overwrite_b,
check_finite=False,
)
@overload(_cho_solve)
def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
def cho_solve_impl(C, B, lower=False, overwrite_b=False):
ensure_lapack()
_check_linalg_matrix(C, ndim=2, dtype=Float, func_name="cho_solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="cho_solve")
......@@ -38,7 +38,7 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
dtype = C.dtype
numba_potrs = _LAPACK().numba_xpotrs(dtype)
def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
def impl(C, B, lower=False, overwrite_b=False):
_solve_check_input_shapes(C, B)
_N = np.int32(C.shape[-1])
......@@ -79,7 +79,8 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
INFO,
)
_solve_check(_N, int_ptr_to_val(INFO))
if int_ptr_to_val(INFO) != 0:
B_copy = np.full_like(B_copy, np.nan)
if B_is_1d:
return B_copy[..., 0]
......
......@@ -3,82 +3,24 @@ from collections.abc import Callable
import numpy as np
from numba.core.extending import overload
from numba.core.types import Float
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from numba.np.linalg import ensure_lapack
from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _getrf
from pytensor.link.numba.dispatch.linalg.solve.lu_solve import _getrs
from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match,
_check_linalg_matrix,
_solve_check,
)
def _xgecon(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]:
"""
Placeholder for computing the condition number of a matrix; used by linalg.solve. Not used by pytensor to numbify
graphs.
"""
return # type: ignore
@overload(_xgecon)
def xgecon_impl(
A: np.ndarray, A_norm: float, norm: str
) -> Callable[[np.ndarray, float, str], tuple[np.ndarray, int]]:
"""
Compute the condition number of a matrix A.
"""
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="gecon")
dtype = A.dtype
numba_gecon = _LAPACK().numba_xgecon(dtype)
def impl(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]:
_N = np.int32(A.shape[-1])
A_copy = _copy_to_fortran_order(A)
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
A_NORM = np.array(A_norm, dtype=dtype)
NORM = val_to_int_ptr(ord(norm))
RCOND = np.empty(1, dtype=dtype)
WORK = np.empty(4 * _N, dtype=dtype)
IWORK = np.empty(_N, dtype=np.int32)
INFO = val_to_int_ptr(1)
numba_gecon(
NORM,
N,
A_copy.ctypes,
LDA,
A_NORM.ctypes,
RCOND.ctypes,
WORK.ctypes,
IWORK.ctypes,
INFO,
)
return RCOND, int_ptr_to_val(INFO)
return impl
def _solve_gen(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
):
"""Thin wrapper around scipy.linalg.solve. Used as an overload target for numba to avoid unexpected side-effects
......@@ -89,7 +31,7 @@ def _solve_gen(
lower=lower,
overwrite_a=overwrite_a,
overwrite_b=overwrite_b,
check_finite=check_finite,
check_finite=False,
assume_a="gen",
transposed=transposed,
)
......@@ -102,9 +44,8 @@ def solve_gen_impl(
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]:
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool], np.ndarray]:
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
......@@ -116,7 +57,6 @@ def solve_gen_impl(
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> np.ndarray:
_N = np.int32(A.shape[-1])
......@@ -127,20 +67,18 @@ def solve_gen_impl(
A = A.T
transposed = not transposed
order = "I" if transposed else "1"
norm = _xlange(A, order=order)
N = A.shape[1]
LU, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a)
_solve_check(N, INFO)
LU, IPIV, INFO1 = _getrf(A, overwrite_a=overwrite_a)
X, INFO = _getrs(
LU=LU, B=B, IPIV=IPIV, trans=transposed, overwrite_b=overwrite_b
X, INFO2 = _getrs(
LU=LU,
B=B,
IPIV=IPIV,
trans=transposed,
overwrite_b=overwrite_b,
)
_solve_check(N, INFO)
RCOND, INFO = _xgecon(LU, norm, "1")
_solve_check(N, INFO, True, RCOND)
if INFO1 != 0 or INFO2 != 0:
X = np.full_like(X, np.nan)
return X
......
......@@ -17,7 +17,6 @@ from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match,
_check_linalg_matrix,
_copy_to_fortran_order_even_if_1d,
_solve_check,
_trans_char_to_int,
)
......@@ -107,14 +106,11 @@ def _lu_solve(
b: np.ndarray,
trans: _Trans,
overwrite_b: bool,
check_finite: bool,
):
"""
Thin wrapper around scipy.lu_solve, used to avoid side effects from numba overloads on users who import Pytensor.
"""
return linalg.lu_solve(
lu_and_piv, b, trans=trans, overwrite_b=overwrite_b, check_finite=check_finite
)
return linalg.lu_solve(lu_and_piv, b, trans=trans, overwrite_b=overwrite_b)
@overload(_lu_solve)
......@@ -123,8 +119,7 @@ def lu_solve_impl(
b: np.ndarray,
trans: _Trans,
overwrite_b: bool,
check_finite: bool,
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, _Trans, bool, bool], np.ndarray]:
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, _Trans, bool], np.ndarray]:
ensure_lapack()
lu, _piv = lu_and_piv
_check_linalg_matrix(lu, ndim=2, dtype=Float, func_name="lu_solve")
......@@ -137,13 +132,11 @@ def lu_solve_impl(
b: np.ndarray,
trans: _Trans,
overwrite_b: bool,
check_finite: bool,
) -> np.ndarray:
n = np.int32(lu.shape[0])
X, info = _getrs(LU=lu, B=b, IPIV=piv, trans=trans, overwrite_b=overwrite_b)
X, INFO = _getrs(LU=lu, B=b, IPIV=piv, trans=trans, overwrite_b=overwrite_b)
_solve_check(n, INFO)
if info != 0:
X = np.full_like(X, np.nan)
return X
......
from collections.abc import Callable
import numpy as np
from numba.core.extending import overload
from numba.core.types import Float
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix
def _xlange(A: np.ndarray, order: str | None = None) -> float:
"""
Placeholder for computing the norm of a matrix; used by linalg.solve. Will never be called in python mode.
"""
return # type: ignore
@overload(_xlange)
def xlange_impl(
A: np.ndarray, order: str | None = None
) -> Callable[[np.ndarray, str], float]:
"""
xLANGE returns the value of the one norm, or the Frobenius norm, or the infinity norm, or the element of
largest absolute value of a matrix A.
"""
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="norm")
dtype = A.dtype
numba_lange = _LAPACK().numba_xlange(dtype)
def impl(A: np.ndarray, order: str | None = None):
_M, _N = np.int32(A.shape[-2:]) # type: ignore
A_copy = _copy_to_fortran_order(A)
M = val_to_int_ptr(_M) # type: ignore
N = val_to_int_ptr(_N) # type: ignore
LDA = val_to_int_ptr(_M) # type: ignore
NORM = (
val_to_int_ptr(ord(order))
if order is not None
else val_to_int_ptr(ord("1"))
)
WORK = np.empty(_M, dtype=dtype) # type: ignore
result = numba_lange(NORM, M, N, A_copy.ctypes, LDA, WORK.ctypes)
return result
return impl
......@@ -11,13 +11,11 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match,
_check_linalg_matrix,
_copy_to_fortran_order_even_if_1d,
_solve_check,
)
......@@ -27,8 +25,6 @@ def _posv(
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> tuple[np.ndarray, np.ndarray, int]:
"""
Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve.
......@@ -43,10 +39,8 @@ def posv_impl(
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> Callable[
[np.ndarray, np.ndarray, bool, bool, bool, bool, bool],
[np.ndarray, np.ndarray, bool, bool, bool],
tuple[np.ndarray, np.ndarray, int],
]:
ensure_lapack()
......@@ -62,8 +56,6 @@ def posv_impl(
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> tuple[np.ndarray, np.ndarray, int]:
_solve_check_input_shapes(A, B)
......@@ -115,60 +107,12 @@ def posv_impl(
return impl
def _pocon(A: np.ndarray, anorm: float) -> tuple[np.ndarray, int]:
"""
Placeholder for computing the condition number of a cholesky-factorized positive-definite matrix. Used by
linalg.solve when assume_a = "pos".
"""
return # type: ignore
@overload(_pocon)
def pocon_impl(
A: np.ndarray, anorm: float
) -> Callable[[np.ndarray, float], tuple[np.ndarray, int]]:
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="pocon")
dtype = A.dtype
numba_pocon = _LAPACK().numba_xpocon(dtype)
def impl(A: np.ndarray, anorm: float):
_N = np.int32(A.shape[-1])
A_copy = _copy_to_fortran_order(A)
UPLO = val_to_int_ptr(ord("L"))
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
ANORM = np.array(anorm, dtype=dtype)
RCOND = np.empty(1, dtype=dtype)
WORK = np.empty(3 * _N, dtype=dtype)
IWORK = np.empty(_N, dtype=np.int32)
INFO = val_to_int_ptr(0)
numba_pocon(
UPLO,
N,
A_copy.ctypes,
LDA,
ANORM.ctypes,
RCOND.ctypes,
WORK.ctypes,
IWORK.ctypes,
INFO,
)
return RCOND, int_ptr_to_val(INFO)
return impl
def _solve_psd(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
):
"""Thin wrapper around scipy.linalg.solve for positive-definite matrices. Used as an overload target for numba to
......@@ -179,7 +123,7 @@ def _solve_psd(
lower=lower,
overwrite_a=overwrite_a,
overwrite_b=overwrite_b,
check_finite=check_finite,
check_finite=False,
transposed=transposed,
assume_a="pos",
)
......@@ -192,9 +136,8 @@ def solve_psd_impl(
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]:
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool], np.ndarray]:
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
......@@ -206,18 +149,14 @@ def solve_psd_impl(
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> np.ndarray:
_solve_check_input_shapes(A, B)
C, x, info = _posv(
A, B, lower, overwrite_a, overwrite_b, check_finite, transposed
)
_solve_check(A.shape[-1], info)
_C, x, info = _posv(A, B, lower, overwrite_a, overwrite_b)
rcond, info = _pocon(C, _xlange(A))
_solve_check(A.shape[-1], info=info, lamch=True, rcond=rcond)
if info != 0:
x = np.full_like(x, np.nan)
return x
......
......@@ -11,13 +11,11 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match,
_check_linalg_matrix,
_copy_to_fortran_order_even_if_1d,
_solve_check,
)
......@@ -121,61 +119,12 @@ def sysv_impl(
return impl
def _sycon(A: np.ndarray, ipiv: np.ndarray, anorm: float) -> tuple[np.ndarray, int]:
"""
Placeholder for computing the condition number of a symmetric matrix; used by linalg.solve. Never called in
python mode.
"""
return # type: ignore
@overload(_sycon)
def sycon_impl(
A: np.ndarray, ipiv: np.ndarray, anorm: float
) -> Callable[[np.ndarray, np.ndarray, float], tuple[np.ndarray, int]]:
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="sycon")
dtype = A.dtype
numba_sycon = _LAPACK().numba_xsycon(dtype)
def impl(A: np.ndarray, ipiv: np.ndarray, anorm: float) -> tuple[np.ndarray, int]:
_N = np.int32(A.shape[-1])
A_copy = _copy_to_fortran_order(A)
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
UPLO = val_to_int_ptr(ord("U"))
ANORM = np.array(anorm, dtype=dtype)
RCOND = np.empty(1, dtype=dtype)
WORK = np.empty(2 * _N, dtype=dtype)
IWORK = np.empty(_N, dtype=np.int32)
INFO = val_to_int_ptr(0)
numba_sycon(
UPLO,
N,
A_copy.ctypes,
LDA,
ipiv.ctypes,
ANORM.ctypes,
RCOND.ctypes,
WORK.ctypes,
IWORK.ctypes,
INFO,
)
return RCOND, int_ptr_to_val(INFO)
return impl
def _solve_symmetric(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
):
"""Thin wrapper around scipy.linalg.solve for symmetric matrices. Used as an overload target for numba to avoid
......@@ -186,7 +135,7 @@ def _solve_symmetric(
lower=lower,
overwrite_a=overwrite_a,
overwrite_b=overwrite_b,
check_finite=check_finite,
check_finite=False,
assume_a="sym",
transposed=transposed,
)
......@@ -199,9 +148,8 @@ def solve_symmetric_impl(
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]:
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool], np.ndarray]:
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
......@@ -213,16 +161,14 @@ def solve_symmetric_impl(
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> np.ndarray:
_solve_check_input_shapes(A, B)
lu, x, ipiv, info = _sysv(A, B, lower, overwrite_a, overwrite_b)
_solve_check(A.shape[-1], info)
_lu, x, _ipiv, info = _sysv(A, B, lower, overwrite_a, overwrite_b)
rcond, info = _sycon(lu, ipiv, _xlange(A, order="I"))
_solve_check(A.shape[-1], info, True, rcond)
if info != 0:
x = np.full_like(x, np.nan)
return x
......
......@@ -15,13 +15,12 @@ from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match,
_check_linalg_matrix,
_copy_to_fortran_order_even_if_1d,
_solve_check,
_trans_char_to_int,
)
def _solve_triangular(
A, B, trans=0, lower=False, unit_diagonal=False, b_ndim=1, overwrite_b=False
A, B, trans=0, lower=False, unit_diagonal=False, overwrite_b=False
):
"""
Thin wrapper around scipy.linalg.solve_triangular.
......@@ -39,11 +38,12 @@ def _solve_triangular(
lower=lower,
unit_diagonal=unit_diagonal,
overwrite_b=overwrite_b,
check_finite=False,
)
@overload(_solve_triangular)
def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
def solve_triangular_impl(A, B, trans, lower, unit_diagonal, overwrite_b):
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve_triangular")
......@@ -57,12 +57,10 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
"This function is not expected to work with complex numbers yet"
)
def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
def impl(A, B, trans, lower, unit_diagonal, overwrite_b):
_N = np.int32(A.shape[-1])
_solve_check_input_shapes(A, B)
# Seems weird to not use the b_ndim input directly, but when I did that Numba complained that the output type
# could potentially be 3d (it didn't understand b_ndim was always equal to B.ndim)
B_is_1d = B.ndim == 1
if A.flags.f_contiguous or (A.flags.c_contiguous and trans in (0, 1)):
......@@ -106,8 +104,8 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
LDB,
INFO,
)
_solve_check(int_ptr_to_val(LDA), int_ptr_to_val(INFO))
if int_ptr_to_val(INFO) != 0:
B_copy = np.full_like(B_copy, np.nan)
if B_is_1d:
return B_copy[..., 0]
......
......@@ -23,7 +23,6 @@ from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match,
_check_linalg_matrix,
_copy_to_fortran_order_even_if_1d,
_solve_check,
_trans_char_to_int,
)
from pytensor.tensor._linalg.solve.tridiagonal import (
......@@ -202,83 +201,12 @@ def gttrs_impl(
return impl
def _gtcon(
dl: ndarray,
d: ndarray,
du: ndarray,
du2: ndarray,
ipiv: ndarray,
anorm: float,
norm: str,
) -> tuple[ndarray, int]:
"""Placeholder for computing the condition number of a tridiagonal system."""
return # type: ignore
@overload(_gtcon)
def gtcon_impl(
dl: ndarray,
d: ndarray,
du: ndarray,
du2: ndarray,
ipiv: ndarray,
anorm: float,
norm: str,
) -> Callable[
[ndarray, ndarray, ndarray, ndarray, ndarray, float, str], tuple[ndarray, int]
]:
ensure_lapack()
_check_linalg_matrix(dl, ndim=1, dtype=Float, func_name="gtcon")
_check_linalg_matrix(d, ndim=1, dtype=Float, func_name="gtcon")
_check_linalg_matrix(du, ndim=1, dtype=Float, func_name="gtcon")
_check_linalg_matrix(du2, ndim=1, dtype=Float, func_name="gtcon")
_check_dtypes_match((dl, d, du, du2), func_name="gtcon")
_check_linalg_matrix(ipiv, ndim=1, dtype=int32, func_name="gtcon")
dtype = d.dtype
numba_gtcon = _LAPACK().numba_xgtcon(dtype)
def impl(
dl: ndarray,
d: ndarray,
du: ndarray,
du2: ndarray,
ipiv: ndarray,
anorm: float,
norm: str,
) -> tuple[ndarray, int]:
n = np.int32(d.shape[-1])
rcond = np.empty(1, dtype=dtype)
work = np.empty(2 * n, dtype=dtype)
iwork = np.empty(n, dtype=np.int32)
info = val_to_int_ptr(0)
numba_gtcon(
val_to_int_ptr(ord(norm)),
val_to_int_ptr(n),
dl.ctypes,
d.ctypes,
du.ctypes,
du2.ctypes,
ipiv.ctypes,
np.array(anorm, dtype=dtype).ctypes,
rcond.ctypes,
work.ctypes,
iwork.ctypes,
info,
)
return rcond, int_ptr_to_val(info)
return impl
def _solve_tridiagonal(
a: ndarray,
b: ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
):
"""
......@@ -290,7 +218,7 @@ def _solve_tridiagonal(
lower=lower,
overwrite_a=overwrite_a,
overwrite_b=overwrite_b,
check_finite=check_finite,
check_finite=False,
transposed=transposed,
assume_a="tridiagonal",
)
......@@ -303,9 +231,8 @@ def _tridiagonal_solve_impl(
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> Callable[[ndarray, ndarray, bool, bool, bool, bool, bool], ndarray]:
) -> Callable[[ndarray, ndarray, bool, bool, bool, bool], ndarray]:
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
......@@ -317,31 +244,24 @@ def _tridiagonal_solve_impl(
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> ndarray:
n = np.int32(A.shape[-1])
_solve_check_input_shapes(A, B)
norm = "1"
if transposed:
A = A.T
dl, d, du = np.diag(A, -1), np.diag(A, 0), np.diag(A, 1)
anorm = tridiagonal_norm(du, d, dl)
dl, d, du, du2, IPIV, INFO = _gttrf(
dl, d, du, du2, ipiv, info1 = _gttrf(
dl, d, du, overwrite_dl=True, overwrite_d=True, overwrite_du=True
)
_solve_check(n, INFO)
X, INFO = _gttrs(
dl, d, du, du2, IPIV, B, trans=transposed, overwrite_b=overwrite_b
X, info2 = _gttrs(
dl, d, du, du2, ipiv, B, trans=transposed, overwrite_b=overwrite_b
)
_solve_check(n, INFO)
RCOND, INFO = _gtcon(dl, d, du, du2, IPIV, anorm, norm)
_solve_check(n, INFO, True, RCOND)
if info1 != 0 or info2 != 0:
X = np.full_like(X, np.nan)
return X
......@@ -391,8 +311,8 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
)
return dl, d, du, du2, ipiv
cache_key = 1
return lu_factor_tridiagonal, cache_key
cache_version = 2
return lu_factor_tridiagonal, cache_version
@register_funcify_default_op_cache_key(SolveLUFactorTridiagonal)
......@@ -434,7 +354,7 @@ def numba_funcify_SolveLUFactorTridiagonal(
ipiv = ipiv.astype(np.int32)
if cast_b:
b = b.astype(out_dtype)
x, _ = _gttrs(
x, info = _gttrs(
dl,
d,
du,
......@@ -444,7 +364,11 @@ def numba_funcify_SolveLUFactorTridiagonal(
overwrite_b=overwrite_b,
trans=transposed,
)
if info != 0:
x = np.full_like(x, np.nan)
return x
cache_key = 1
return solve_lu_factor_tridiagonal, cache_key
cache_version = 2
return solve_lu_factor_tridiagonal, cache_version
from collections.abc import Callable, Sequence
from collections.abc import Sequence
import numba
from numba.core import types
from numba.core.extending import overload
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from numpy.linalg import LinAlgError
from numba.np.linalg import _copy_to_fortran_order
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
val_to_int_ptr,
)
@numba_basic.numba_njit(inline="always")
......@@ -61,58 +54,3 @@ def _check_dtypes_match(arrays: Sequence, func_name="cho_solve"):
if first_dtype != other_dtype:
msg = f"{func_name} only supported for matching dtypes, got {dtypes}"
raise numba.TypingError(msg, highlighting=False)
@numba_basic.numba_njit(inline="always")
def _solve_check(n, info, lamch=False, rcond=None):
"""
Check arguments during the different steps of the solution phase
Adapted from https://github.com/scipy/scipy/blob/7f7f04caa4a55306a9c6613c89eef91fedbd72d4/scipy/linalg/_basic.py#L38
"""
if info < 0:
# TODO: figure out how to do an fstring here
msg = "LAPACK reported an illegal value in input"
raise ValueError(msg)
elif 0 < info:
raise LinAlgError("Matrix is singular.")
if lamch:
E = _xlamch("E")
if rcond < E:
# TODO: This should be a warning, but we can't raise warnings in numba mode
print( # noqa: T201
"Ill-conditioned matrix, rcond=", rcond, ", result may not be accurate."
)
def _xlamch(kind: str = "E"):
"""
Placeholder for getting machine precision; used by linalg.solve. Not used by pytensor to numbify graphs.
"""
pass
@overload(_xlamch)
def xlamch_impl(kind: str = "E") -> Callable[[str], float]:
"""
Compute the machine precision for a given floating point type.
"""
from pytensor import config
ensure_lapack()
w_type = _get_underlying_float(config.floatX)
if w_type == "float32":
dtype = types.float32
elif w_type == "float64":
dtype = types.float64
else:
raise NotImplementedError("Unsupported dtype")
numba_lamch = _LAPACK().numba_xlamch(dtype)
def impl(kind: str = "E") -> float:
KIND = val_to_int_ptr(ord(kind))
return numba_lamch(KIND) # type: ignore
return impl
......@@ -181,8 +181,8 @@ def numba_funcify_ExtractDiag(op, node, **kwargs):
out[..., i] = new_entry
return out
cache_key = 1
return extract_diag, cache_key
cache_version = 1
return extract_diag, cache_version
@register_funcify_default_op_cache_key(Eye)
......
......@@ -20,14 +20,13 @@ from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_so
from pytensor.tensor.variable import TensorVariable
def decompose_A(A, assume_a, check_finite, lower):
def decompose_A(A, assume_a, lower):
if assume_a == "gen":
return lu_factor(A, check_finite=check_finite)
return lu_factor(A)
elif assume_a == "tridiagonal":
# We didn't implement check_finite for tridiagonal LU factorization
return tridiagonal_lu_factor(A)
elif assume_a == "pos":
return cholesky(A, lower=lower, check_finite=check_finite)
return cholesky(A, lower=lower)
else:
raise NotImplementedError
......@@ -36,7 +35,6 @@ def solve_decomposed_system(
A_decomp, b, transposed=False, lower=False, *, core_solve_op: Solve
):
b_ndim = core_solve_op.b_ndim
check_finite = core_solve_op.check_finite
assume_a = core_solve_op.assume_a
if assume_a == "gen":
......@@ -45,10 +43,8 @@ def solve_decomposed_system(
b,
b_ndim=b_ndim,
trans=transposed,
check_finite=check_finite,
)
elif assume_a == "tridiagonal":
# We didn't implement check_finite for tridiagonal LU solve
return tridiagonal_lu_solve(
A_decomp,
b,
......@@ -61,7 +57,6 @@ def solve_decomposed_system(
(A_decomp, lower),
b,
b_ndim=b_ndim,
check_finite=check_finite,
)
else:
raise NotImplementedError
......@@ -141,17 +136,8 @@ def _split_decomp_and_solve_steps(
):
return None
# If any Op had check_finite=True, we also do it for the LU decomposition
check_finite_decomp = False
for client, _ in A_solve_clients_and_transpose:
if client.op.core_op.check_finite:
check_finite_decomp = True
break
lower = node.op.core_op.lower
A_decomp = decompose_A(
A, assume_a=assume_a, check_finite=check_finite_decomp, lower=lower
)
A_decomp = decompose_A(A, assume_a=assume_a, lower=lower)
replacements = {}
for client, transposed in A_solve_clients_and_transpose:
......
差异被折叠。
from collections.abc import Sequence
from typing import Literal
from pytensor.tensor.slinalg import Cholesky, Solve
from pytensor.xtensor.type import as_xtensor
......@@ -10,8 +9,7 @@ def cholesky(
x,
lower: bool = True,
*,
check_finite: bool = False,
on_error: Literal["raise", "nan"] = "raise",
check_finite: bool = True,
dims: Sequence[str],
):
"""Compute the Cholesky decomposition of an XTensorVariable.
......@@ -22,22 +20,15 @@ def cholesky(
The input variable to decompose.
lower : bool, optional
Whether to return the lower triangular matrix. Default is True.
check_finite : bool, optional
Whether to check that the input is finite. Default is False.
on_error : {'raise', 'nan'}, optional
What to do if the input is not positive definite. If 'raise', an error is raised.
If 'nan', the output will contain NaNs. Default is 'raise'.
check_finite : bool
Unused by PyTensor. PyTensor will return nan if the operation fails.
dims : Sequence[str]
The two core dimensions of the input variable, over which the Cholesky decomposition is computed.
"""
if len(dims) != 2:
raise ValueError(f"Cholesky needs two dims, got {len(dims)}")
core_op = Cholesky(
lower=lower,
check_finite=check_finite,
on_error=on_error,
)
core_op = Cholesky(lower=lower)
core_dims = (
((dims[0], dims[1]),),
((dims[0], dims[1]),),
......@@ -52,7 +43,7 @@ def solve(
dims: Sequence[str],
assume_a="gen",
lower: bool = False,
check_finite: bool = False,
check_finite: bool = True,
):
"""Solve a system of linear equations using XTensorVariables.
......@@ -75,8 +66,8 @@ def solve(
Long form options can also be used ["general", "symmetric", "hermitian", "positive_definite"].
lower : bool, optional
Whether `a` is lower triangular. Default is False. Only relevant if `assume_a` is "sym", "her", or "pos".
check_finite : bool, optional
Whether to check that the input is finite. Default is False.
check_finite : bool
Unused by PyTensor. PyTensor will return nan if the operation fails.
"""
a, b = as_xtensor(a), as_xtensor(b)
input_core_dims: tuple[tuple[str, str], tuple[str] | tuple[str, str]]
......@@ -98,9 +89,7 @@ def solve(
else:
raise ValueError("Solve dims must have length 2 or 3")
core_op = Solve(
b_ndim=b_ndim, assume_a=assume_a, lower=lower, check_finite=check_finite
)
core_op = Solve(b_ndim=b_ndim, assume_a=assume_a, lower=lower)
x_op = XBlockwise(
core_op,
core_dims=(input_core_dims, output_core_dims),
......
import re
from typing import Literal
import numpy as np
......@@ -36,70 +35,6 @@ floatX = config.floatX
rng = np.random.default_rng(42849)
def test_lamch():
from scipy.linalg import get_lapack_funcs
from pytensor.link.numba.dispatch.linalg.utils import _xlamch
@numba.njit()
def xlamch(kind):
return _xlamch(kind)
lamch = get_lapack_funcs("lamch", (np.array([0.0], dtype=floatX),))
np.testing.assert_allclose(xlamch("E"), lamch("E"))
np.testing.assert_allclose(xlamch("S"), lamch("S"))
np.testing.assert_allclose(xlamch("P"), lamch("P"))
np.testing.assert_allclose(xlamch("B"), lamch("B"))
np.testing.assert_allclose(xlamch("R"), lamch("R"))
np.testing.assert_allclose(xlamch("M"), lamch("M"))
@pytest.mark.parametrize(
"ord_numba, ord_scipy", [("F", "fro"), ("1", 1), ("I", np.inf)]
)
def test_xlange(ord_numba, ord_scipy):
# xlange is called internally only, we don't dispatch pt.linalg.norm to it
from scipy import linalg
from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange
@numba.njit()
def xlange(x, ord):
return _xlange(x, ord)
x = np.random.normal(size=(5, 5)).astype(floatX)
np.testing.assert_allclose(xlange(x, ord_numba), linalg.norm(x, ord_scipy))
@pytest.mark.parametrize("ord_numba, ord_scipy", [("1", 1), ("I", np.inf)])
def test_xgecon(ord_numba, ord_scipy):
# gecon is called internally only, we don't dispatch pt.linalg.norm to it
from scipy.linalg import get_lapack_funcs
from pytensor.link.numba.dispatch.linalg.solve.general import _xgecon
from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange
@numba.njit()
def gecon(x, norm):
anorm = _xlange(x, norm)
cond, info = _xgecon(x, anorm, norm)
return cond, info
x = np.random.normal(size=(5, 5)).astype(floatX)
rcond, info = gecon(x, norm=ord_numba)
# Test against direct call to the underlying LAPACK functions
# Solution does **not** agree with 1 / np.linalg.cond(x) !
lange, gecon = get_lapack_funcs(("lange", "gecon"), (x,))
norm = lange(ord_numba, x)
rcond2, _ = gecon(x, norm, norm=ord_numba)
assert info == 0
np.testing.assert_allclose(rcond, rcond2)
class TestSolves:
@pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower={x}")
@pytest.mark.parametrize(
......@@ -323,7 +258,7 @@ class TestSolves:
np.testing.assert_allclose(b_val_not_contig, b_val)
@pytest.mark.parametrize("value", [np.nan, np.inf])
def test_solve_triangular_raises_on_nan_inf(self, value):
def test_solve_triangular_does_not_raise_on_nan_inf(self, value):
A = pt.matrix("A")
b = pt.matrix("b")
......@@ -335,11 +270,8 @@ class TestSolves:
A_tri = np.linalg.cholesky(A_sym).astype(floatX)
b = np.full((5, 1), value).astype(floatX)
with pytest.raises(
np.linalg.LinAlgError,
match=re.escape("Non-numeric values"),
):
f(A_tri, b)
# Not checking everything is nan, because, with inf, LAPACK returns a mix of inf/nan, but does not set info != 0
assert not np.isfinite(f(A_tri, b)).any()
@pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower = {x}")
@pytest.mark.parametrize(
......@@ -567,10 +499,13 @@ class TestDecompositions:
x = pt.tensor(dtype=floatX, shape=(3, 3))
x = x.T.dot(x)
g = pt.linalg.cholesky(x, check_finite=True)
with pytest.warns(FutureWarning):
g = pt.linalg.cholesky(x, check_finite=True, on_error="raise")
f = pytensor.function([x], g, mode="NUMBA")
with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"):
with pytest.raises(
np.linalg.LinAlgError, match=r"Matrix is not positive definite"
):
f(test_value)
@pytest.mark.parametrize("on_error", ["nan", "raise"])
......@@ -578,13 +513,17 @@ class TestDecompositions:
test_value = rng.random(size=(3, 3)).astype(floatX)
x = pt.tensor(dtype=floatX, shape=(3, 3))
g = pt.linalg.cholesky(x, on_error=on_error)
if on_error == "raise":
with pytest.warns(FutureWarning):
g = pt.linalg.cholesky(x, on_error=on_error)
else:
g = pt.linalg.cholesky(x, on_error=on_error)
f = pytensor.function([x], g, mode="NUMBA")
if on_error == "raise":
with pytest.raises(
np.linalg.LinAlgError,
match=r"Input to cholesky is not positive definite",
match=r"Matrix is not positive definite",
):
f(test_value)
else:
......
......@@ -213,47 +213,3 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed):
resx1 = fn_opt(A_test, x0_test)
rtol = 1e-7 if config.floatX == "float64" else 1e-4
np.testing.assert_allclose(resx0, resx1, rtol=rtol)
@pytest.mark.parametrize(
"assume_a, counter",
(
("gen", LUOpCounter),
("pos", CholeskyOpCounter),
),
)
def test_decomposition_reused_preserves_check_finite(assume_a, counter):
# Check that the LU decomposition rewrite preserves the check_finite flag
rewrite_name = reuse_decomposition_multiple_solves.__name__
A = tensor("A", shape=(2, 2))
b1 = tensor("b1", shape=(2,))
b2 = tensor("b2", shape=(2,))
x1 = solve(A, b1, assume_a=assume_a, check_finite=True)
x2 = solve(A, b2, assume_a=assume_a, check_finite=False)
fn_opt = function(
[A, b1, b2], [x1, x2], mode=get_default_mode().including(rewrite_name)
)
opt_nodes = fn_opt.maker.fgraph.apply_nodes
assert counter.count_vanilla_solve_nodes(opt_nodes) == 0
assert counter.count_decomp_nodes(opt_nodes) == 1
assert counter.count_solve_nodes(opt_nodes) == 2
# We should get an error if A or b1 is non finite
A_valid = np.array([[1, 0], [0, 1]], dtype=A.type.dtype)
b1_valid = np.array([1, 1], dtype=b1.type.dtype)
b2_valid = np.array([1, 1], dtype=b2.type.dtype)
assert fn_opt(A_valid, b1_valid, b2_valid) # Fine
assert fn_opt(
A_valid, b1_valid, b2_valid * np.nan
) # Should not raise (also fine on most LAPACK implementations?)
err_msg = (
"(array must not contain infs or NaNs"
r"|Non-numeric values \(nan or inf\))"
)
with pytest.raises((ValueError, np.linalg.LinAlgError), match=err_msg):
assert fn_opt(A_valid, b1_valid * np.nan, b2_valid)
with pytest.raises((ValueError, np.linalg.LinAlgError), match=err_msg):
assert fn_opt(A_valid * np.nan, b1_valid, b2_valid)
......@@ -74,9 +74,6 @@ def test_cholesky():
chol = Cholesky(lower=False)(x)
ch_f = function([x], chol)
check_upper_triangular(pd, ch_f)
chol = Cholesky(lower=False, on_error="nan")(x)
ch_f = function([x], chol)
check_upper_triangular(pd, ch_f)
def test_cholesky_performance(benchmark):
......@@ -102,12 +99,15 @@ def test_cholesky_empty():
def test_cholesky_indef():
x = matrix()
mat = np.array([[1, 0.2], [0.2, -2]]).astype(config.floatX)
cholesky = Cholesky(lower=True, on_error="raise")
chol_f = function([x], cholesky(x))
with pytest.warns(FutureWarning):
out = cholesky(x, lower=True, on_error="raise")
chol_f = function([x], out)
with pytest.raises(scipy.linalg.LinAlgError):
chol_f(mat)
cholesky = Cholesky(lower=True, on_error="nan")
chol_f = function([x], cholesky(x))
out = cholesky(x, lower=True, on_error="nan")
chol_f = function([x], out)
assert np.all(np.isnan(chol_f(mat)))
......@@ -143,12 +143,16 @@ def test_cholesky_grad():
def test_cholesky_grad_indef():
x = matrix()
mat = np.array([[1, 0.2], [0.2, -2]]).astype(config.floatX)
cholesky = Cholesky(lower=True, on_error="raise")
chol_f = function([x], grad(cholesky(x).sum(), [x]))
with pytest.raises(scipy.linalg.LinAlgError):
chol_f(mat)
cholesky = Cholesky(lower=True, on_error="nan")
chol_f = function([x], grad(cholesky(x).sum(), [x]))
with pytest.warns(FutureWarning):
out = cholesky(x, lower=True, on_error="raise")
chol_f = function([x], grad(out.sum(), [x]), mode="FAST_RUN")
# original cholesky doesn't show up in the grad (if mode="FAST_RUN"), so it does not raise
assert np.all(np.isnan(chol_f(mat)))
out = cholesky(x, lower=True, on_error="nan")
chol_f = function([x], grad(out.sum(), [x]))
assert np.all(np.isnan(chol_f(mat)))
......@@ -237,7 +241,7 @@ class TestSolveBase:
y = self.SolveTest(b_ndim=2)(A, b)
assert (
y.__repr__()
== "SolveTest{lower=False, check_finite=True, b_ndim=2, overwrite_a=False, overwrite_b=False}.0"
== "SolveTest{lower=False, b_ndim=2, overwrite_a=False, overwrite_b=False}.0"
)
......@@ -549,7 +553,7 @@ class TestCholeskySolve(utt.InferShapeTester):
def test_repr(self):
assert (
repr(CholeskySolve(lower=True, b_ndim=1))
== "CholeskySolve(lower=True,check_finite=True,b_ndim=1,overwrite_b=False)"
== "CholeskySolve(lower=True,b_ndim=1,overwrite_b=False)"
)
def test_infer_shape(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论