提交 672a4829 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Do not raise in linalg Ops

上级 b2d8bc24
...@@ -92,7 +92,6 @@ def jax_funcify_Solve(op, **kwargs): ...@@ -92,7 +92,6 @@ def jax_funcify_Solve(op, **kwargs):
def jax_funcify_SolveTriangular(op, **kwargs): def jax_funcify_SolveTriangular(op, **kwargs):
lower = op.lower lower = op.lower
unit_diagonal = op.unit_diagonal unit_diagonal = op.unit_diagonal
check_finite = op.check_finite
def solve_triangular(A, b): def solve_triangular(A, b):
return jax.scipy.linalg.solve_triangular( return jax.scipy.linalg.solve_triangular(
...@@ -101,7 +100,7 @@ def jax_funcify_SolveTriangular(op, **kwargs): ...@@ -101,7 +100,7 @@ def jax_funcify_SolveTriangular(op, **kwargs):
lower=lower, lower=lower,
trans=0, # this is handled by explicitly transposing A, so it will always be 0 when we get to here. trans=0, # this is handled by explicitly transposing A, so it will always be 0 when we get to here.
unit_diagonal=unit_diagonal, unit_diagonal=unit_diagonal,
check_finite=check_finite, check_finite=False,
) )
return solve_triangular return solve_triangular
...@@ -132,27 +131,23 @@ def jax_funcify_PivotToPermutation(op, **kwargs): ...@@ -132,27 +131,23 @@ def jax_funcify_PivotToPermutation(op, **kwargs):
def jax_funcify_LU(op, **kwargs): def jax_funcify_LU(op, **kwargs):
permute_l = op.permute_l permute_l = op.permute_l
p_indices = op.p_indices p_indices = op.p_indices
check_finite = op.check_finite
if p_indices: if p_indices:
raise ValueError("JAX does not support the p_indices argument") raise ValueError("JAX does not support the p_indices argument")
def lu(*inputs): def lu(*inputs):
return jax.scipy.linalg.lu( return jax.scipy.linalg.lu(*inputs, permute_l=permute_l, check_finite=False)
*inputs, permute_l=permute_l, check_finite=check_finite
)
return lu return lu
@jax_funcify.register(LUFactor) @jax_funcify.register(LUFactor)
def jax_funcify_LUFactor(op, **kwargs): def jax_funcify_LUFactor(op, **kwargs):
check_finite = op.check_finite
overwrite_a = op.overwrite_a overwrite_a = op.overwrite_a
def lu_factor(a): def lu_factor(a):
return jax.scipy.linalg.lu_factor( return jax.scipy.linalg.lu_factor(
a, check_finite=check_finite, overwrite_a=overwrite_a a, check_finite=False, overwrite_a=overwrite_a
) )
return lu_factor return lu_factor
...@@ -161,12 +156,11 @@ def jax_funcify_LUFactor(op, **kwargs): ...@@ -161,12 +156,11 @@ def jax_funcify_LUFactor(op, **kwargs):
@jax_funcify.register(CholeskySolve) @jax_funcify.register(CholeskySolve)
def jax_funcify_ChoSolve(op, **kwargs): def jax_funcify_ChoSolve(op, **kwargs):
lower = op.lower lower = op.lower
check_finite = op.check_finite
overwrite_b = op.overwrite_b overwrite_b = op.overwrite_b
def cho_solve(c, b): def cho_solve(c, b):
return jax.scipy.linalg.cho_solve( return jax.scipy.linalg.cho_solve(
(c, lower), b, check_finite=check_finite, overwrite_b=overwrite_b (c, lower), b, check_finite=False, overwrite_b=overwrite_b
) )
return cho_solve return cho_solve
......
...@@ -263,122 +263,6 @@ class _LAPACK: ...@@ -263,122 +263,6 @@ class _LAPACK:
return potrs return potrs
@classmethod
def numba_xlange(cls, dtype) -> CPUDispatcher:
"""
Compute the value of the 1-norm, Frobenius norm, infinity-norm, or the largest absolute value of any element of
a general M-by-N matrix A.
Called by scipy.linalg.solve, but doesn't correspond to any Op in pytensor.
"""
kind = get_blas_kind(dtype)
float_type = _get_nb_float_from_dtype(kind, return_pointer=False)
float_pointer = _get_nb_float_from_dtype(kind, return_pointer=True)
unique_func_name = f"scipy.lapack.{kind}lange"
@numba_basic.numba_njit
def get_lange_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "lange")
return ptr
lange_function_type = types.FunctionType(
float_type(
nb_i32p, # NORM
nb_i32p, # M
nb_i32p, # N
float_pointer, # A
nb_i32p, # LDA
float_pointer, # WORK
)
)
@numba_basic.numba_njit
def lange(NORM, M, N, A, LDA, WORK):
fn = _call_cached_ptr(
get_ptr_func=get_lange_pointer,
func_type_ref=lange_function_type,
unique_func_name_lit=unique_func_name,
)
return fn(NORM, M, N, A, LDA, WORK)
return lange
@classmethod
def numba_xlamch(cls, dtype) -> CPUDispatcher:
"""
Determine machine precision for floating point arithmetic.
"""
kind = get_blas_kind(dtype)
float_type = _get_nb_float_from_dtype(kind, return_pointer=False)
unique_func_name = f"scipy.lapack.{kind}lamch"
@numba_basic.numba_njit
def get_lamch_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "lamch")
return ptr
lamch_function_type = types.FunctionType(
float_type( # Return type
nb_i32p, # CMACH
)
)
@numba_basic.numba_njit
def lamch(CMACH):
fn = _call_cached_ptr(
get_ptr_func=get_lamch_pointer,
func_type_ref=lamch_function_type,
unique_func_name_lit=unique_func_name,
)
res = fn(CMACH)
return res
return lamch
@classmethod
def numba_xgecon(cls, dtype) -> CPUDispatcher:
"""
Estimates the condition number of a matrix A, using the LU factorization computed by numba_getrf.
Called by scipy.linalg.solve when assume_a == "gen"
"""
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}gecon"
@numba_basic.numba_njit
def get_gecon_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "gecon")
return ptr
gecon_function_type = types.FunctionType(
types.void(
nb_i32p, # NORM
nb_i32p, # N
float_pointer, # A
nb_i32p, # LDA
float_pointer, # ANORM
float_pointer, # RCOND
float_pointer, # WORK
nb_i32p, # IWORK
nb_i32p, # INFO
)
)
@numba_basic.numba_njit
def gecon(NORM, N, A, LDA, ANORM, RCOND, WORK, IWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_gecon_pointer,
func_type_ref=gecon_function_type,
unique_func_name_lit=unique_func_name,
)
fn(NORM, N, A, LDA, ANORM, RCOND, WORK, IWORK, INFO)
return gecon
@classmethod @classmethod
def numba_xgetrf(cls, dtype) -> CPUDispatcher: def numba_xgetrf(cls, dtype) -> CPUDispatcher:
""" """
...@@ -506,91 +390,6 @@ class _LAPACK: ...@@ -506,91 +390,6 @@ class _LAPACK:
return sysv return sysv
@classmethod
def numba_xsycon(cls, dtype) -> CPUDispatcher:
"""
Estimate the reciprocal of the condition number of a symmetric matrix A using the UDU or LDL factorization
computed by xSYTRF.
"""
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}sycon"
@numba_basic.numba_njit
def get_sycon_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "sycon")
return ptr
sycon_function_type = types.FunctionType(
types.void(
nb_i32p, # UPLO
nb_i32p, # N
float_pointer, # A
nb_i32p, # LDA
nb_i32p, # IPIV
float_pointer, # ANORM
float_pointer, # RCOND
float_pointer, # WORK
nb_i32p, # IWORK
nb_i32p, # INFO
)
)
@numba_basic.numba_njit
def sycon(UPLO, N, A, LDA, IPIV, ANORM, RCOND, WORK, IWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_sycon_pointer,
func_type_ref=sycon_function_type,
unique_func_name_lit=unique_func_name,
)
fn(UPLO, N, A, LDA, IPIV, ANORM, RCOND, WORK, IWORK, INFO)
return sycon
@classmethod
def numba_xpocon(cls, dtype) -> CPUDispatcher:
"""
Estimates the reciprocal of the condition number of a positive definite matrix A using the Cholesky factorization
computed by potrf.
Called by scipy.linalg.solve when assume_a == "pos"
"""
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}pocon"
@numba_basic.numba_njit
def get_pocon_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "pocon")
return ptr
pocon_function_type = types.FunctionType(
types.void(
nb_i32p, # UPLO
nb_i32p, # N
float_pointer, # A
nb_i32p, # LDA
float_pointer, # ANORM
float_pointer, # RCOND
float_pointer, # WORK
nb_i32p, # IWORK
nb_i32p, # INFO
)
)
@numba_basic.numba_njit
def pocon(UPLO, N, A, LDA, ANORM, RCOND, WORK, IWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_pocon_pointer,
func_type_ref=pocon_function_type,
unique_func_name_lit=unique_func_name,
)
fn(UPLO, N, A, LDA, ANORM, RCOND, WORK, IWORK, INFO)
return pocon
@classmethod @classmethod
def numba_xposv(cls, dtype) -> CPUDispatcher: def numba_xposv(cls, dtype) -> CPUDispatcher:
""" """
......
...@@ -12,24 +12,19 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import ( ...@@ -12,24 +12,19 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix
def _cholesky(a, lower=False, overwrite_a=False, check_finite=True): def _cholesky(a, lower=False, overwrite_a=False):
return ( return linalg.cholesky(a, lower=lower, overwrite_a=overwrite_a, check_finite=False)
linalg.cholesky(
a, lower=lower, overwrite_a=overwrite_a, check_finite=check_finite
),
0,
)
@overload(_cholesky) @overload(_cholesky)
def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True): def cholesky_impl(A, lower=0, overwrite_a=False):
ensure_lapack() ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="cholesky") _check_linalg_matrix(A, ndim=2, dtype=Float, func_name="cholesky")
dtype = A.dtype dtype = A.dtype
numba_potrf = _LAPACK().numba_xpotrf(dtype) numba_potrf = _LAPACK().numba_xpotrf(dtype)
def impl(A, lower=False, overwrite_a=False, check_finite=True): def impl(A, lower=False, overwrite_a=False):
_N = np.int32(A.shape[-1]) _N = np.int32(A.shape[-1])
if A.shape[-2] != _N: if A.shape[-2] != _N:
raise linalg.LinAlgError("Last 2 dimensions of A must be square") raise linalg.LinAlgError("Last 2 dimensions of A must be square")
...@@ -58,6 +53,10 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True): ...@@ -58,6 +53,10 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
INFO, INFO,
) )
if int_ptr_to_val(INFO) != 0:
A_copy = np.full_like(A_copy, np.nan)
return A_copy
if lower: if lower:
for j in range(1, _N): for j in range(1, _N):
for i in range(j): for i in range(j):
...@@ -67,10 +66,9 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True): ...@@ -67,10 +66,9 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
for i in range(j + 1, _N): for i in range(j + 1, _N):
A_copy[i, j] = 0.0 A_copy[i, j] = 0.0
info_int = int_ptr_to_val(INFO)
if transposed: if transposed:
return A_copy.T, info_int return A_copy.T
return A_copy, info_int else:
return A_copy
return impl return impl
...@@ -39,7 +39,6 @@ def _lu_factor_to_lu(a, dtype, overwrite_a): ...@@ -39,7 +39,6 @@ def _lu_factor_to_lu(a, dtype, overwrite_a):
def _lu_1( def _lu_1(
a: np.ndarray, a: np.ndarray,
permute_l: Literal[True], permute_l: Literal[True],
check_finite: bool,
p_indices: Literal[False], p_indices: Literal[False],
overwrite_a: bool, overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
...@@ -52,7 +51,7 @@ def _lu_1( ...@@ -52,7 +51,7 @@ def _lu_1(
return linalg.lu( # type: ignore[no-any-return] return linalg.lu( # type: ignore[no-any-return]
a, a,
permute_l=permute_l, permute_l=permute_l,
check_finite=check_finite, check_finite=False,
p_indices=p_indices, p_indices=p_indices,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
) )
...@@ -61,7 +60,6 @@ def _lu_1( ...@@ -61,7 +60,6 @@ def _lu_1(
def _lu_2( def _lu_2(
a: np.ndarray, a: np.ndarray,
permute_l: Literal[False], permute_l: Literal[False],
check_finite: bool,
p_indices: Literal[True], p_indices: Literal[True],
overwrite_a: bool, overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray]:
...@@ -74,7 +72,7 @@ def _lu_2( ...@@ -74,7 +72,7 @@ def _lu_2(
return linalg.lu( # type: ignore[no-any-return] return linalg.lu( # type: ignore[no-any-return]
a, a,
permute_l=permute_l, permute_l=permute_l,
check_finite=check_finite, check_finite=False,
p_indices=p_indices, p_indices=p_indices,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
) )
...@@ -83,7 +81,6 @@ def _lu_2( ...@@ -83,7 +81,6 @@ def _lu_2(
def _lu_3( def _lu_3(
a: np.ndarray, a: np.ndarray,
permute_l: Literal[False], permute_l: Literal[False],
check_finite: bool,
p_indices: Literal[False], p_indices: Literal[False],
overwrite_a: bool, overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
...@@ -96,7 +93,7 @@ def _lu_3( ...@@ -96,7 +93,7 @@ def _lu_3(
return linalg.lu( # type: ignore[no-any-return] return linalg.lu( # type: ignore[no-any-return]
a, a,
permute_l=permute_l, permute_l=permute_l,
check_finite=check_finite, check_finite=False,
p_indices=p_indices, p_indices=p_indices,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
) )
...@@ -106,11 +103,10 @@ def _lu_3( ...@@ -106,11 +103,10 @@ def _lu_3(
def lu_impl_1( def lu_impl_1(
a: np.ndarray, a: np.ndarray,
permute_l: bool, permute_l: bool,
check_finite: bool,
p_indices: bool, p_indices: bool,
overwrite_a: bool, overwrite_a: bool,
) -> Callable[ ) -> Callable[
[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray] [np.ndarray, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray]
]: ]:
""" """
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
...@@ -123,7 +119,6 @@ def lu_impl_1( ...@@ -123,7 +119,6 @@ def lu_impl_1(
def impl( def impl(
a: np.ndarray, a: np.ndarray,
permute_l: bool, permute_l: bool,
check_finite: bool,
p_indices: bool, p_indices: bool,
overwrite_a: bool, overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
...@@ -137,10 +132,9 @@ def lu_impl_1( ...@@ -137,10 +132,9 @@ def lu_impl_1(
def lu_impl_2( def lu_impl_2(
a: np.ndarray, a: np.ndarray,
permute_l: bool, permute_l: bool,
check_finite: bool,
p_indices: bool, p_indices: bool,
overwrite_a: bool, overwrite_a: bool,
) -> Callable[[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray]]: ) -> Callable[[np.ndarray, bool, bool, bool], tuple[np.ndarray, np.ndarray]]:
""" """
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is False and p_indices is Overload scipy.linalg.lu with a numba function. This function is called when permute_l is False and p_indices is
True. Returns a tuple of (PL, U), where PL is the permuted L matrix, PL = P @ L. True. Returns a tuple of (PL, U), where PL is the permuted L matrix, PL = P @ L.
...@@ -153,7 +147,6 @@ def lu_impl_2( ...@@ -153,7 +147,6 @@ def lu_impl_2(
def impl( def impl(
a: np.ndarray, a: np.ndarray,
permute_l: bool, permute_l: bool,
check_finite: bool,
p_indices: bool, p_indices: bool,
overwrite_a: bool, overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray]:
...@@ -169,11 +162,10 @@ def lu_impl_2( ...@@ -169,11 +162,10 @@ def lu_impl_2(
def lu_impl_3( def lu_impl_3(
a: np.ndarray, a: np.ndarray,
permute_l: bool, permute_l: bool,
check_finite: bool,
p_indices: bool, p_indices: bool,
overwrite_a: bool, overwrite_a: bool,
) -> Callable[ ) -> Callable[
[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray] [np.ndarray, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray]
]: ]:
""" """
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
...@@ -186,7 +178,6 @@ def lu_impl_3( ...@@ -186,7 +178,6 @@ def lu_impl_3(
def impl( def impl(
a: np.ndarray, a: np.ndarray,
permute_l: bool, permute_l: bool,
check_finite: bool,
p_indices: bool, p_indices: bool,
overwrite_a: bool, overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
......
...@@ -79,11 +79,12 @@ def lu_factor_impl( ...@@ -79,11 +79,12 @@ def lu_factor_impl(
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="lu_factor") _check_linalg_matrix(A, ndim=2, dtype=Float, func_name="lu_factor")
def impl(A: np.ndarray, overwrite_a: bool = False) -> tuple[np.ndarray, np.ndarray]: def impl(A: np.ndarray, overwrite_a: bool = False) -> tuple[np.ndarray, np.ndarray]:
A_copy, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a) A_copy, IPIV, info = _getrf(A, overwrite_a=overwrite_a)
IPIV -= 1 # LAPACK uses 1-based indexing, convert to 0-based IPIV -= 1 # LAPACK uses 1-based indexing, convert to 0-based
if INFO != 0: if info != 0:
raise np.linalg.LinAlgError("LU decomposition failed") A_copy = np.full_like(A_copy, np.nan)
return A_copy, IPIV return A_copy, IPIV
return impl return impl
...@@ -228,7 +228,6 @@ def _qr_full_pivot( ...@@ -228,7 +228,6 @@ def _qr_full_pivot(
mode: Literal["full", "economic"] = "full", mode: Literal["full", "economic"] = "full",
pivoting: Literal[True] = True, pivoting: Literal[True] = True,
overwrite_a: bool = False, overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None, lwork: int | None = None,
): ):
""" """
...@@ -243,7 +242,7 @@ def _qr_full_pivot( ...@@ -243,7 +242,7 @@ def _qr_full_pivot(
mode=mode, mode=mode,
pivoting=pivoting, pivoting=pivoting,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
check_finite=check_finite, check_finite=False,
lwork=lwork, lwork=lwork,
) )
...@@ -253,7 +252,6 @@ def _qr_full_no_pivot( ...@@ -253,7 +252,6 @@ def _qr_full_no_pivot(
mode: Literal["full", "economic"] = "full", mode: Literal["full", "economic"] = "full",
pivoting: Literal[False] = False, pivoting: Literal[False] = False,
overwrite_a: bool = False, overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None, lwork: int | None = None,
): ):
""" """
...@@ -267,7 +265,7 @@ def _qr_full_no_pivot( ...@@ -267,7 +265,7 @@ def _qr_full_no_pivot(
mode=mode, mode=mode,
pivoting=pivoting, pivoting=pivoting,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
check_finite=check_finite, check_finite=False,
lwork=lwork, lwork=lwork,
) )
...@@ -277,7 +275,6 @@ def _qr_r_pivot( ...@@ -277,7 +275,6 @@ def _qr_r_pivot(
mode: Literal["r", "raw"] = "r", mode: Literal["r", "raw"] = "r",
pivoting: Literal[True] = True, pivoting: Literal[True] = True,
overwrite_a: bool = False, overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None, lwork: int | None = None,
): ):
""" """
...@@ -291,7 +288,7 @@ def _qr_r_pivot( ...@@ -291,7 +288,7 @@ def _qr_r_pivot(
mode=mode, mode=mode,
pivoting=pivoting, pivoting=pivoting,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
check_finite=check_finite, check_finite=False,
lwork=lwork, lwork=lwork,
) )
...@@ -301,7 +298,6 @@ def _qr_r_no_pivot( ...@@ -301,7 +298,6 @@ def _qr_r_no_pivot(
mode: Literal["r", "raw"] = "r", mode: Literal["r", "raw"] = "r",
pivoting: Literal[False] = False, pivoting: Literal[False] = False,
overwrite_a: bool = False, overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None, lwork: int | None = None,
): ):
""" """
...@@ -315,7 +311,7 @@ def _qr_r_no_pivot( ...@@ -315,7 +311,7 @@ def _qr_r_no_pivot(
mode=mode, mode=mode,
pivoting=pivoting, pivoting=pivoting,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
check_finite=check_finite, check_finite=False,
lwork=lwork, lwork=lwork,
) )
...@@ -325,7 +321,6 @@ def _qr_raw_no_pivot( ...@@ -325,7 +321,6 @@ def _qr_raw_no_pivot(
mode: Literal["raw"] = "raw", mode: Literal["raw"] = "raw",
pivoting: Literal[False] = False, pivoting: Literal[False] = False,
overwrite_a: bool = False, overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None, lwork: int | None = None,
): ):
""" """
...@@ -339,7 +334,7 @@ def _qr_raw_no_pivot( ...@@ -339,7 +334,7 @@ def _qr_raw_no_pivot(
mode=mode, mode=mode,
pivoting=pivoting, pivoting=pivoting,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
check_finite=check_finite, check_finite=False,
lwork=lwork, lwork=lwork,
) )
...@@ -351,7 +346,6 @@ def _qr_raw_pivot( ...@@ -351,7 +346,6 @@ def _qr_raw_pivot(
mode: Literal["raw"] = "raw", mode: Literal["raw"] = "raw",
pivoting: Literal[True] = True, pivoting: Literal[True] = True,
overwrite_a: bool = False, overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None, lwork: int | None = None,
): ):
""" """
...@@ -365,7 +359,7 @@ def _qr_raw_pivot( ...@@ -365,7 +359,7 @@ def _qr_raw_pivot(
mode=mode, mode=mode,
pivoting=pivoting, pivoting=pivoting,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
check_finite=check_finite, check_finite=False,
lwork=lwork, lwork=lwork,
) )
...@@ -373,9 +367,7 @@ def _qr_raw_pivot( ...@@ -373,9 +367,7 @@ def _qr_raw_pivot(
@overload(_qr_full_pivot) @overload(_qr_full_pivot)
def qr_full_pivot_impl( def qr_full_pivot_impl(x, mode="full", pivoting=True, overwrite_a=False, lwork=None):
x, mode="full", pivoting=True, overwrite_a=False, check_finite=False, lwork=None
):
ensure_lapack() ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr") _check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype dtype = x.dtype
...@@ -395,7 +387,6 @@ def qr_full_pivot_impl( ...@@ -395,7 +387,6 @@ def qr_full_pivot_impl(
mode="full", mode="full",
pivoting=True, pivoting=True,
overwrite_a=False, overwrite_a=False,
check_finite=False,
lwork=None, lwork=None,
): ):
M = np.int32(x.shape[0]) M = np.int32(x.shape[0])
...@@ -529,7 +520,7 @@ def qr_full_pivot_impl( ...@@ -529,7 +520,7 @@ def qr_full_pivot_impl(
@overload(_qr_full_no_pivot) @overload(_qr_full_no_pivot)
def qr_full_no_pivot_impl( def qr_full_no_pivot_impl(
x, mode="full", pivoting=False, overwrite_a=False, check_finite=False, lwork=None x, mode="full", pivoting=False, overwrite_a=False, lwork=None
): ):
ensure_lapack() ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr") _check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
...@@ -546,7 +537,6 @@ def qr_full_no_pivot_impl( ...@@ -546,7 +537,6 @@ def qr_full_no_pivot_impl(
mode="full", mode="full",
pivoting=False, pivoting=False,
overwrite_a=False, overwrite_a=False,
check_finite=False,
lwork=None, lwork=None,
): ):
M = np.int32(x.shape[0]) M = np.int32(x.shape[0])
...@@ -645,9 +635,7 @@ def qr_full_no_pivot_impl( ...@@ -645,9 +635,7 @@ def qr_full_no_pivot_impl(
@overload(_qr_r_pivot) @overload(_qr_r_pivot)
def qr_r_pivot_impl( def qr_r_pivot_impl(x, mode="r", pivoting=True, overwrite_a=False, lwork=None):
x, mode="r", pivoting=True, overwrite_a=False, check_finite=False, lwork=None
):
ensure_lapack() ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr") _check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype dtype = x.dtype
...@@ -658,7 +646,6 @@ def qr_r_pivot_impl( ...@@ -658,7 +646,6 @@ def qr_r_pivot_impl(
mode="r", mode="r",
pivoting=True, pivoting=True,
overwrite_a=False, overwrite_a=False,
check_finite=False,
lwork=None, lwork=None,
): ):
M = np.int32(x.shape[0]) M = np.int32(x.shape[0])
...@@ -720,9 +707,7 @@ def qr_r_pivot_impl( ...@@ -720,9 +707,7 @@ def qr_r_pivot_impl(
@overload(_qr_r_no_pivot) @overload(_qr_r_no_pivot)
def qr_r_no_pivot_impl( def qr_r_no_pivot_impl(x, mode="r", pivoting=False, overwrite_a=False, lwork=None):
x, mode="r", pivoting=False, overwrite_a=False, check_finite=False, lwork=None
):
ensure_lapack() ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr") _check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype dtype = x.dtype
...@@ -733,7 +718,6 @@ def qr_r_no_pivot_impl( ...@@ -733,7 +718,6 @@ def qr_r_no_pivot_impl(
mode="r", mode="r",
pivoting=False, pivoting=False,
overwrite_a=False, overwrite_a=False,
check_finite=False,
lwork=None, lwork=None,
): ):
M = np.int32(x.shape[0]) M = np.int32(x.shape[0])
...@@ -792,9 +776,7 @@ def qr_r_no_pivot_impl( ...@@ -792,9 +776,7 @@ def qr_r_no_pivot_impl(
@overload(_qr_raw_no_pivot) @overload(_qr_raw_no_pivot)
def qr_raw_no_pivot_impl( def qr_raw_no_pivot_impl(x, mode="raw", pivoting=False, overwrite_a=False, lwork=None):
x, mode="raw", pivoting=False, overwrite_a=False, check_finite=False, lwork=None
):
ensure_lapack() ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr") _check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype dtype = x.dtype
...@@ -805,7 +787,6 @@ def qr_raw_no_pivot_impl( ...@@ -805,7 +787,6 @@ def qr_raw_no_pivot_impl(
mode="raw", mode="raw",
pivoting=False, pivoting=False,
overwrite_a=False, overwrite_a=False,
check_finite=False,
lwork=None, lwork=None,
): ):
M = np.int32(x.shape[0]) M = np.int32(x.shape[0])
...@@ -863,9 +844,7 @@ def qr_raw_no_pivot_impl( ...@@ -863,9 +844,7 @@ def qr_raw_no_pivot_impl(
@overload(_qr_raw_pivot) @overload(_qr_raw_pivot)
def qr_raw_pivot_impl( def qr_raw_pivot_impl(x, mode="raw", pivoting=True, overwrite_a=False, lwork=None):
x, mode="raw", pivoting=True, overwrite_a=False, check_finite=False, lwork=None
):
ensure_lapack() ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr") _check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
...@@ -880,7 +859,6 @@ def qr_raw_pivot_impl( ...@@ -880,7 +859,6 @@ def qr_raw_pivot_impl(
mode="raw", mode="raw",
pivoting=True, pivoting=True,
overwrite_a=False, overwrite_a=False,
check_finite=False,
lwork=None, lwork=None,
): ):
M = np.int32(x.shape[0]) M = np.int32(x.shape[0])
......
...@@ -14,23 +14,23 @@ from pytensor.link.numba.dispatch.linalg.utils import ( ...@@ -14,23 +14,23 @@ from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match, _check_dtypes_match,
_check_linalg_matrix, _check_linalg_matrix,
_copy_to_fortran_order_even_if_1d, _copy_to_fortran_order_even_if_1d,
_solve_check,
) )
def _cho_solve( def _cho_solve(C: np.ndarray, B: np.ndarray, lower: bool, overwrite_b: bool):
C: np.ndarray, B: np.ndarray, lower: bool, overwrite_b: bool, check_finite: bool
):
""" """
Solve a positive-definite linear system using the Cholesky decomposition. Solve a positive-definite linear system using the Cholesky decomposition.
""" """
return linalg.cho_solve( return linalg.cho_solve(
(C, lower), b=B, overwrite_b=overwrite_b, check_finite=check_finite (C, lower),
b=B,
overwrite_b=overwrite_b,
check_finite=False,
) )
@overload(_cho_solve) @overload(_cho_solve)
def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True): def cho_solve_impl(C, B, lower=False, overwrite_b=False):
ensure_lapack() ensure_lapack()
_check_linalg_matrix(C, ndim=2, dtype=Float, func_name="cho_solve") _check_linalg_matrix(C, ndim=2, dtype=Float, func_name="cho_solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="cho_solve") _check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="cho_solve")
...@@ -38,7 +38,7 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True): ...@@ -38,7 +38,7 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
dtype = C.dtype dtype = C.dtype
numba_potrs = _LAPACK().numba_xpotrs(dtype) numba_potrs = _LAPACK().numba_xpotrs(dtype)
def impl(C, B, lower=False, overwrite_b=False, check_finite=True): def impl(C, B, lower=False, overwrite_b=False):
_solve_check_input_shapes(C, B) _solve_check_input_shapes(C, B)
_N = np.int32(C.shape[-1]) _N = np.int32(C.shape[-1])
...@@ -79,7 +79,8 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True): ...@@ -79,7 +79,8 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
INFO, INFO,
) )
_solve_check(_N, int_ptr_to_val(INFO)) if int_ptr_to_val(INFO) != 0:
B_copy = np.full_like(B_copy, np.nan)
if B_is_1d: if B_is_1d:
return B_copy[..., 0] return B_copy[..., 0]
......
...@@ -3,82 +3,24 @@ from collections.abc import Callable ...@@ -3,82 +3,24 @@ from collections.abc import Callable
import numpy as np import numpy as np
from numba.core.extending import overload from numba.core.extending import overload
from numba.core.types import Float from numba.core.types import Float
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack from numba.np.linalg import ensure_lapack
from scipy import linalg from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _getrf from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _getrf
from pytensor.link.numba.dispatch.linalg.solve.lu_solve import _getrs from pytensor.link.numba.dispatch.linalg.solve.lu_solve import _getrs
from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import ( from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match, _check_dtypes_match,
_check_linalg_matrix, _check_linalg_matrix,
_solve_check,
) )
def _xgecon(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]:
"""
Placeholder for computing the condition number of a matrix; used by linalg.solve. Not used by pytensor to numbify
graphs.
"""
return # type: ignore
@overload(_xgecon)
def xgecon_impl(
A: np.ndarray, A_norm: float, norm: str
) -> Callable[[np.ndarray, float, str], tuple[np.ndarray, int]]:
"""
Compute the condition number of a matrix A.
"""
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="gecon")
dtype = A.dtype
numba_gecon = _LAPACK().numba_xgecon(dtype)
def impl(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]:
_N = np.int32(A.shape[-1])
A_copy = _copy_to_fortran_order(A)
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
A_NORM = np.array(A_norm, dtype=dtype)
NORM = val_to_int_ptr(ord(norm))
RCOND = np.empty(1, dtype=dtype)
WORK = np.empty(4 * _N, dtype=dtype)
IWORK = np.empty(_N, dtype=np.int32)
INFO = val_to_int_ptr(1)
numba_gecon(
NORM,
N,
A_copy.ctypes,
LDA,
A_NORM.ctypes,
RCOND.ctypes,
WORK.ctypes,
IWORK.ctypes,
INFO,
)
return RCOND, int_ptr_to_val(INFO)
return impl
def _solve_gen( def _solve_gen(
A: np.ndarray, A: np.ndarray,
B: np.ndarray, B: np.ndarray,
lower: bool, lower: bool,
overwrite_a: bool, overwrite_a: bool,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
transposed: bool, transposed: bool,
): ):
"""Thin wrapper around scipy.linalg.solve. Used as an overload target for numba to avoid unexpected side-effects """Thin wrapper around scipy.linalg.solve. Used as an overload target for numba to avoid unexpected side-effects
...@@ -89,7 +31,7 @@ def _solve_gen( ...@@ -89,7 +31,7 @@ def _solve_gen(
lower=lower, lower=lower,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
overwrite_b=overwrite_b, overwrite_b=overwrite_b,
check_finite=check_finite, check_finite=False,
assume_a="gen", assume_a="gen",
transposed=transposed, transposed=transposed,
) )
...@@ -102,9 +44,8 @@ def solve_gen_impl( ...@@ -102,9 +44,8 @@ def solve_gen_impl(
lower: bool, lower: bool,
overwrite_a: bool, overwrite_a: bool,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
transposed: bool, transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]: ) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool], np.ndarray]:
ensure_lapack() ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve") _check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve") _check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
...@@ -116,7 +57,6 @@ def solve_gen_impl( ...@@ -116,7 +57,6 @@ def solve_gen_impl(
lower: bool, lower: bool,
overwrite_a: bool, overwrite_a: bool,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
transposed: bool, transposed: bool,
) -> np.ndarray: ) -> np.ndarray:
_N = np.int32(A.shape[-1]) _N = np.int32(A.shape[-1])
...@@ -127,20 +67,18 @@ def solve_gen_impl( ...@@ -127,20 +67,18 @@ def solve_gen_impl(
A = A.T A = A.T
transposed = not transposed transposed = not transposed
order = "I" if transposed else "1" LU, IPIV, INFO1 = _getrf(A, overwrite_a=overwrite_a)
norm = _xlange(A, order=order)
N = A.shape[1] X, INFO2 = _getrs(
LU, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a) LU=LU,
_solve_check(N, INFO) B=B,
IPIV=IPIV,
X, INFO = _getrs( trans=transposed,
LU=LU, B=B, IPIV=IPIV, trans=transposed, overwrite_b=overwrite_b overwrite_b=overwrite_b,
) )
_solve_check(N, INFO)
RCOND, INFO = _xgecon(LU, norm, "1") if INFO1 != 0 or INFO2 != 0:
_solve_check(N, INFO, True, RCOND) X = np.full_like(X, np.nan)
return X return X
......
...@@ -17,7 +17,6 @@ from pytensor.link.numba.dispatch.linalg.utils import ( ...@@ -17,7 +17,6 @@ from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match, _check_dtypes_match,
_check_linalg_matrix, _check_linalg_matrix,
_copy_to_fortran_order_even_if_1d, _copy_to_fortran_order_even_if_1d,
_solve_check,
_trans_char_to_int, _trans_char_to_int,
) )
...@@ -107,14 +106,11 @@ def _lu_solve( ...@@ -107,14 +106,11 @@ def _lu_solve(
b: np.ndarray, b: np.ndarray,
trans: _Trans, trans: _Trans,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
): ):
""" """
Thin wrapper around scipy.lu_solve, used to avoid side effects from numba overloads on users who import Pytensor. Thin wrapper around scipy.lu_solve, used to avoid side effects from numba overloads on users who import Pytensor.
""" """
return linalg.lu_solve( return linalg.lu_solve(lu_and_piv, b, trans=trans, overwrite_b=overwrite_b)
lu_and_piv, b, trans=trans, overwrite_b=overwrite_b, check_finite=check_finite
)
@overload(_lu_solve) @overload(_lu_solve)
...@@ -123,8 +119,7 @@ def lu_solve_impl( ...@@ -123,8 +119,7 @@ def lu_solve_impl(
b: np.ndarray, b: np.ndarray,
trans: _Trans, trans: _Trans,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool, ) -> Callable[[np.ndarray, np.ndarray, np.ndarray, _Trans, bool], np.ndarray]:
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, _Trans, bool, bool], np.ndarray]:
ensure_lapack() ensure_lapack()
lu, _piv = lu_and_piv lu, _piv = lu_and_piv
_check_linalg_matrix(lu, ndim=2, dtype=Float, func_name="lu_solve") _check_linalg_matrix(lu, ndim=2, dtype=Float, func_name="lu_solve")
...@@ -137,13 +132,11 @@ def lu_solve_impl( ...@@ -137,13 +132,11 @@ def lu_solve_impl(
b: np.ndarray, b: np.ndarray,
trans: _Trans, trans: _Trans,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
) -> np.ndarray: ) -> np.ndarray:
n = np.int32(lu.shape[0]) X, info = _getrs(LU=lu, B=b, IPIV=piv, trans=trans, overwrite_b=overwrite_b)
X, INFO = _getrs(LU=lu, B=b, IPIV=piv, trans=trans, overwrite_b=overwrite_b)
_solve_check(n, INFO) if info != 0:
X = np.full_like(X, np.nan)
return X return X
......
from collections.abc import Callable
import numpy as np
from numba.core.extending import overload
from numba.core.types import Float
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix
def _xlange(A: np.ndarray, order: str | None = None) -> float:
"""
Placeholder for computing the norm of a matrix; used by linalg.solve. Will never be called in python mode.
"""
return # type: ignore
@overload(_xlange)
def xlange_impl(
A: np.ndarray, order: str | None = None
) -> Callable[[np.ndarray, str], float]:
"""
xLANGE returns the value of the one norm, or the Frobenius norm, or the infinity norm, or the element of
largest absolute value of a matrix A.
"""
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="norm")
dtype = A.dtype
numba_lange = _LAPACK().numba_xlange(dtype)
def impl(A: np.ndarray, order: str | None = None):
_M, _N = np.int32(A.shape[-2:]) # type: ignore
A_copy = _copy_to_fortran_order(A)
M = val_to_int_ptr(_M) # type: ignore
N = val_to_int_ptr(_N) # type: ignore
LDA = val_to_int_ptr(_M) # type: ignore
NORM = (
val_to_int_ptr(ord(order))
if order is not None
else val_to_int_ptr(ord("1"))
)
WORK = np.empty(_M, dtype=dtype) # type: ignore
result = numba_lange(NORM, M, N, A_copy.ctypes, LDA, WORK.ctypes)
return result
return impl
...@@ -11,13 +11,11 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import ( ...@@ -11,13 +11,11 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
int_ptr_to_val, int_ptr_to_val,
val_to_int_ptr, val_to_int_ptr,
) )
from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import ( from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match, _check_dtypes_match,
_check_linalg_matrix, _check_linalg_matrix,
_copy_to_fortran_order_even_if_1d, _copy_to_fortran_order_even_if_1d,
_solve_check,
) )
...@@ -27,8 +25,6 @@ def _posv( ...@@ -27,8 +25,6 @@ def _posv(
lower: bool, lower: bool,
overwrite_a: bool, overwrite_a: bool,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> tuple[np.ndarray, np.ndarray, int]: ) -> tuple[np.ndarray, np.ndarray, int]:
""" """
Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve. Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve.
...@@ -43,10 +39,8 @@ def posv_impl( ...@@ -43,10 +39,8 @@ def posv_impl(
lower: bool, lower: bool,
overwrite_a: bool, overwrite_a: bool,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> Callable[ ) -> Callable[
[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], [np.ndarray, np.ndarray, bool, bool, bool],
tuple[np.ndarray, np.ndarray, int], tuple[np.ndarray, np.ndarray, int],
]: ]:
ensure_lapack() ensure_lapack()
...@@ -62,8 +56,6 @@ def posv_impl( ...@@ -62,8 +56,6 @@ def posv_impl(
lower: bool, lower: bool,
overwrite_a: bool, overwrite_a: bool,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> tuple[np.ndarray, np.ndarray, int]: ) -> tuple[np.ndarray, np.ndarray, int]:
_solve_check_input_shapes(A, B) _solve_check_input_shapes(A, B)
...@@ -115,60 +107,12 @@ def posv_impl( ...@@ -115,60 +107,12 @@ def posv_impl(
return impl return impl
def _pocon(A: np.ndarray, anorm: float) -> tuple[np.ndarray, int]:
"""
Placeholder for computing the condition number of a cholesky-factorized positive-definite matrix. Used by
linalg.solve when assume_a = "pos".
"""
return # type: ignore
@overload(_pocon)
def pocon_impl(
A: np.ndarray, anorm: float
) -> Callable[[np.ndarray, float], tuple[np.ndarray, int]]:
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="pocon")
dtype = A.dtype
numba_pocon = _LAPACK().numba_xpocon(dtype)
def impl(A: np.ndarray, anorm: float):
_N = np.int32(A.shape[-1])
A_copy = _copy_to_fortran_order(A)
UPLO = val_to_int_ptr(ord("L"))
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
ANORM = np.array(anorm, dtype=dtype)
RCOND = np.empty(1, dtype=dtype)
WORK = np.empty(3 * _N, dtype=dtype)
IWORK = np.empty(_N, dtype=np.int32)
INFO = val_to_int_ptr(0)
numba_pocon(
UPLO,
N,
A_copy.ctypes,
LDA,
ANORM.ctypes,
RCOND.ctypes,
WORK.ctypes,
IWORK.ctypes,
INFO,
)
return RCOND, int_ptr_to_val(INFO)
return impl
def _solve_psd( def _solve_psd(
A: np.ndarray, A: np.ndarray,
B: np.ndarray, B: np.ndarray,
lower: bool, lower: bool,
overwrite_a: bool, overwrite_a: bool,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
transposed: bool, transposed: bool,
): ):
"""Thin wrapper around scipy.linalg.solve for positive-definite matrices. Used as an overload target for numba to """Thin wrapper around scipy.linalg.solve for positive-definite matrices. Used as an overload target for numba to
...@@ -179,7 +123,7 @@ def _solve_psd( ...@@ -179,7 +123,7 @@ def _solve_psd(
lower=lower, lower=lower,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
overwrite_b=overwrite_b, overwrite_b=overwrite_b,
check_finite=check_finite, check_finite=False,
transposed=transposed, transposed=transposed,
assume_a="pos", assume_a="pos",
) )
...@@ -192,9 +136,8 @@ def solve_psd_impl( ...@@ -192,9 +136,8 @@ def solve_psd_impl(
lower: bool, lower: bool,
overwrite_a: bool, overwrite_a: bool,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
transposed: bool, transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]: ) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool], np.ndarray]:
ensure_lapack() ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve") _check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve") _check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
...@@ -206,18 +149,14 @@ def solve_psd_impl( ...@@ -206,18 +149,14 @@ def solve_psd_impl(
lower: bool, lower: bool,
overwrite_a: bool, overwrite_a: bool,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
transposed: bool, transposed: bool,
) -> np.ndarray: ) -> np.ndarray:
_solve_check_input_shapes(A, B) _solve_check_input_shapes(A, B)
C, x, info = _posv( _C, x, info = _posv(A, B, lower, overwrite_a, overwrite_b)
A, B, lower, overwrite_a, overwrite_b, check_finite, transposed
)
_solve_check(A.shape[-1], info)
rcond, info = _pocon(C, _xlange(A)) if info != 0:
_solve_check(A.shape[-1], info=info, lamch=True, rcond=rcond) x = np.full_like(x, np.nan)
return x return x
......
...@@ -11,13 +11,11 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import ( ...@@ -11,13 +11,11 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
int_ptr_to_val, int_ptr_to_val,
val_to_int_ptr, val_to_int_ptr,
) )
from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import ( from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match, _check_dtypes_match,
_check_linalg_matrix, _check_linalg_matrix,
_copy_to_fortran_order_even_if_1d, _copy_to_fortran_order_even_if_1d,
_solve_check,
) )
...@@ -121,61 +119,12 @@ def sysv_impl( ...@@ -121,61 +119,12 @@ def sysv_impl(
return impl return impl
def _sycon(A: np.ndarray, ipiv: np.ndarray, anorm: float) -> tuple[np.ndarray, int]:
"""
Placeholder for computing the condition number of a symmetric matrix; used by linalg.solve. Never called in
python mode.
"""
return # type: ignore
@overload(_sycon)
def sycon_impl(
A: np.ndarray, ipiv: np.ndarray, anorm: float
) -> Callable[[np.ndarray, np.ndarray, float], tuple[np.ndarray, int]]:
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="sycon")
dtype = A.dtype
numba_sycon = _LAPACK().numba_xsycon(dtype)
def impl(A: np.ndarray, ipiv: np.ndarray, anorm: float) -> tuple[np.ndarray, int]:
_N = np.int32(A.shape[-1])
A_copy = _copy_to_fortran_order(A)
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
UPLO = val_to_int_ptr(ord("U"))
ANORM = np.array(anorm, dtype=dtype)
RCOND = np.empty(1, dtype=dtype)
WORK = np.empty(2 * _N, dtype=dtype)
IWORK = np.empty(_N, dtype=np.int32)
INFO = val_to_int_ptr(0)
numba_sycon(
UPLO,
N,
A_copy.ctypes,
LDA,
ipiv.ctypes,
ANORM.ctypes,
RCOND.ctypes,
WORK.ctypes,
IWORK.ctypes,
INFO,
)
return RCOND, int_ptr_to_val(INFO)
return impl
def _solve_symmetric( def _solve_symmetric(
A: np.ndarray, A: np.ndarray,
B: np.ndarray, B: np.ndarray,
lower: bool, lower: bool,
overwrite_a: bool, overwrite_a: bool,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
transposed: bool, transposed: bool,
): ):
"""Thin wrapper around scipy.linalg.solve for symmetric matrices. Used as an overload target for numba to avoid """Thin wrapper around scipy.linalg.solve for symmetric matrices. Used as an overload target for numba to avoid
...@@ -186,7 +135,7 @@ def _solve_symmetric( ...@@ -186,7 +135,7 @@ def _solve_symmetric(
lower=lower, lower=lower,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
overwrite_b=overwrite_b, overwrite_b=overwrite_b,
check_finite=check_finite, check_finite=False,
assume_a="sym", assume_a="sym",
transposed=transposed, transposed=transposed,
) )
...@@ -199,9 +148,8 @@ def solve_symmetric_impl( ...@@ -199,9 +148,8 @@ def solve_symmetric_impl(
lower: bool, lower: bool,
overwrite_a: bool, overwrite_a: bool,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
transposed: bool, transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]: ) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool], np.ndarray]:
ensure_lapack() ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve") _check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve") _check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
...@@ -213,16 +161,14 @@ def solve_symmetric_impl( ...@@ -213,16 +161,14 @@ def solve_symmetric_impl(
lower: bool, lower: bool,
overwrite_a: bool, overwrite_a: bool,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
transposed: bool, transposed: bool,
) -> np.ndarray: ) -> np.ndarray:
_solve_check_input_shapes(A, B) _solve_check_input_shapes(A, B)
lu, x, ipiv, info = _sysv(A, B, lower, overwrite_a, overwrite_b) _lu, x, _ipiv, info = _sysv(A, B, lower, overwrite_a, overwrite_b)
_solve_check(A.shape[-1], info)
rcond, info = _sycon(lu, ipiv, _xlange(A, order="I")) if info != 0:
_solve_check(A.shape[-1], info, True, rcond) x = np.full_like(x, np.nan)
return x return x
......
...@@ -15,13 +15,12 @@ from pytensor.link.numba.dispatch.linalg.utils import ( ...@@ -15,13 +15,12 @@ from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match, _check_dtypes_match,
_check_linalg_matrix, _check_linalg_matrix,
_copy_to_fortran_order_even_if_1d, _copy_to_fortran_order_even_if_1d,
_solve_check,
_trans_char_to_int, _trans_char_to_int,
) )
def _solve_triangular( def _solve_triangular(
A, B, trans=0, lower=False, unit_diagonal=False, b_ndim=1, overwrite_b=False A, B, trans=0, lower=False, unit_diagonal=False, overwrite_b=False
): ):
""" """
Thin wrapper around scipy.linalg.solve_triangular. Thin wrapper around scipy.linalg.solve_triangular.
...@@ -39,11 +38,12 @@ def _solve_triangular( ...@@ -39,11 +38,12 @@ def _solve_triangular(
lower=lower, lower=lower,
unit_diagonal=unit_diagonal, unit_diagonal=unit_diagonal,
overwrite_b=overwrite_b, overwrite_b=overwrite_b,
check_finite=False,
) )
@overload(_solve_triangular) @overload(_solve_triangular)
def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b): def solve_triangular_impl(A, B, trans, lower, unit_diagonal, overwrite_b):
ensure_lapack() ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve_triangular") _check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve_triangular")
...@@ -57,12 +57,10 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b ...@@ -57,12 +57,10 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
"This function is not expected to work with complex numbers yet" "This function is not expected to work with complex numbers yet"
) )
def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b): def impl(A, B, trans, lower, unit_diagonal, overwrite_b):
_N = np.int32(A.shape[-1]) _N = np.int32(A.shape[-1])
_solve_check_input_shapes(A, B) _solve_check_input_shapes(A, B)
# Seems weird to not use the b_ndim input directly, but when I did that Numba complained that the output type
# could potentially be 3d (it didn't understand b_ndim was always equal to B.ndim)
B_is_1d = B.ndim == 1 B_is_1d = B.ndim == 1
if A.flags.f_contiguous or (A.flags.c_contiguous and trans in (0, 1)): if A.flags.f_contiguous or (A.flags.c_contiguous and trans in (0, 1)):
...@@ -106,8 +104,8 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b ...@@ -106,8 +104,8 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
LDB, LDB,
INFO, INFO,
) )
if int_ptr_to_val(INFO) != 0:
_solve_check(int_ptr_to_val(LDA), int_ptr_to_val(INFO)) B_copy = np.full_like(B_copy, np.nan)
if B_is_1d: if B_is_1d:
return B_copy[..., 0] return B_copy[..., 0]
......
...@@ -23,7 +23,6 @@ from pytensor.link.numba.dispatch.linalg.utils import ( ...@@ -23,7 +23,6 @@ from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match, _check_dtypes_match,
_check_linalg_matrix, _check_linalg_matrix,
_copy_to_fortran_order_even_if_1d, _copy_to_fortran_order_even_if_1d,
_solve_check,
_trans_char_to_int, _trans_char_to_int,
) )
from pytensor.tensor._linalg.solve.tridiagonal import ( from pytensor.tensor._linalg.solve.tridiagonal import (
...@@ -202,83 +201,12 @@ def gttrs_impl( ...@@ -202,83 +201,12 @@ def gttrs_impl(
return impl return impl
def _gtcon(
dl: ndarray,
d: ndarray,
du: ndarray,
du2: ndarray,
ipiv: ndarray,
anorm: float,
norm: str,
) -> tuple[ndarray, int]:
"""Placeholder for computing the condition number of a tridiagonal system."""
return # type: ignore
@overload(_gtcon)
def gtcon_impl(
dl: ndarray,
d: ndarray,
du: ndarray,
du2: ndarray,
ipiv: ndarray,
anorm: float,
norm: str,
) -> Callable[
[ndarray, ndarray, ndarray, ndarray, ndarray, float, str], tuple[ndarray, int]
]:
ensure_lapack()
_check_linalg_matrix(dl, ndim=1, dtype=Float, func_name="gtcon")
_check_linalg_matrix(d, ndim=1, dtype=Float, func_name="gtcon")
_check_linalg_matrix(du, ndim=1, dtype=Float, func_name="gtcon")
_check_linalg_matrix(du2, ndim=1, dtype=Float, func_name="gtcon")
_check_dtypes_match((dl, d, du, du2), func_name="gtcon")
_check_linalg_matrix(ipiv, ndim=1, dtype=int32, func_name="gtcon")
dtype = d.dtype
numba_gtcon = _LAPACK().numba_xgtcon(dtype)
def impl(
dl: ndarray,
d: ndarray,
du: ndarray,
du2: ndarray,
ipiv: ndarray,
anorm: float,
norm: str,
) -> tuple[ndarray, int]:
n = np.int32(d.shape[-1])
rcond = np.empty(1, dtype=dtype)
work = np.empty(2 * n, dtype=dtype)
iwork = np.empty(n, dtype=np.int32)
info = val_to_int_ptr(0)
numba_gtcon(
val_to_int_ptr(ord(norm)),
val_to_int_ptr(n),
dl.ctypes,
d.ctypes,
du.ctypes,
du2.ctypes,
ipiv.ctypes,
np.array(anorm, dtype=dtype).ctypes,
rcond.ctypes,
work.ctypes,
iwork.ctypes,
info,
)
return rcond, int_ptr_to_val(info)
return impl
def _solve_tridiagonal( def _solve_tridiagonal(
a: ndarray, a: ndarray,
b: ndarray, b: ndarray,
lower: bool, lower: bool,
overwrite_a: bool, overwrite_a: bool,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
transposed: bool, transposed: bool,
): ):
""" """
...@@ -290,7 +218,7 @@ def _solve_tridiagonal( ...@@ -290,7 +218,7 @@ def _solve_tridiagonal(
lower=lower, lower=lower,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
overwrite_b=overwrite_b, overwrite_b=overwrite_b,
check_finite=check_finite, check_finite=False,
transposed=transposed, transposed=transposed,
assume_a="tridiagonal", assume_a="tridiagonal",
) )
...@@ -303,9 +231,8 @@ def _tridiagonal_solve_impl( ...@@ -303,9 +231,8 @@ def _tridiagonal_solve_impl(
lower: bool, lower: bool,
overwrite_a: bool, overwrite_a: bool,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
transposed: bool, transposed: bool,
) -> Callable[[ndarray, ndarray, bool, bool, bool, bool, bool], ndarray]: ) -> Callable[[ndarray, ndarray, bool, bool, bool, bool], ndarray]:
ensure_lapack() ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve") _check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve") _check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
...@@ -317,31 +244,24 @@ def _tridiagonal_solve_impl( ...@@ -317,31 +244,24 @@ def _tridiagonal_solve_impl(
lower: bool, lower: bool,
overwrite_a: bool, overwrite_a: bool,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
transposed: bool, transposed: bool,
) -> ndarray: ) -> ndarray:
n = np.int32(A.shape[-1])
_solve_check_input_shapes(A, B) _solve_check_input_shapes(A, B)
norm = "1"
if transposed: if transposed:
A = A.T A = A.T
dl, d, du = np.diag(A, -1), np.diag(A, 0), np.diag(A, 1) dl, d, du = np.diag(A, -1), np.diag(A, 0), np.diag(A, 1)
anorm = tridiagonal_norm(du, d, dl) dl, d, du, du2, ipiv, info1 = _gttrf(
dl, d, du, du2, IPIV, INFO = _gttrf(
dl, d, du, overwrite_dl=True, overwrite_d=True, overwrite_du=True dl, d, du, overwrite_dl=True, overwrite_d=True, overwrite_du=True
) )
_solve_check(n, INFO)
X, INFO = _gttrs( X, info2 = _gttrs(
dl, d, du, du2, IPIV, B, trans=transposed, overwrite_b=overwrite_b dl, d, du, du2, ipiv, B, trans=transposed, overwrite_b=overwrite_b
) )
_solve_check(n, INFO)
RCOND, INFO = _gtcon(dl, d, du, du2, IPIV, anorm, norm) if info1 != 0 or info2 != 0:
_solve_check(n, INFO, True, RCOND) X = np.full_like(X, np.nan)
return X return X
...@@ -391,8 +311,8 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs): ...@@ -391,8 +311,8 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
) )
return dl, d, du, du2, ipiv return dl, d, du, du2, ipiv
cache_key = 1 cache_version = 2
return lu_factor_tridiagonal, cache_key return lu_factor_tridiagonal, cache_version
@register_funcify_default_op_cache_key(SolveLUFactorTridiagonal) @register_funcify_default_op_cache_key(SolveLUFactorTridiagonal)
...@@ -434,7 +354,7 @@ def numba_funcify_SolveLUFactorTridiagonal( ...@@ -434,7 +354,7 @@ def numba_funcify_SolveLUFactorTridiagonal(
ipiv = ipiv.astype(np.int32) ipiv = ipiv.astype(np.int32)
if cast_b: if cast_b:
b = b.astype(out_dtype) b = b.astype(out_dtype)
x, _ = _gttrs( x, info = _gttrs(
dl, dl,
d, d,
du, du,
...@@ -444,7 +364,11 @@ def numba_funcify_SolveLUFactorTridiagonal( ...@@ -444,7 +364,11 @@ def numba_funcify_SolveLUFactorTridiagonal(
overwrite_b=overwrite_b, overwrite_b=overwrite_b,
trans=transposed, trans=transposed,
) )
if info != 0:
x = np.full_like(x, np.nan)
return x return x
cache_key = 1 cache_version = 2
return solve_lu_factor_tridiagonal, cache_key return solve_lu_factor_tridiagonal, cache_version
from collections.abc import Callable, Sequence from collections.abc import Sequence
import numba import numba
from numba.core import types from numba.core import types
from numba.core.extending import overload from numba.np.linalg import _copy_to_fortran_order
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from numpy.linalg import LinAlgError
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
val_to_int_ptr,
)
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit(inline="always")
...@@ -61,58 +54,3 @@ def _check_dtypes_match(arrays: Sequence, func_name="cho_solve"): ...@@ -61,58 +54,3 @@ def _check_dtypes_match(arrays: Sequence, func_name="cho_solve"):
if first_dtype != other_dtype: if first_dtype != other_dtype:
msg = f"{func_name} only supported for matching dtypes, got {dtypes}" msg = f"{func_name} only supported for matching dtypes, got {dtypes}"
raise numba.TypingError(msg, highlighting=False) raise numba.TypingError(msg, highlighting=False)
@numba_basic.numba_njit(inline="always")
def _solve_check(n, info, lamch=False, rcond=None):
"""
Check arguments during the different steps of the solution phase
Adapted from https://github.com/scipy/scipy/blob/7f7f04caa4a55306a9c6613c89eef91fedbd72d4/scipy/linalg/_basic.py#L38
"""
if info < 0:
# TODO: figure out how to do an fstring here
msg = "LAPACK reported an illegal value in input"
raise ValueError(msg)
elif 0 < info:
raise LinAlgError("Matrix is singular.")
if lamch:
E = _xlamch("E")
if rcond < E:
# TODO: This should be a warning, but we can't raise warnings in numba mode
print( # noqa: T201
"Ill-conditioned matrix, rcond=", rcond, ", result may not be accurate."
)
def _xlamch(kind: str = "E"):
"""
Placeholder for getting machine precision; used by linalg.solve. Not used by pytensor to numbify graphs.
"""
pass
@overload(_xlamch)
def xlamch_impl(kind: str = "E") -> Callable[[str], float]:
"""
Compute the machine precision for a given floating point type.
"""
from pytensor import config
ensure_lapack()
w_type = _get_underlying_float(config.floatX)
if w_type == "float32":
dtype = types.float32
elif w_type == "float64":
dtype = types.float64
else:
raise NotImplementedError("Unsupported dtype")
numba_lamch = _LAPACK().numba_xlamch(dtype)
def impl(kind: str = "E") -> float:
KIND = val_to_int_ptr(ord(kind))
return numba_lamch(KIND) # type: ignore
return impl
...@@ -181,8 +181,8 @@ def numba_funcify_ExtractDiag(op, node, **kwargs): ...@@ -181,8 +181,8 @@ def numba_funcify_ExtractDiag(op, node, **kwargs):
out[..., i] = new_entry out[..., i] = new_entry
return out return out
cache_key = 1 cache_version = 1
return extract_diag, cache_key return extract_diag, cache_version
@register_funcify_default_op_cache_key(Eye) @register_funcify_default_op_cache_key(Eye)
......
...@@ -20,14 +20,13 @@ from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_so ...@@ -20,14 +20,13 @@ from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_so
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
def decompose_A(A, assume_a, check_finite, lower): def decompose_A(A, assume_a, lower):
if assume_a == "gen": if assume_a == "gen":
return lu_factor(A, check_finite=check_finite) return lu_factor(A)
elif assume_a == "tridiagonal": elif assume_a == "tridiagonal":
# We didn't implement check_finite for tridiagonal LU factorization
return tridiagonal_lu_factor(A) return tridiagonal_lu_factor(A)
elif assume_a == "pos": elif assume_a == "pos":
return cholesky(A, lower=lower, check_finite=check_finite) return cholesky(A, lower=lower)
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -36,7 +35,6 @@ def solve_decomposed_system( ...@@ -36,7 +35,6 @@ def solve_decomposed_system(
A_decomp, b, transposed=False, lower=False, *, core_solve_op: Solve A_decomp, b, transposed=False, lower=False, *, core_solve_op: Solve
): ):
b_ndim = core_solve_op.b_ndim b_ndim = core_solve_op.b_ndim
check_finite = core_solve_op.check_finite
assume_a = core_solve_op.assume_a assume_a = core_solve_op.assume_a
if assume_a == "gen": if assume_a == "gen":
...@@ -45,10 +43,8 @@ def solve_decomposed_system( ...@@ -45,10 +43,8 @@ def solve_decomposed_system(
b, b,
b_ndim=b_ndim, b_ndim=b_ndim,
trans=transposed, trans=transposed,
check_finite=check_finite,
) )
elif assume_a == "tridiagonal": elif assume_a == "tridiagonal":
# We didn't implement check_finite for tridiagonal LU solve
return tridiagonal_lu_solve( return tridiagonal_lu_solve(
A_decomp, A_decomp,
b, b,
...@@ -61,7 +57,6 @@ def solve_decomposed_system( ...@@ -61,7 +57,6 @@ def solve_decomposed_system(
(A_decomp, lower), (A_decomp, lower),
b, b,
b_ndim=b_ndim, b_ndim=b_ndim,
check_finite=check_finite,
) )
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -141,17 +136,8 @@ def _split_decomp_and_solve_steps( ...@@ -141,17 +136,8 @@ def _split_decomp_and_solve_steps(
): ):
return None return None
# If any Op had check_finite=True, we also do it for the LU decomposition
check_finite_decomp = False
for client, _ in A_solve_clients_and_transpose:
if client.op.core_op.check_finite:
check_finite_decomp = True
break
lower = node.op.core_op.lower lower = node.op.core_op.lower
A_decomp = decompose_A( A_decomp = decompose_A(A, assume_a=assume_a, lower=lower)
A, assume_a=assume_a, check_finite=check_finite_decomp, lower=lower
)
replacements = {} replacements = {}
for client, transposed in A_solve_clients_and_transpose: for client, transposed in A_solve_clients_and_transpose:
......
差异被折叠。
from collections.abc import Sequence from collections.abc import Sequence
from typing import Literal
from pytensor.tensor.slinalg import Cholesky, Solve from pytensor.tensor.slinalg import Cholesky, Solve
from pytensor.xtensor.type import as_xtensor from pytensor.xtensor.type import as_xtensor
...@@ -10,8 +9,7 @@ def cholesky( ...@@ -10,8 +9,7 @@ def cholesky(
x, x,
lower: bool = True, lower: bool = True,
*, *,
check_finite: bool = False, check_finite: bool = True,
on_error: Literal["raise", "nan"] = "raise",
dims: Sequence[str], dims: Sequence[str],
): ):
"""Compute the Cholesky decomposition of an XTensorVariable. """Compute the Cholesky decomposition of an XTensorVariable.
...@@ -22,22 +20,15 @@ def cholesky( ...@@ -22,22 +20,15 @@ def cholesky(
The input variable to decompose. The input variable to decompose.
lower : bool, optional lower : bool, optional
Whether to return the lower triangular matrix. Default is True. Whether to return the lower triangular matrix. Default is True.
check_finite : bool, optional check_finite : bool
Whether to check that the input is finite. Default is False. Unused by PyTensor. PyTensor will return nan if the operation fails.
on_error : {'raise', 'nan'}, optional
What to do if the input is not positive definite. If 'raise', an error is raised.
If 'nan', the output will contain NaNs. Default is 'raise'.
dims : Sequence[str] dims : Sequence[str]
The two core dimensions of the input variable, over which the Cholesky decomposition is computed. The two core dimensions of the input variable, over which the Cholesky decomposition is computed.
""" """
if len(dims) != 2: if len(dims) != 2:
raise ValueError(f"Cholesky needs two dims, got {len(dims)}") raise ValueError(f"Cholesky needs two dims, got {len(dims)}")
core_op = Cholesky( core_op = Cholesky(lower=lower)
lower=lower,
check_finite=check_finite,
on_error=on_error,
)
core_dims = ( core_dims = (
((dims[0], dims[1]),), ((dims[0], dims[1]),),
((dims[0], dims[1]),), ((dims[0], dims[1]),),
...@@ -52,7 +43,7 @@ def solve( ...@@ -52,7 +43,7 @@ def solve(
dims: Sequence[str], dims: Sequence[str],
assume_a="gen", assume_a="gen",
lower: bool = False, lower: bool = False,
check_finite: bool = False, check_finite: bool = True,
): ):
"""Solve a system of linear equations using XTensorVariables. """Solve a system of linear equations using XTensorVariables.
...@@ -75,8 +66,8 @@ def solve( ...@@ -75,8 +66,8 @@ def solve(
Long form options can also be used ["general", "symmetric", "hermitian", "positive_definite"]. Long form options can also be used ["general", "symmetric", "hermitian", "positive_definite"].
lower : bool, optional lower : bool, optional
Whether `a` is lower triangular. Default is False. Only relevant if `assume_a` is "sym", "her", or "pos". Whether `a` is lower triangular. Default is False. Only relevant if `assume_a` is "sym", "her", or "pos".
check_finite : bool, optional check_finite : bool
Whether to check that the input is finite. Default is False. Unused by PyTensor. PyTensor will return nan if the operation fails.
""" """
a, b = as_xtensor(a), as_xtensor(b) a, b = as_xtensor(a), as_xtensor(b)
input_core_dims: tuple[tuple[str, str], tuple[str] | tuple[str, str]] input_core_dims: tuple[tuple[str, str], tuple[str] | tuple[str, str]]
...@@ -98,9 +89,7 @@ def solve( ...@@ -98,9 +89,7 @@ def solve(
else: else:
raise ValueError("Solve dims must have length 2 or 3") raise ValueError("Solve dims must have length 2 or 3")
core_op = Solve( core_op = Solve(b_ndim=b_ndim, assume_a=assume_a, lower=lower)
b_ndim=b_ndim, assume_a=assume_a, lower=lower, check_finite=check_finite
)
x_op = XBlockwise( x_op = XBlockwise(
core_op, core_op,
core_dims=(input_core_dims, output_core_dims), core_dims=(input_core_dims, output_core_dims),
......
import re
from typing import Literal from typing import Literal
import numpy as np import numpy as np
...@@ -36,70 +35,6 @@ floatX = config.floatX ...@@ -36,70 +35,6 @@ floatX = config.floatX
rng = np.random.default_rng(42849) rng = np.random.default_rng(42849)
def test_lamch():
from scipy.linalg import get_lapack_funcs
from pytensor.link.numba.dispatch.linalg.utils import _xlamch
@numba.njit()
def xlamch(kind):
return _xlamch(kind)
lamch = get_lapack_funcs("lamch", (np.array([0.0], dtype=floatX),))
np.testing.assert_allclose(xlamch("E"), lamch("E"))
np.testing.assert_allclose(xlamch("S"), lamch("S"))
np.testing.assert_allclose(xlamch("P"), lamch("P"))
np.testing.assert_allclose(xlamch("B"), lamch("B"))
np.testing.assert_allclose(xlamch("R"), lamch("R"))
np.testing.assert_allclose(xlamch("M"), lamch("M"))
@pytest.mark.parametrize(
"ord_numba, ord_scipy", [("F", "fro"), ("1", 1), ("I", np.inf)]
)
def test_xlange(ord_numba, ord_scipy):
# xlange is called internally only, we don't dispatch pt.linalg.norm to it
from scipy import linalg
from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange
@numba.njit()
def xlange(x, ord):
return _xlange(x, ord)
x = np.random.normal(size=(5, 5)).astype(floatX)
np.testing.assert_allclose(xlange(x, ord_numba), linalg.norm(x, ord_scipy))
@pytest.mark.parametrize("ord_numba, ord_scipy", [("1", 1), ("I", np.inf)])
def test_xgecon(ord_numba, ord_scipy):
# gecon is called internally only, we don't dispatch pt.linalg.norm to it
from scipy.linalg import get_lapack_funcs
from pytensor.link.numba.dispatch.linalg.solve.general import _xgecon
from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange
@numba.njit()
def gecon(x, norm):
anorm = _xlange(x, norm)
cond, info = _xgecon(x, anorm, norm)
return cond, info
x = np.random.normal(size=(5, 5)).astype(floatX)
rcond, info = gecon(x, norm=ord_numba)
# Test against direct call to the underlying LAPACK functions
# Solution does **not** agree with 1 / np.linalg.cond(x) !
lange, gecon = get_lapack_funcs(("lange", "gecon"), (x,))
norm = lange(ord_numba, x)
rcond2, _ = gecon(x, norm, norm=ord_numba)
assert info == 0
np.testing.assert_allclose(rcond, rcond2)
class TestSolves: class TestSolves:
@pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower={x}") @pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower={x}")
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -323,7 +258,7 @@ class TestSolves: ...@@ -323,7 +258,7 @@ class TestSolves:
np.testing.assert_allclose(b_val_not_contig, b_val) np.testing.assert_allclose(b_val_not_contig, b_val)
@pytest.mark.parametrize("value", [np.nan, np.inf]) @pytest.mark.parametrize("value", [np.nan, np.inf])
def test_solve_triangular_raises_on_nan_inf(self, value): def test_solve_triangular_does_not_raise_on_nan_inf(self, value):
A = pt.matrix("A") A = pt.matrix("A")
b = pt.matrix("b") b = pt.matrix("b")
...@@ -335,11 +270,8 @@ class TestSolves: ...@@ -335,11 +270,8 @@ class TestSolves:
A_tri = np.linalg.cholesky(A_sym).astype(floatX) A_tri = np.linalg.cholesky(A_sym).astype(floatX)
b = np.full((5, 1), value).astype(floatX) b = np.full((5, 1), value).astype(floatX)
with pytest.raises( # Not checking everything is nan, because, with inf, LAPACK returns a mix of inf/nan, but does not set info != 0
np.linalg.LinAlgError, assert not np.isfinite(f(A_tri, b)).any()
match=re.escape("Non-numeric values"),
):
f(A_tri, b)
@pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower = {x}") @pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower = {x}")
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -567,10 +499,13 @@ class TestDecompositions: ...@@ -567,10 +499,13 @@ class TestDecompositions:
x = pt.tensor(dtype=floatX, shape=(3, 3)) x = pt.tensor(dtype=floatX, shape=(3, 3))
x = x.T.dot(x) x = x.T.dot(x)
g = pt.linalg.cholesky(x, check_finite=True) with pytest.warns(FutureWarning):
g = pt.linalg.cholesky(x, check_finite=True, on_error="raise")
f = pytensor.function([x], g, mode="NUMBA") f = pytensor.function([x], g, mode="NUMBA")
with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"): with pytest.raises(
np.linalg.LinAlgError, match=r"Matrix is not positive definite"
):
f(test_value) f(test_value)
@pytest.mark.parametrize("on_error", ["nan", "raise"]) @pytest.mark.parametrize("on_error", ["nan", "raise"])
...@@ -578,13 +513,17 @@ class TestDecompositions: ...@@ -578,13 +513,17 @@ class TestDecompositions:
test_value = rng.random(size=(3, 3)).astype(floatX) test_value = rng.random(size=(3, 3)).astype(floatX)
x = pt.tensor(dtype=floatX, shape=(3, 3)) x = pt.tensor(dtype=floatX, shape=(3, 3))
if on_error == "raise":
with pytest.warns(FutureWarning):
g = pt.linalg.cholesky(x, on_error=on_error)
else:
g = pt.linalg.cholesky(x, on_error=on_error) g = pt.linalg.cholesky(x, on_error=on_error)
f = pytensor.function([x], g, mode="NUMBA") f = pytensor.function([x], g, mode="NUMBA")
if on_error == "raise": if on_error == "raise":
with pytest.raises( with pytest.raises(
np.linalg.LinAlgError, np.linalg.LinAlgError,
match=r"Input to cholesky is not positive definite", match=r"Matrix is not positive definite",
): ):
f(test_value) f(test_value)
else: else:
......
...@@ -213,47 +213,3 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed): ...@@ -213,47 +213,3 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed):
resx1 = fn_opt(A_test, x0_test) resx1 = fn_opt(A_test, x0_test)
rtol = 1e-7 if config.floatX == "float64" else 1e-4 rtol = 1e-7 if config.floatX == "float64" else 1e-4
np.testing.assert_allclose(resx0, resx1, rtol=rtol) np.testing.assert_allclose(resx0, resx1, rtol=rtol)
@pytest.mark.parametrize(
"assume_a, counter",
(
("gen", LUOpCounter),
("pos", CholeskyOpCounter),
),
)
def test_decomposition_reused_preserves_check_finite(assume_a, counter):
# Check that the LU decomposition rewrite preserves the check_finite flag
rewrite_name = reuse_decomposition_multiple_solves.__name__
A = tensor("A", shape=(2, 2))
b1 = tensor("b1", shape=(2,))
b2 = tensor("b2", shape=(2,))
x1 = solve(A, b1, assume_a=assume_a, check_finite=True)
x2 = solve(A, b2, assume_a=assume_a, check_finite=False)
fn_opt = function(
[A, b1, b2], [x1, x2], mode=get_default_mode().including(rewrite_name)
)
opt_nodes = fn_opt.maker.fgraph.apply_nodes
assert counter.count_vanilla_solve_nodes(opt_nodes) == 0
assert counter.count_decomp_nodes(opt_nodes) == 1
assert counter.count_solve_nodes(opt_nodes) == 2
# We should get an error if A or b1 is non finite
A_valid = np.array([[1, 0], [0, 1]], dtype=A.type.dtype)
b1_valid = np.array([1, 1], dtype=b1.type.dtype)
b2_valid = np.array([1, 1], dtype=b2.type.dtype)
assert fn_opt(A_valid, b1_valid, b2_valid) # Fine
assert fn_opt(
A_valid, b1_valid, b2_valid * np.nan
) # Should not raise (also fine on most LAPACK implementations?)
err_msg = (
"(array must not contain infs or NaNs"
r"|Non-numeric values \(nan or inf\))"
)
with pytest.raises((ValueError, np.linalg.LinAlgError), match=err_msg):
assert fn_opt(A_valid, b1_valid * np.nan, b2_valid)
with pytest.raises((ValueError, np.linalg.LinAlgError), match=err_msg):
assert fn_opt(A_valid * np.nan, b1_valid, b2_valid)
...@@ -74,9 +74,6 @@ def test_cholesky(): ...@@ -74,9 +74,6 @@ def test_cholesky():
chol = Cholesky(lower=False)(x) chol = Cholesky(lower=False)(x)
ch_f = function([x], chol) ch_f = function([x], chol)
check_upper_triangular(pd, ch_f) check_upper_triangular(pd, ch_f)
chol = Cholesky(lower=False, on_error="nan")(x)
ch_f = function([x], chol)
check_upper_triangular(pd, ch_f)
def test_cholesky_performance(benchmark): def test_cholesky_performance(benchmark):
...@@ -102,12 +99,15 @@ def test_cholesky_empty(): ...@@ -102,12 +99,15 @@ def test_cholesky_empty():
def test_cholesky_indef(): def test_cholesky_indef():
x = matrix() x = matrix()
mat = np.array([[1, 0.2], [0.2, -2]]).astype(config.floatX) mat = np.array([[1, 0.2], [0.2, -2]]).astype(config.floatX)
cholesky = Cholesky(lower=True, on_error="raise")
chol_f = function([x], cholesky(x)) with pytest.warns(FutureWarning):
out = cholesky(x, lower=True, on_error="raise")
chol_f = function([x], out)
with pytest.raises(scipy.linalg.LinAlgError): with pytest.raises(scipy.linalg.LinAlgError):
chol_f(mat) chol_f(mat)
cholesky = Cholesky(lower=True, on_error="nan")
chol_f = function([x], cholesky(x)) out = cholesky(x, lower=True, on_error="nan")
chol_f = function([x], out)
assert np.all(np.isnan(chol_f(mat))) assert np.all(np.isnan(chol_f(mat)))
...@@ -143,12 +143,16 @@ def test_cholesky_grad(): ...@@ -143,12 +143,16 @@ def test_cholesky_grad():
def test_cholesky_grad_indef(): def test_cholesky_grad_indef():
x = matrix() x = matrix()
mat = np.array([[1, 0.2], [0.2, -2]]).astype(config.floatX) mat = np.array([[1, 0.2], [0.2, -2]]).astype(config.floatX)
cholesky = Cholesky(lower=True, on_error="raise")
chol_f = function([x], grad(cholesky(x).sum(), [x])) with pytest.warns(FutureWarning):
with pytest.raises(scipy.linalg.LinAlgError): out = cholesky(x, lower=True, on_error="raise")
chol_f(mat) chol_f = function([x], grad(out.sum(), [x]), mode="FAST_RUN")
cholesky = Cholesky(lower=True, on_error="nan")
chol_f = function([x], grad(cholesky(x).sum(), [x])) # original cholesky doesn't show up in the grad (if mode="FAST_RUN"), so it does not raise
assert np.all(np.isnan(chol_f(mat)))
out = cholesky(x, lower=True, on_error="nan")
chol_f = function([x], grad(out.sum(), [x]))
assert np.all(np.isnan(chol_f(mat))) assert np.all(np.isnan(chol_f(mat)))
...@@ -237,7 +241,7 @@ class TestSolveBase: ...@@ -237,7 +241,7 @@ class TestSolveBase:
y = self.SolveTest(b_ndim=2)(A, b) y = self.SolveTest(b_ndim=2)(A, b)
assert ( assert (
y.__repr__() y.__repr__()
== "SolveTest{lower=False, check_finite=True, b_ndim=2, overwrite_a=False, overwrite_b=False}.0" == "SolveTest{lower=False, b_ndim=2, overwrite_a=False, overwrite_b=False}.0"
) )
...@@ -549,7 +553,7 @@ class TestCholeskySolve(utt.InferShapeTester): ...@@ -549,7 +553,7 @@ class TestCholeskySolve(utt.InferShapeTester):
def test_repr(self): def test_repr(self):
assert ( assert (
repr(CholeskySolve(lower=True, b_ndim=1)) repr(CholeskySolve(lower=True, b_ndim=1))
== "CholeskySolve(lower=True,check_finite=True,b_ndim=1,overwrite_b=False)" == "CholeskySolve(lower=True,b_ndim=1,overwrite_b=False)"
) )
def test_infer_shape(self): def test_infer_shape(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论