提交 672a4829 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Do not raise in linalg Ops

上级 b2d8bc24
...@@ -92,7 +92,6 @@ def jax_funcify_Solve(op, **kwargs): ...@@ -92,7 +92,6 @@ def jax_funcify_Solve(op, **kwargs):
def jax_funcify_SolveTriangular(op, **kwargs): def jax_funcify_SolveTriangular(op, **kwargs):
lower = op.lower lower = op.lower
unit_diagonal = op.unit_diagonal unit_diagonal = op.unit_diagonal
check_finite = op.check_finite
def solve_triangular(A, b): def solve_triangular(A, b):
return jax.scipy.linalg.solve_triangular( return jax.scipy.linalg.solve_triangular(
...@@ -101,7 +100,7 @@ def jax_funcify_SolveTriangular(op, **kwargs): ...@@ -101,7 +100,7 @@ def jax_funcify_SolveTriangular(op, **kwargs):
lower=lower, lower=lower,
trans=0, # this is handled by explicitly transposing A, so it will always be 0 when we get to here. trans=0, # this is handled by explicitly transposing A, so it will always be 0 when we get to here.
unit_diagonal=unit_diagonal, unit_diagonal=unit_diagonal,
check_finite=check_finite, check_finite=False,
) )
return solve_triangular return solve_triangular
...@@ -132,27 +131,23 @@ def jax_funcify_PivotToPermutation(op, **kwargs): ...@@ -132,27 +131,23 @@ def jax_funcify_PivotToPermutation(op, **kwargs):
def jax_funcify_LU(op, **kwargs): def jax_funcify_LU(op, **kwargs):
permute_l = op.permute_l permute_l = op.permute_l
p_indices = op.p_indices p_indices = op.p_indices
check_finite = op.check_finite
if p_indices: if p_indices:
raise ValueError("JAX does not support the p_indices argument") raise ValueError("JAX does not support the p_indices argument")
def lu(*inputs): def lu(*inputs):
return jax.scipy.linalg.lu( return jax.scipy.linalg.lu(*inputs, permute_l=permute_l, check_finite=False)
*inputs, permute_l=permute_l, check_finite=check_finite
)
return lu return lu
@jax_funcify.register(LUFactor) @jax_funcify.register(LUFactor)
def jax_funcify_LUFactor(op, **kwargs): def jax_funcify_LUFactor(op, **kwargs):
check_finite = op.check_finite
overwrite_a = op.overwrite_a overwrite_a = op.overwrite_a
def lu_factor(a): def lu_factor(a):
return jax.scipy.linalg.lu_factor( return jax.scipy.linalg.lu_factor(
a, check_finite=check_finite, overwrite_a=overwrite_a a, check_finite=False, overwrite_a=overwrite_a
) )
return lu_factor return lu_factor
...@@ -161,12 +156,11 @@ def jax_funcify_LUFactor(op, **kwargs): ...@@ -161,12 +156,11 @@ def jax_funcify_LUFactor(op, **kwargs):
@jax_funcify.register(CholeskySolve) @jax_funcify.register(CholeskySolve)
def jax_funcify_ChoSolve(op, **kwargs): def jax_funcify_ChoSolve(op, **kwargs):
lower = op.lower lower = op.lower
check_finite = op.check_finite
overwrite_b = op.overwrite_b overwrite_b = op.overwrite_b
def cho_solve(c, b): def cho_solve(c, b):
return jax.scipy.linalg.cho_solve( return jax.scipy.linalg.cho_solve(
(c, lower), b, check_finite=check_finite, overwrite_b=overwrite_b (c, lower), b, check_finite=False, overwrite_b=overwrite_b
) )
return cho_solve return cho_solve
......
...@@ -263,122 +263,6 @@ class _LAPACK: ...@@ -263,122 +263,6 @@ class _LAPACK:
return potrs return potrs
@classmethod
def numba_xlange(cls, dtype) -> CPUDispatcher:
"""
Compute the value of the 1-norm, Frobenius norm, infinity-norm, or the largest absolute value of any element of
a general M-by-N matrix A.
Called by scipy.linalg.solve, but doesn't correspond to any Op in pytensor.
"""
kind = get_blas_kind(dtype)
float_type = _get_nb_float_from_dtype(kind, return_pointer=False)
float_pointer = _get_nb_float_from_dtype(kind, return_pointer=True)
unique_func_name = f"scipy.lapack.{kind}lange"
@numba_basic.numba_njit
def get_lange_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "lange")
return ptr
lange_function_type = types.FunctionType(
float_type(
nb_i32p, # NORM
nb_i32p, # M
nb_i32p, # N
float_pointer, # A
nb_i32p, # LDA
float_pointer, # WORK
)
)
@numba_basic.numba_njit
def lange(NORM, M, N, A, LDA, WORK):
fn = _call_cached_ptr(
get_ptr_func=get_lange_pointer,
func_type_ref=lange_function_type,
unique_func_name_lit=unique_func_name,
)
return fn(NORM, M, N, A, LDA, WORK)
return lange
@classmethod
def numba_xlamch(cls, dtype) -> CPUDispatcher:
"""
Determine machine precision for floating point arithmetic.
"""
kind = get_blas_kind(dtype)
float_type = _get_nb_float_from_dtype(kind, return_pointer=False)
unique_func_name = f"scipy.lapack.{kind}lamch"
@numba_basic.numba_njit
def get_lamch_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "lamch")
return ptr
lamch_function_type = types.FunctionType(
float_type( # Return type
nb_i32p, # CMACH
)
)
@numba_basic.numba_njit
def lamch(CMACH):
fn = _call_cached_ptr(
get_ptr_func=get_lamch_pointer,
func_type_ref=lamch_function_type,
unique_func_name_lit=unique_func_name,
)
res = fn(CMACH)
return res
return lamch
@classmethod
def numba_xgecon(cls, dtype) -> CPUDispatcher:
"""
Estimates the condition number of a matrix A, using the LU factorization computed by numba_getrf.
Called by scipy.linalg.solve when assume_a == "gen"
"""
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}gecon"
@numba_basic.numba_njit
def get_gecon_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "gecon")
return ptr
gecon_function_type = types.FunctionType(
types.void(
nb_i32p, # NORM
nb_i32p, # N
float_pointer, # A
nb_i32p, # LDA
float_pointer, # ANORM
float_pointer, # RCOND
float_pointer, # WORK
nb_i32p, # IWORK
nb_i32p, # INFO
)
)
@numba_basic.numba_njit
def gecon(NORM, N, A, LDA, ANORM, RCOND, WORK, IWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_gecon_pointer,
func_type_ref=gecon_function_type,
unique_func_name_lit=unique_func_name,
)
fn(NORM, N, A, LDA, ANORM, RCOND, WORK, IWORK, INFO)
return gecon
@classmethod @classmethod
def numba_xgetrf(cls, dtype) -> CPUDispatcher: def numba_xgetrf(cls, dtype) -> CPUDispatcher:
""" """
...@@ -506,91 +390,6 @@ class _LAPACK: ...@@ -506,91 +390,6 @@ class _LAPACK:
return sysv return sysv
@classmethod
def numba_xsycon(cls, dtype) -> CPUDispatcher:
"""
Estimate the reciprocal of the condition number of a symmetric matrix A using the UDU or LDL factorization
computed by xSYTRF.
"""
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}sycon"
@numba_basic.numba_njit
def get_sycon_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "sycon")
return ptr
sycon_function_type = types.FunctionType(
types.void(
nb_i32p, # UPLO
nb_i32p, # N
float_pointer, # A
nb_i32p, # LDA
nb_i32p, # IPIV
float_pointer, # ANORM
float_pointer, # RCOND
float_pointer, # WORK
nb_i32p, # IWORK
nb_i32p, # INFO
)
)
@numba_basic.numba_njit
def sycon(UPLO, N, A, LDA, IPIV, ANORM, RCOND, WORK, IWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_sycon_pointer,
func_type_ref=sycon_function_type,
unique_func_name_lit=unique_func_name,
)
fn(UPLO, N, A, LDA, IPIV, ANORM, RCOND, WORK, IWORK, INFO)
return sycon
@classmethod
def numba_xpocon(cls, dtype) -> CPUDispatcher:
"""
Estimates the reciprocal of the condition number of a positive definite matrix A using the Cholesky factorization
computed by potrf.
Called by scipy.linalg.solve when assume_a == "pos"
"""
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}pocon"
@numba_basic.numba_njit
def get_pocon_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "pocon")
return ptr
pocon_function_type = types.FunctionType(
types.void(
nb_i32p, # UPLO
nb_i32p, # N
float_pointer, # A
nb_i32p, # LDA
float_pointer, # ANORM
float_pointer, # RCOND
float_pointer, # WORK
nb_i32p, # IWORK
nb_i32p, # INFO
)
)
@numba_basic.numba_njit
def pocon(UPLO, N, A, LDA, ANORM, RCOND, WORK, IWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_pocon_pointer,
func_type_ref=pocon_function_type,
unique_func_name_lit=unique_func_name,
)
fn(UPLO, N, A, LDA, ANORM, RCOND, WORK, IWORK, INFO)
return pocon
@classmethod @classmethod
def numba_xposv(cls, dtype) -> CPUDispatcher: def numba_xposv(cls, dtype) -> CPUDispatcher:
""" """
......
...@@ -12,24 +12,19 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import ( ...@@ -12,24 +12,19 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix
def _cholesky(a, lower=False, overwrite_a=False, check_finite=True): def _cholesky(a, lower=False, overwrite_a=False):
return ( return linalg.cholesky(a, lower=lower, overwrite_a=overwrite_a, check_finite=False)
linalg.cholesky(
a, lower=lower, overwrite_a=overwrite_a, check_finite=check_finite
),
0,
)
@overload(_cholesky) @overload(_cholesky)
def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True): def cholesky_impl(A, lower=0, overwrite_a=False):
ensure_lapack() ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="cholesky") _check_linalg_matrix(A, ndim=2, dtype=Float, func_name="cholesky")
dtype = A.dtype dtype = A.dtype
numba_potrf = _LAPACK().numba_xpotrf(dtype) numba_potrf = _LAPACK().numba_xpotrf(dtype)
def impl(A, lower=False, overwrite_a=False, check_finite=True): def impl(A, lower=False, overwrite_a=False):
_N = np.int32(A.shape[-1]) _N = np.int32(A.shape[-1])
if A.shape[-2] != _N: if A.shape[-2] != _N:
raise linalg.LinAlgError("Last 2 dimensions of A must be square") raise linalg.LinAlgError("Last 2 dimensions of A must be square")
...@@ -58,6 +53,10 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True): ...@@ -58,6 +53,10 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
INFO, INFO,
) )
if int_ptr_to_val(INFO) != 0:
A_copy = np.full_like(A_copy, np.nan)
return A_copy
if lower: if lower:
for j in range(1, _N): for j in range(1, _N):
for i in range(j): for i in range(j):
...@@ -67,10 +66,9 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True): ...@@ -67,10 +66,9 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
for i in range(j + 1, _N): for i in range(j + 1, _N):
A_copy[i, j] = 0.0 A_copy[i, j] = 0.0
info_int = int_ptr_to_val(INFO)
if transposed: if transposed:
return A_copy.T, info_int return A_copy.T
return A_copy, info_int else:
return A_copy
return impl return impl
...@@ -39,7 +39,6 @@ def _lu_factor_to_lu(a, dtype, overwrite_a): ...@@ -39,7 +39,6 @@ def _lu_factor_to_lu(a, dtype, overwrite_a):
def _lu_1( def _lu_1(
a: np.ndarray, a: np.ndarray,
permute_l: Literal[True], permute_l: Literal[True],
check_finite: bool,
p_indices: Literal[False], p_indices: Literal[False],
overwrite_a: bool, overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
...@@ -52,7 +51,7 @@ def _lu_1( ...@@ -52,7 +51,7 @@ def _lu_1(
return linalg.lu( # type: ignore[no-any-return] return linalg.lu( # type: ignore[no-any-return]
a, a,
permute_l=permute_l, permute_l=permute_l,
check_finite=check_finite, check_finite=False,
p_indices=p_indices, p_indices=p_indices,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
) )
...@@ -61,7 +60,6 @@ def _lu_1( ...@@ -61,7 +60,6 @@ def _lu_1(
def _lu_2( def _lu_2(
a: np.ndarray, a: np.ndarray,
permute_l: Literal[False], permute_l: Literal[False],
check_finite: bool,
p_indices: Literal[True], p_indices: Literal[True],
overwrite_a: bool, overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray]:
...@@ -74,7 +72,7 @@ def _lu_2( ...@@ -74,7 +72,7 @@ def _lu_2(
return linalg.lu( # type: ignore[no-any-return] return linalg.lu( # type: ignore[no-any-return]
a, a,
permute_l=permute_l, permute_l=permute_l,
check_finite=check_finite, check_finite=False,
p_indices=p_indices, p_indices=p_indices,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
) )
...@@ -83,7 +81,6 @@ def _lu_2( ...@@ -83,7 +81,6 @@ def _lu_2(
def _lu_3( def _lu_3(
a: np.ndarray, a: np.ndarray,
permute_l: Literal[False], permute_l: Literal[False],
check_finite: bool,
p_indices: Literal[False], p_indices: Literal[False],
overwrite_a: bool, overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
...@@ -96,7 +93,7 @@ def _lu_3( ...@@ -96,7 +93,7 @@ def _lu_3(
return linalg.lu( # type: ignore[no-any-return] return linalg.lu( # type: ignore[no-any-return]
a, a,
permute_l=permute_l, permute_l=permute_l,
check_finite=check_finite, check_finite=False,
p_indices=p_indices, p_indices=p_indices,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
) )
...@@ -106,11 +103,10 @@ def _lu_3( ...@@ -106,11 +103,10 @@ def _lu_3(
def lu_impl_1( def lu_impl_1(
a: np.ndarray, a: np.ndarray,
permute_l: bool, permute_l: bool,
check_finite: bool,
p_indices: bool, p_indices: bool,
overwrite_a: bool, overwrite_a: bool,
) -> Callable[ ) -> Callable[
[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray] [np.ndarray, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray]
]: ]:
""" """
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
...@@ -123,7 +119,6 @@ def lu_impl_1( ...@@ -123,7 +119,6 @@ def lu_impl_1(
def impl( def impl(
a: np.ndarray, a: np.ndarray,
permute_l: bool, permute_l: bool,
check_finite: bool,
p_indices: bool, p_indices: bool,
overwrite_a: bool, overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
...@@ -137,10 +132,9 @@ def lu_impl_1( ...@@ -137,10 +132,9 @@ def lu_impl_1(
def lu_impl_2( def lu_impl_2(
a: np.ndarray, a: np.ndarray,
permute_l: bool, permute_l: bool,
check_finite: bool,
p_indices: bool, p_indices: bool,
overwrite_a: bool, overwrite_a: bool,
) -> Callable[[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray]]: ) -> Callable[[np.ndarray, bool, bool, bool], tuple[np.ndarray, np.ndarray]]:
""" """
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is False and p_indices is Overload scipy.linalg.lu with a numba function. This function is called when permute_l is False and p_indices is
True. Returns a tuple of (PL, U), where PL is the permuted L matrix, PL = P @ L. True. Returns a tuple of (PL, U), where PL is the permuted L matrix, PL = P @ L.
...@@ -153,7 +147,6 @@ def lu_impl_2( ...@@ -153,7 +147,6 @@ def lu_impl_2(
def impl( def impl(
a: np.ndarray, a: np.ndarray,
permute_l: bool, permute_l: bool,
check_finite: bool,
p_indices: bool, p_indices: bool,
overwrite_a: bool, overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray]:
...@@ -169,11 +162,10 @@ def lu_impl_2( ...@@ -169,11 +162,10 @@ def lu_impl_2(
def lu_impl_3( def lu_impl_3(
a: np.ndarray, a: np.ndarray,
permute_l: bool, permute_l: bool,
check_finite: bool,
p_indices: bool, p_indices: bool,
overwrite_a: bool, overwrite_a: bool,
) -> Callable[ ) -> Callable[
[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray] [np.ndarray, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray]
]: ]:
""" """
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
...@@ -186,7 +178,6 @@ def lu_impl_3( ...@@ -186,7 +178,6 @@ def lu_impl_3(
def impl( def impl(
a: np.ndarray, a: np.ndarray,
permute_l: bool, permute_l: bool,
check_finite: bool,
p_indices: bool, p_indices: bool,
overwrite_a: bool, overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
......
...@@ -79,11 +79,12 @@ def lu_factor_impl( ...@@ -79,11 +79,12 @@ def lu_factor_impl(
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="lu_factor") _check_linalg_matrix(A, ndim=2, dtype=Float, func_name="lu_factor")
def impl(A: np.ndarray, overwrite_a: bool = False) -> tuple[np.ndarray, np.ndarray]: def impl(A: np.ndarray, overwrite_a: bool = False) -> tuple[np.ndarray, np.ndarray]:
A_copy, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a) A_copy, IPIV, info = _getrf(A, overwrite_a=overwrite_a)
IPIV -= 1 # LAPACK uses 1-based indexing, convert to 0-based IPIV -= 1 # LAPACK uses 1-based indexing, convert to 0-based
if INFO != 0: if info != 0:
raise np.linalg.LinAlgError("LU decomposition failed") A_copy = np.full_like(A_copy, np.nan)
return A_copy, IPIV return A_copy, IPIV
return impl return impl
...@@ -228,7 +228,6 @@ def _qr_full_pivot( ...@@ -228,7 +228,6 @@ def _qr_full_pivot(
mode: Literal["full", "economic"] = "full", mode: Literal["full", "economic"] = "full",
pivoting: Literal[True] = True, pivoting: Literal[True] = True,
overwrite_a: bool = False, overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None, lwork: int | None = None,
): ):
""" """
...@@ -243,7 +242,7 @@ def _qr_full_pivot( ...@@ -243,7 +242,7 @@ def _qr_full_pivot(
mode=mode, mode=mode,
pivoting=pivoting, pivoting=pivoting,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
check_finite=check_finite, check_finite=False,
lwork=lwork, lwork=lwork,
) )
...@@ -253,7 +252,6 @@ def _qr_full_no_pivot( ...@@ -253,7 +252,6 @@ def _qr_full_no_pivot(
mode: Literal["full", "economic"] = "full", mode: Literal["full", "economic"] = "full",
pivoting: Literal[False] = False, pivoting: Literal[False] = False,
overwrite_a: bool = False, overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None, lwork: int | None = None,
): ):
""" """
...@@ -267,7 +265,7 @@ def _qr_full_no_pivot( ...@@ -267,7 +265,7 @@ def _qr_full_no_pivot(
mode=mode, mode=mode,
pivoting=pivoting, pivoting=pivoting,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
check_finite=check_finite, check_finite=False,
lwork=lwork, lwork=lwork,
) )
...@@ -277,7 +275,6 @@ def _qr_r_pivot( ...@@ -277,7 +275,6 @@ def _qr_r_pivot(
mode: Literal["r", "raw"] = "r", mode: Literal["r", "raw"] = "r",
pivoting: Literal[True] = True, pivoting: Literal[True] = True,
overwrite_a: bool = False, overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None, lwork: int | None = None,
): ):
""" """
...@@ -291,7 +288,7 @@ def _qr_r_pivot( ...@@ -291,7 +288,7 @@ def _qr_r_pivot(
mode=mode, mode=mode,
pivoting=pivoting, pivoting=pivoting,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
check_finite=check_finite, check_finite=False,
lwork=lwork, lwork=lwork,
) )
...@@ -301,7 +298,6 @@ def _qr_r_no_pivot( ...@@ -301,7 +298,6 @@ def _qr_r_no_pivot(
mode: Literal["r", "raw"] = "r", mode: Literal["r", "raw"] = "r",
pivoting: Literal[False] = False, pivoting: Literal[False] = False,
overwrite_a: bool = False, overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None, lwork: int | None = None,
): ):
""" """
...@@ -315,7 +311,7 @@ def _qr_r_no_pivot( ...@@ -315,7 +311,7 @@ def _qr_r_no_pivot(
mode=mode, mode=mode,
pivoting=pivoting, pivoting=pivoting,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
check_finite=check_finite, check_finite=False,
lwork=lwork, lwork=lwork,
) )
...@@ -325,7 +321,6 @@ def _qr_raw_no_pivot( ...@@ -325,7 +321,6 @@ def _qr_raw_no_pivot(
mode: Literal["raw"] = "raw", mode: Literal["raw"] = "raw",
pivoting: Literal[False] = False, pivoting: Literal[False] = False,
overwrite_a: bool = False, overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None, lwork: int | None = None,
): ):
""" """
...@@ -339,7 +334,7 @@ def _qr_raw_no_pivot( ...@@ -339,7 +334,7 @@ def _qr_raw_no_pivot(
mode=mode, mode=mode,
pivoting=pivoting, pivoting=pivoting,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
check_finite=check_finite, check_finite=False,
lwork=lwork, lwork=lwork,
) )
...@@ -351,7 +346,6 @@ def _qr_raw_pivot( ...@@ -351,7 +346,6 @@ def _qr_raw_pivot(
mode: Literal["raw"] = "raw", mode: Literal["raw"] = "raw",
pivoting: Literal[True] = True, pivoting: Literal[True] = True,
overwrite_a: bool = False, overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None, lwork: int | None = None,
): ):
""" """
...@@ -365,7 +359,7 @@ def _qr_raw_pivot( ...@@ -365,7 +359,7 @@ def _qr_raw_pivot(
mode=mode, mode=mode,
pivoting=pivoting, pivoting=pivoting,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
check_finite=check_finite, check_finite=False,
lwork=lwork, lwork=lwork,
) )
...@@ -373,9 +367,7 @@ def _qr_raw_pivot( ...@@ -373,9 +367,7 @@ def _qr_raw_pivot(
@overload(_qr_full_pivot) @overload(_qr_full_pivot)
def qr_full_pivot_impl( def qr_full_pivot_impl(x, mode="full", pivoting=True, overwrite_a=False, lwork=None):
x, mode="full", pivoting=True, overwrite_a=False, check_finite=False, lwork=None
):
ensure_lapack() ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr") _check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype dtype = x.dtype
...@@ -395,7 +387,6 @@ def qr_full_pivot_impl( ...@@ -395,7 +387,6 @@ def qr_full_pivot_impl(
mode="full", mode="full",
pivoting=True, pivoting=True,
overwrite_a=False, overwrite_a=False,
check_finite=False,
lwork=None, lwork=None,
): ):
M = np.int32(x.shape[0]) M = np.int32(x.shape[0])
...@@ -529,7 +520,7 @@ def qr_full_pivot_impl( ...@@ -529,7 +520,7 @@ def qr_full_pivot_impl(
@overload(_qr_full_no_pivot) @overload(_qr_full_no_pivot)
def qr_full_no_pivot_impl( def qr_full_no_pivot_impl(
x, mode="full", pivoting=False, overwrite_a=False, check_finite=False, lwork=None x, mode="full", pivoting=False, overwrite_a=False, lwork=None
): ):
ensure_lapack() ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr") _check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
...@@ -546,7 +537,6 @@ def qr_full_no_pivot_impl( ...@@ -546,7 +537,6 @@ def qr_full_no_pivot_impl(
mode="full", mode="full",
pivoting=False, pivoting=False,
overwrite_a=False, overwrite_a=False,
check_finite=False,
lwork=None, lwork=None,
): ):
M = np.int32(x.shape[0]) M = np.int32(x.shape[0])
...@@ -645,9 +635,7 @@ def qr_full_no_pivot_impl( ...@@ -645,9 +635,7 @@ def qr_full_no_pivot_impl(
@overload(_qr_r_pivot) @overload(_qr_r_pivot)
def qr_r_pivot_impl( def qr_r_pivot_impl(x, mode="r", pivoting=True, overwrite_a=False, lwork=None):
x, mode="r", pivoting=True, overwrite_a=False, check_finite=False, lwork=None
):
ensure_lapack() ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr") _check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype dtype = x.dtype
...@@ -658,7 +646,6 @@ def qr_r_pivot_impl( ...@@ -658,7 +646,6 @@ def qr_r_pivot_impl(
mode="r", mode="r",
pivoting=True, pivoting=True,
overwrite_a=False, overwrite_a=False,
check_finite=False,
lwork=None, lwork=None,
): ):
M = np.int32(x.shape[0]) M = np.int32(x.shape[0])
...@@ -720,9 +707,7 @@ def qr_r_pivot_impl( ...@@ -720,9 +707,7 @@ def qr_r_pivot_impl(
@overload(_qr_r_no_pivot) @overload(_qr_r_no_pivot)
def qr_r_no_pivot_impl( def qr_r_no_pivot_impl(x, mode="r", pivoting=False, overwrite_a=False, lwork=None):
x, mode="r", pivoting=False, overwrite_a=False, check_finite=False, lwork=None
):
ensure_lapack() ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr") _check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype dtype = x.dtype
...@@ -733,7 +718,6 @@ def qr_r_no_pivot_impl( ...@@ -733,7 +718,6 @@ def qr_r_no_pivot_impl(
mode="r", mode="r",
pivoting=False, pivoting=False,
overwrite_a=False, overwrite_a=False,
check_finite=False,
lwork=None, lwork=None,
): ):
M = np.int32(x.shape[0]) M = np.int32(x.shape[0])
...@@ -792,9 +776,7 @@ def qr_r_no_pivot_impl( ...@@ -792,9 +776,7 @@ def qr_r_no_pivot_impl(
@overload(_qr_raw_no_pivot) @overload(_qr_raw_no_pivot)
def qr_raw_no_pivot_impl( def qr_raw_no_pivot_impl(x, mode="raw", pivoting=False, overwrite_a=False, lwork=None):
x, mode="raw", pivoting=False, overwrite_a=False, check_finite=False, lwork=None
):
ensure_lapack() ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr") _check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype dtype = x.dtype
...@@ -805,7 +787,6 @@ def qr_raw_no_pivot_impl( ...@@ -805,7 +787,6 @@ def qr_raw_no_pivot_impl(
mode="raw", mode="raw",
pivoting=False, pivoting=False,
overwrite_a=False, overwrite_a=False,
check_finite=False,
lwork=None, lwork=None,
): ):
M = np.int32(x.shape[0]) M = np.int32(x.shape[0])
...@@ -863,9 +844,7 @@ def qr_raw_no_pivot_impl( ...@@ -863,9 +844,7 @@ def qr_raw_no_pivot_impl(
@overload(_qr_raw_pivot) @overload(_qr_raw_pivot)
def qr_raw_pivot_impl( def qr_raw_pivot_impl(x, mode="raw", pivoting=True, overwrite_a=False, lwork=None):
x, mode="raw", pivoting=True, overwrite_a=False, check_finite=False, lwork=None
):
ensure_lapack() ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr") _check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
...@@ -880,7 +859,6 @@ def qr_raw_pivot_impl( ...@@ -880,7 +859,6 @@ def qr_raw_pivot_impl(
mode="raw", mode="raw",
pivoting=True, pivoting=True,
overwrite_a=False, overwrite_a=False,
check_finite=False,
lwork=None, lwork=None,
): ):
M = np.int32(x.shape[0]) M = np.int32(x.shape[0])
......
...@@ -14,23 +14,23 @@ from pytensor.link.numba.dispatch.linalg.utils import ( ...@@ -14,23 +14,23 @@ from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match, _check_dtypes_match,
_check_linalg_matrix, _check_linalg_matrix,
_copy_to_fortran_order_even_if_1d, _copy_to_fortran_order_even_if_1d,
_solve_check,
) )
def _cho_solve( def _cho_solve(C: np.ndarray, B: np.ndarray, lower: bool, overwrite_b: bool):
C: np.ndarray, B: np.ndarray, lower: bool, overwrite_b: bool, check_finite: bool
):
""" """
Solve a positive-definite linear system using the Cholesky decomposition. Solve a positive-definite linear system using the Cholesky decomposition.
""" """
return linalg.cho_solve( return linalg.cho_solve(
(C, lower), b=B, overwrite_b=overwrite_b, check_finite=check_finite (C, lower),
b=B,
overwrite_b=overwrite_b,
check_finite=False,
) )
@overload(_cho_solve) @overload(_cho_solve)
def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True): def cho_solve_impl(C, B, lower=False, overwrite_b=False):
ensure_lapack() ensure_lapack()
_check_linalg_matrix(C, ndim=2, dtype=Float, func_name="cho_solve") _check_linalg_matrix(C, ndim=2, dtype=Float, func_name="cho_solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="cho_solve") _check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="cho_solve")
...@@ -38,7 +38,7 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True): ...@@ -38,7 +38,7 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
dtype = C.dtype dtype = C.dtype
numba_potrs = _LAPACK().numba_xpotrs(dtype) numba_potrs = _LAPACK().numba_xpotrs(dtype)
def impl(C, B, lower=False, overwrite_b=False, check_finite=True): def impl(C, B, lower=False, overwrite_b=False):
_solve_check_input_shapes(C, B) _solve_check_input_shapes(C, B)
_N = np.int32(C.shape[-1]) _N = np.int32(C.shape[-1])
...@@ -79,7 +79,8 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True): ...@@ -79,7 +79,8 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
INFO, INFO,
) )
_solve_check(_N, int_ptr_to_val(INFO)) if int_ptr_to_val(INFO) != 0:
B_copy = np.full_like(B_copy, np.nan)
if B_is_1d: if B_is_1d:
return B_copy[..., 0] return B_copy[..., 0]
......
...@@ -3,82 +3,24 @@ from collections.abc import Callable ...@@ -3,82 +3,24 @@ from collections.abc import Callable
import numpy as np import numpy as np
from numba.core.extending import overload from numba.core.extending import overload
from numba.core.types import Float from numba.core.types import Float
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack from numba.np.linalg import ensure_lapack
from scipy import linalg from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _getrf from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _getrf
from pytensor.link.numba.dispatch.linalg.solve.lu_solve import _getrs from pytensor.link.numba.dispatch.linalg.solve.lu_solve import _getrs
from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import ( from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match, _check_dtypes_match,
_check_linalg_matrix, _check_linalg_matrix,
_solve_check,
) )
def _xgecon(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]:
"""
Placeholder for computing the condition number of a matrix; used by linalg.solve. Not used by pytensor to numbify
graphs.
"""
return # type: ignore
@overload(_xgecon)
def xgecon_impl(
A: np.ndarray, A_norm: float, norm: str
) -> Callable[[np.ndarray, float, str], tuple[np.ndarray, int]]:
"""
Compute the condition number of a matrix A.
"""
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="gecon")
dtype = A.dtype
numba_gecon = _LAPACK().numba_xgecon(dtype)
def impl(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]:
_N = np.int32(A.shape[-1])
A_copy = _copy_to_fortran_order(A)
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
A_NORM = np.array(A_norm, dtype=dtype)
NORM = val_to_int_ptr(ord(norm))
RCOND = np.empty(1, dtype=dtype)
WORK = np.empty(4 * _N, dtype=dtype)
IWORK = np.empty(_N, dtype=np.int32)
INFO = val_to_int_ptr(1)
numba_gecon(
NORM,
N,
A_copy.ctypes,
LDA,
A_NORM.ctypes,
RCOND.ctypes,
WORK.ctypes,
IWORK.ctypes,
INFO,
)
return RCOND, int_ptr_to_val(INFO)
return impl
def _solve_gen( def _solve_gen(
A: np.ndarray, A: np.ndarray,
B: np.ndarray, B: np.ndarray,
lower: bool, lower: bool,
overwrite_a: bool, overwrite_a: bool,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
transposed: bool, transposed: bool,
): ):
"""Thin wrapper around scipy.linalg.solve. Used as an overload target for numba to avoid unexpected side-effects """Thin wrapper around scipy.linalg.solve. Used as an overload target for numba to avoid unexpected side-effects
...@@ -89,7 +31,7 @@ def _solve_gen( ...@@ -89,7 +31,7 @@ def _solve_gen(
lower=lower, lower=lower,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
overwrite_b=overwrite_b, overwrite_b=overwrite_b,
check_finite=check_finite, check_finite=False,
assume_a="gen", assume_a="gen",
transposed=transposed, transposed=transposed,
) )
...@@ -102,9 +44,8 @@ def solve_gen_impl( ...@@ -102,9 +44,8 @@ def solve_gen_impl(
lower: bool, lower: bool,
overwrite_a: bool, overwrite_a: bool,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
transposed: bool, transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]: ) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool], np.ndarray]:
ensure_lapack() ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve") _check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve") _check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
...@@ -116,7 +57,6 @@ def solve_gen_impl( ...@@ -116,7 +57,6 @@ def solve_gen_impl(
lower: bool, lower: bool,
overwrite_a: bool, overwrite_a: bool,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
transposed: bool, transposed: bool,
) -> np.ndarray: ) -> np.ndarray:
_N = np.int32(A.shape[-1]) _N = np.int32(A.shape[-1])
...@@ -127,20 +67,18 @@ def solve_gen_impl( ...@@ -127,20 +67,18 @@ def solve_gen_impl(
A = A.T A = A.T
transposed = not transposed transposed = not transposed
order = "I" if transposed else "1" LU, IPIV, INFO1 = _getrf(A, overwrite_a=overwrite_a)
norm = _xlange(A, order=order)
N = A.shape[1]
LU, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a)
_solve_check(N, INFO)
X, INFO = _getrs( X, INFO2 = _getrs(
LU=LU, B=B, IPIV=IPIV, trans=transposed, overwrite_b=overwrite_b LU=LU,
B=B,
IPIV=IPIV,
trans=transposed,
overwrite_b=overwrite_b,
) )
_solve_check(N, INFO)
RCOND, INFO = _xgecon(LU, norm, "1") if INFO1 != 0 or INFO2 != 0:
_solve_check(N, INFO, True, RCOND) X = np.full_like(X, np.nan)
return X return X
......
...@@ -17,7 +17,6 @@ from pytensor.link.numba.dispatch.linalg.utils import ( ...@@ -17,7 +17,6 @@ from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match, _check_dtypes_match,
_check_linalg_matrix, _check_linalg_matrix,
_copy_to_fortran_order_even_if_1d, _copy_to_fortran_order_even_if_1d,
_solve_check,
_trans_char_to_int, _trans_char_to_int,
) )
...@@ -107,14 +106,11 @@ def _lu_solve( ...@@ -107,14 +106,11 @@ def _lu_solve(
b: np.ndarray, b: np.ndarray,
trans: _Trans, trans: _Trans,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
): ):
""" """
Thin wrapper around scipy.lu_solve, used to avoid side effects from numba overloads on users who import Pytensor. Thin wrapper around scipy.lu_solve, used to avoid side effects from numba overloads on users who import Pytensor.
""" """
return linalg.lu_solve( return linalg.lu_solve(lu_and_piv, b, trans=trans, overwrite_b=overwrite_b)
lu_and_piv, b, trans=trans, overwrite_b=overwrite_b, check_finite=check_finite
)
@overload(_lu_solve) @overload(_lu_solve)
...@@ -123,8 +119,7 @@ def lu_solve_impl( ...@@ -123,8 +119,7 @@ def lu_solve_impl(
b: np.ndarray, b: np.ndarray,
trans: _Trans, trans: _Trans,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool, ) -> Callable[[np.ndarray, np.ndarray, np.ndarray, _Trans, bool], np.ndarray]:
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, _Trans, bool, bool], np.ndarray]:
ensure_lapack() ensure_lapack()
lu, _piv = lu_and_piv lu, _piv = lu_and_piv
_check_linalg_matrix(lu, ndim=2, dtype=Float, func_name="lu_solve") _check_linalg_matrix(lu, ndim=2, dtype=Float, func_name="lu_solve")
...@@ -137,13 +132,11 @@ def lu_solve_impl( ...@@ -137,13 +132,11 @@ def lu_solve_impl(
b: np.ndarray, b: np.ndarray,
trans: _Trans, trans: _Trans,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
) -> np.ndarray: ) -> np.ndarray:
n = np.int32(lu.shape[0]) X, info = _getrs(LU=lu, B=b, IPIV=piv, trans=trans, overwrite_b=overwrite_b)
X, INFO = _getrs(LU=lu, B=b, IPIV=piv, trans=trans, overwrite_b=overwrite_b) if info != 0:
X = np.full_like(X, np.nan)
_solve_check(n, INFO)
return X return X
......
from collections.abc import Callable
import numpy as np
from numba.core.extending import overload
from numba.core.types import Float
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix
def _xlange(A: np.ndarray, order: str | None = None) -> float:
"""
Placeholder for computing the norm of a matrix; used by linalg.solve. Will never be called in python mode.
"""
return # type: ignore
@overload(_xlange)
def xlange_impl(
A: np.ndarray, order: str | None = None
) -> Callable[[np.ndarray, str], float]:
"""
xLANGE returns the value of the one norm, or the Frobenius norm, or the infinity norm, or the element of
largest absolute value of a matrix A.
"""
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="norm")
dtype = A.dtype
numba_lange = _LAPACK().numba_xlange(dtype)
def impl(A: np.ndarray, order: str | None = None):
_M, _N = np.int32(A.shape[-2:]) # type: ignore
A_copy = _copy_to_fortran_order(A)
M = val_to_int_ptr(_M) # type: ignore
N = val_to_int_ptr(_N) # type: ignore
LDA = val_to_int_ptr(_M) # type: ignore
NORM = (
val_to_int_ptr(ord(order))
if order is not None
else val_to_int_ptr(ord("1"))
)
WORK = np.empty(_M, dtype=dtype) # type: ignore
result = numba_lange(NORM, M, N, A_copy.ctypes, LDA, WORK.ctypes)
return result
return impl
...@@ -11,13 +11,11 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import ( ...@@ -11,13 +11,11 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
int_ptr_to_val, int_ptr_to_val,
val_to_int_ptr, val_to_int_ptr,
) )
from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import ( from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match, _check_dtypes_match,
_check_linalg_matrix, _check_linalg_matrix,
_copy_to_fortran_order_even_if_1d, _copy_to_fortran_order_even_if_1d,
_solve_check,
) )
...@@ -27,8 +25,6 @@ def _posv( ...@@ -27,8 +25,6 @@ def _posv(
lower: bool, lower: bool,
overwrite_a: bool, overwrite_a: bool,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> tuple[np.ndarray, np.ndarray, int]: ) -> tuple[np.ndarray, np.ndarray, int]:
""" """
Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve. Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve.
...@@ -43,10 +39,8 @@ def posv_impl( ...@@ -43,10 +39,8 @@ def posv_impl(
lower: bool, lower: bool,
overwrite_a: bool, overwrite_a: bool,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> Callable[ ) -> Callable[
[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], [np.ndarray, np.ndarray, bool, bool, bool],
tuple[np.ndarray, np.ndarray, int], tuple[np.ndarray, np.ndarray, int],
]: ]:
ensure_lapack() ensure_lapack()
...@@ -62,8 +56,6 @@ def posv_impl( ...@@ -62,8 +56,6 @@ def posv_impl(
lower: bool, lower: bool,
overwrite_a: bool, overwrite_a: bool,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> tuple[np.ndarray, np.ndarray, int]: ) -> tuple[np.ndarray, np.ndarray, int]:
_solve_check_input_shapes(A, B) _solve_check_input_shapes(A, B)
...@@ -115,60 +107,12 @@ def posv_impl( ...@@ -115,60 +107,12 @@ def posv_impl(
return impl return impl
def _pocon(A: np.ndarray, anorm: float) -> tuple[np.ndarray, int]:
"""
Placeholder for computing the condition number of a cholesky-factorized positive-definite matrix. Used by
linalg.solve when assume_a = "pos".
"""
return # type: ignore
@overload(_pocon)
def pocon_impl(
A: np.ndarray, anorm: float
) -> Callable[[np.ndarray, float], tuple[np.ndarray, int]]:
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="pocon")
dtype = A.dtype
numba_pocon = _LAPACK().numba_xpocon(dtype)
def impl(A: np.ndarray, anorm: float):
_N = np.int32(A.shape[-1])
A_copy = _copy_to_fortran_order(A)
UPLO = val_to_int_ptr(ord("L"))
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
ANORM = np.array(anorm, dtype=dtype)
RCOND = np.empty(1, dtype=dtype)
WORK = np.empty(3 * _N, dtype=dtype)
IWORK = np.empty(_N, dtype=np.int32)
INFO = val_to_int_ptr(0)
numba_pocon(
UPLO,
N,
A_copy.ctypes,
LDA,
ANORM.ctypes,
RCOND.ctypes,
WORK.ctypes,
IWORK.ctypes,
INFO,
)
return RCOND, int_ptr_to_val(INFO)
return impl
def _solve_psd( def _solve_psd(
A: np.ndarray, A: np.ndarray,
B: np.ndarray, B: np.ndarray,
lower: bool, lower: bool,
overwrite_a: bool, overwrite_a: bool,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
transposed: bool, transposed: bool,
): ):
"""Thin wrapper around scipy.linalg.solve for positive-definite matrices. Used as an overload target for numba to """Thin wrapper around scipy.linalg.solve for positive-definite matrices. Used as an overload target for numba to
...@@ -179,7 +123,7 @@ def _solve_psd( ...@@ -179,7 +123,7 @@ def _solve_psd(
lower=lower, lower=lower,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
overwrite_b=overwrite_b, overwrite_b=overwrite_b,
check_finite=check_finite, check_finite=False,
transposed=transposed, transposed=transposed,
assume_a="pos", assume_a="pos",
) )
...@@ -192,9 +136,8 @@ def solve_psd_impl( ...@@ -192,9 +136,8 @@ def solve_psd_impl(
lower: bool, lower: bool,
overwrite_a: bool, overwrite_a: bool,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
transposed: bool, transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]: ) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool], np.ndarray]:
ensure_lapack() ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve") _check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve") _check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
...@@ -206,18 +149,14 @@ def solve_psd_impl( ...@@ -206,18 +149,14 @@ def solve_psd_impl(
lower: bool, lower: bool,
overwrite_a: bool, overwrite_a: bool,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
transposed: bool, transposed: bool,
) -> np.ndarray: ) -> np.ndarray:
_solve_check_input_shapes(A, B) _solve_check_input_shapes(A, B)
C, x, info = _posv( _C, x, info = _posv(A, B, lower, overwrite_a, overwrite_b)
A, B, lower, overwrite_a, overwrite_b, check_finite, transposed
)
_solve_check(A.shape[-1], info)
rcond, info = _pocon(C, _xlange(A)) if info != 0:
_solve_check(A.shape[-1], info=info, lamch=True, rcond=rcond) x = np.full_like(x, np.nan)
return x return x
......
...@@ -11,13 +11,11 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import ( ...@@ -11,13 +11,11 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
int_ptr_to_val, int_ptr_to_val,
val_to_int_ptr, val_to_int_ptr,
) )
from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import ( from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match, _check_dtypes_match,
_check_linalg_matrix, _check_linalg_matrix,
_copy_to_fortran_order_even_if_1d, _copy_to_fortran_order_even_if_1d,
_solve_check,
) )
...@@ -121,61 +119,12 @@ def sysv_impl( ...@@ -121,61 +119,12 @@ def sysv_impl(
return impl return impl
def _sycon(A: np.ndarray, ipiv: np.ndarray, anorm: float) -> tuple[np.ndarray, int]:
"""
Placeholder for computing the condition number of a symmetric matrix; used by linalg.solve. Never called in
python mode.
"""
return # type: ignore
@overload(_sycon)
def sycon_impl(
A: np.ndarray, ipiv: np.ndarray, anorm: float
) -> Callable[[np.ndarray, np.ndarray, float], tuple[np.ndarray, int]]:
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="sycon")
dtype = A.dtype
numba_sycon = _LAPACK().numba_xsycon(dtype)
def impl(A: np.ndarray, ipiv: np.ndarray, anorm: float) -> tuple[np.ndarray, int]:
_N = np.int32(A.shape[-1])
A_copy = _copy_to_fortran_order(A)
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
UPLO = val_to_int_ptr(ord("U"))
ANORM = np.array(anorm, dtype=dtype)
RCOND = np.empty(1, dtype=dtype)
WORK = np.empty(2 * _N, dtype=dtype)
IWORK = np.empty(_N, dtype=np.int32)
INFO = val_to_int_ptr(0)
numba_sycon(
UPLO,
N,
A_copy.ctypes,
LDA,
ipiv.ctypes,
ANORM.ctypes,
RCOND.ctypes,
WORK.ctypes,
IWORK.ctypes,
INFO,
)
return RCOND, int_ptr_to_val(INFO)
return impl
def _solve_symmetric( def _solve_symmetric(
A: np.ndarray, A: np.ndarray,
B: np.ndarray, B: np.ndarray,
lower: bool, lower: bool,
overwrite_a: bool, overwrite_a: bool,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
transposed: bool, transposed: bool,
): ):
"""Thin wrapper around scipy.linalg.solve for symmetric matrices. Used as an overload target for numba to avoid """Thin wrapper around scipy.linalg.solve for symmetric matrices. Used as an overload target for numba to avoid
...@@ -186,7 +135,7 @@ def _solve_symmetric( ...@@ -186,7 +135,7 @@ def _solve_symmetric(
lower=lower, lower=lower,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
overwrite_b=overwrite_b, overwrite_b=overwrite_b,
check_finite=check_finite, check_finite=False,
assume_a="sym", assume_a="sym",
transposed=transposed, transposed=transposed,
) )
...@@ -199,9 +148,8 @@ def solve_symmetric_impl( ...@@ -199,9 +148,8 @@ def solve_symmetric_impl(
lower: bool, lower: bool,
overwrite_a: bool, overwrite_a: bool,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
transposed: bool, transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]: ) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool], np.ndarray]:
ensure_lapack() ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve") _check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve") _check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
...@@ -213,16 +161,14 @@ def solve_symmetric_impl( ...@@ -213,16 +161,14 @@ def solve_symmetric_impl(
lower: bool, lower: bool,
overwrite_a: bool, overwrite_a: bool,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
transposed: bool, transposed: bool,
) -> np.ndarray: ) -> np.ndarray:
_solve_check_input_shapes(A, B) _solve_check_input_shapes(A, B)
lu, x, ipiv, info = _sysv(A, B, lower, overwrite_a, overwrite_b) _lu, x, _ipiv, info = _sysv(A, B, lower, overwrite_a, overwrite_b)
_solve_check(A.shape[-1], info)
rcond, info = _sycon(lu, ipiv, _xlange(A, order="I")) if info != 0:
_solve_check(A.shape[-1], info, True, rcond) x = np.full_like(x, np.nan)
return x return x
......
...@@ -15,13 +15,12 @@ from pytensor.link.numba.dispatch.linalg.utils import ( ...@@ -15,13 +15,12 @@ from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match, _check_dtypes_match,
_check_linalg_matrix, _check_linalg_matrix,
_copy_to_fortran_order_even_if_1d, _copy_to_fortran_order_even_if_1d,
_solve_check,
_trans_char_to_int, _trans_char_to_int,
) )
def _solve_triangular( def _solve_triangular(
A, B, trans=0, lower=False, unit_diagonal=False, b_ndim=1, overwrite_b=False A, B, trans=0, lower=False, unit_diagonal=False, overwrite_b=False
): ):
""" """
Thin wrapper around scipy.linalg.solve_triangular. Thin wrapper around scipy.linalg.solve_triangular.
...@@ -39,11 +38,12 @@ def _solve_triangular( ...@@ -39,11 +38,12 @@ def _solve_triangular(
lower=lower, lower=lower,
unit_diagonal=unit_diagonal, unit_diagonal=unit_diagonal,
overwrite_b=overwrite_b, overwrite_b=overwrite_b,
check_finite=False,
) )
@overload(_solve_triangular) @overload(_solve_triangular)
def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b): def solve_triangular_impl(A, B, trans, lower, unit_diagonal, overwrite_b):
ensure_lapack() ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve_triangular") _check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve_triangular")
...@@ -57,12 +57,10 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b ...@@ -57,12 +57,10 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
"This function is not expected to work with complex numbers yet" "This function is not expected to work with complex numbers yet"
) )
def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b): def impl(A, B, trans, lower, unit_diagonal, overwrite_b):
_N = np.int32(A.shape[-1]) _N = np.int32(A.shape[-1])
_solve_check_input_shapes(A, B) _solve_check_input_shapes(A, B)
# Seems weird to not use the b_ndim input directly, but when I did that Numba complained that the output type
# could potentially be 3d (it didn't understand b_ndim was always equal to B.ndim)
B_is_1d = B.ndim == 1 B_is_1d = B.ndim == 1
if A.flags.f_contiguous or (A.flags.c_contiguous and trans in (0, 1)): if A.flags.f_contiguous or (A.flags.c_contiguous and trans in (0, 1)):
...@@ -106,8 +104,8 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b ...@@ -106,8 +104,8 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
LDB, LDB,
INFO, INFO,
) )
if int_ptr_to_val(INFO) != 0:
_solve_check(int_ptr_to_val(LDA), int_ptr_to_val(INFO)) B_copy = np.full_like(B_copy, np.nan)
if B_is_1d: if B_is_1d:
return B_copy[..., 0] return B_copy[..., 0]
......
...@@ -23,7 +23,6 @@ from pytensor.link.numba.dispatch.linalg.utils import ( ...@@ -23,7 +23,6 @@ from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match, _check_dtypes_match,
_check_linalg_matrix, _check_linalg_matrix,
_copy_to_fortran_order_even_if_1d, _copy_to_fortran_order_even_if_1d,
_solve_check,
_trans_char_to_int, _trans_char_to_int,
) )
from pytensor.tensor._linalg.solve.tridiagonal import ( from pytensor.tensor._linalg.solve.tridiagonal import (
...@@ -202,83 +201,12 @@ def gttrs_impl( ...@@ -202,83 +201,12 @@ def gttrs_impl(
return impl return impl
def _gtcon(
dl: ndarray,
d: ndarray,
du: ndarray,
du2: ndarray,
ipiv: ndarray,
anorm: float,
norm: str,
) -> tuple[ndarray, int]:
"""Placeholder for computing the condition number of a tridiagonal system."""
return # type: ignore
@overload(_gtcon)
def gtcon_impl(
dl: ndarray,
d: ndarray,
du: ndarray,
du2: ndarray,
ipiv: ndarray,
anorm: float,
norm: str,
) -> Callable[
[ndarray, ndarray, ndarray, ndarray, ndarray, float, str], tuple[ndarray, int]
]:
ensure_lapack()
_check_linalg_matrix(dl, ndim=1, dtype=Float, func_name="gtcon")
_check_linalg_matrix(d, ndim=1, dtype=Float, func_name="gtcon")
_check_linalg_matrix(du, ndim=1, dtype=Float, func_name="gtcon")
_check_linalg_matrix(du2, ndim=1, dtype=Float, func_name="gtcon")
_check_dtypes_match((dl, d, du, du2), func_name="gtcon")
_check_linalg_matrix(ipiv, ndim=1, dtype=int32, func_name="gtcon")
dtype = d.dtype
numba_gtcon = _LAPACK().numba_xgtcon(dtype)
def impl(
dl: ndarray,
d: ndarray,
du: ndarray,
du2: ndarray,
ipiv: ndarray,
anorm: float,
norm: str,
) -> tuple[ndarray, int]:
n = np.int32(d.shape[-1])
rcond = np.empty(1, dtype=dtype)
work = np.empty(2 * n, dtype=dtype)
iwork = np.empty(n, dtype=np.int32)
info = val_to_int_ptr(0)
numba_gtcon(
val_to_int_ptr(ord(norm)),
val_to_int_ptr(n),
dl.ctypes,
d.ctypes,
du.ctypes,
du2.ctypes,
ipiv.ctypes,
np.array(anorm, dtype=dtype).ctypes,
rcond.ctypes,
work.ctypes,
iwork.ctypes,
info,
)
return rcond, int_ptr_to_val(info)
return impl
def _solve_tridiagonal( def _solve_tridiagonal(
a: ndarray, a: ndarray,
b: ndarray, b: ndarray,
lower: bool, lower: bool,
overwrite_a: bool, overwrite_a: bool,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
transposed: bool, transposed: bool,
): ):
""" """
...@@ -290,7 +218,7 @@ def _solve_tridiagonal( ...@@ -290,7 +218,7 @@ def _solve_tridiagonal(
lower=lower, lower=lower,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
overwrite_b=overwrite_b, overwrite_b=overwrite_b,
check_finite=check_finite, check_finite=False,
transposed=transposed, transposed=transposed,
assume_a="tridiagonal", assume_a="tridiagonal",
) )
...@@ -303,9 +231,8 @@ def _tridiagonal_solve_impl( ...@@ -303,9 +231,8 @@ def _tridiagonal_solve_impl(
lower: bool, lower: bool,
overwrite_a: bool, overwrite_a: bool,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
transposed: bool, transposed: bool,
) -> Callable[[ndarray, ndarray, bool, bool, bool, bool, bool], ndarray]: ) -> Callable[[ndarray, ndarray, bool, bool, bool, bool], ndarray]:
ensure_lapack() ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve") _check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve") _check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
...@@ -317,31 +244,24 @@ def _tridiagonal_solve_impl( ...@@ -317,31 +244,24 @@ def _tridiagonal_solve_impl(
lower: bool, lower: bool,
overwrite_a: bool, overwrite_a: bool,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool,
transposed: bool, transposed: bool,
) -> ndarray: ) -> ndarray:
n = np.int32(A.shape[-1])
_solve_check_input_shapes(A, B) _solve_check_input_shapes(A, B)
norm = "1"
if transposed: if transposed:
A = A.T A = A.T
dl, d, du = np.diag(A, -1), np.diag(A, 0), np.diag(A, 1) dl, d, du = np.diag(A, -1), np.diag(A, 0), np.diag(A, 1)
anorm = tridiagonal_norm(du, d, dl) dl, d, du, du2, ipiv, info1 = _gttrf(
dl, d, du, du2, IPIV, INFO = _gttrf(
dl, d, du, overwrite_dl=True, overwrite_d=True, overwrite_du=True dl, d, du, overwrite_dl=True, overwrite_d=True, overwrite_du=True
) )
_solve_check(n, INFO)
X, INFO = _gttrs( X, info2 = _gttrs(
dl, d, du, du2, IPIV, B, trans=transposed, overwrite_b=overwrite_b dl, d, du, du2, ipiv, B, trans=transposed, overwrite_b=overwrite_b
) )
_solve_check(n, INFO)
RCOND, INFO = _gtcon(dl, d, du, du2, IPIV, anorm, norm) if info1 != 0 or info2 != 0:
_solve_check(n, INFO, True, RCOND) X = np.full_like(X, np.nan)
return X return X
...@@ -391,8 +311,8 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs): ...@@ -391,8 +311,8 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
) )
return dl, d, du, du2, ipiv return dl, d, du, du2, ipiv
cache_key = 1 cache_version = 2
return lu_factor_tridiagonal, cache_key return lu_factor_tridiagonal, cache_version
@register_funcify_default_op_cache_key(SolveLUFactorTridiagonal) @register_funcify_default_op_cache_key(SolveLUFactorTridiagonal)
...@@ -434,7 +354,7 @@ def numba_funcify_SolveLUFactorTridiagonal( ...@@ -434,7 +354,7 @@ def numba_funcify_SolveLUFactorTridiagonal(
ipiv = ipiv.astype(np.int32) ipiv = ipiv.astype(np.int32)
if cast_b: if cast_b:
b = b.astype(out_dtype) b = b.astype(out_dtype)
x, _ = _gttrs( x, info = _gttrs(
dl, dl,
d, d,
du, du,
...@@ -444,7 +364,11 @@ def numba_funcify_SolveLUFactorTridiagonal( ...@@ -444,7 +364,11 @@ def numba_funcify_SolveLUFactorTridiagonal(
overwrite_b=overwrite_b, overwrite_b=overwrite_b,
trans=transposed, trans=transposed,
) )
if info != 0:
x = np.full_like(x, np.nan)
return x return x
cache_key = 1 cache_version = 2
return solve_lu_factor_tridiagonal, cache_key return solve_lu_factor_tridiagonal, cache_version
from collections.abc import Callable, Sequence from collections.abc import Sequence
import numba import numba
from numba.core import types from numba.core import types
from numba.core.extending import overload from numba.np.linalg import _copy_to_fortran_order
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from numpy.linalg import LinAlgError
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
val_to_int_ptr,
)
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit(inline="always")
...@@ -61,58 +54,3 @@ def _check_dtypes_match(arrays: Sequence, func_name="cho_solve"): ...@@ -61,58 +54,3 @@ def _check_dtypes_match(arrays: Sequence, func_name="cho_solve"):
if first_dtype != other_dtype: if first_dtype != other_dtype:
msg = f"{func_name} only supported for matching dtypes, got {dtypes}" msg = f"{func_name} only supported for matching dtypes, got {dtypes}"
raise numba.TypingError(msg, highlighting=False) raise numba.TypingError(msg, highlighting=False)
@numba_basic.numba_njit(inline="always")
def _solve_check(n, info, lamch=False, rcond=None):
"""
Check arguments during the different steps of the solution phase
Adapted from https://github.com/scipy/scipy/blob/7f7f04caa4a55306a9c6613c89eef91fedbd72d4/scipy/linalg/_basic.py#L38
"""
if info < 0:
# TODO: figure out how to do an fstring here
msg = "LAPACK reported an illegal value in input"
raise ValueError(msg)
elif 0 < info:
raise LinAlgError("Matrix is singular.")
if lamch:
E = _xlamch("E")
if rcond < E:
# TODO: This should be a warning, but we can't raise warnings in numba mode
print( # noqa: T201
"Ill-conditioned matrix, rcond=", rcond, ", result may not be accurate."
)
def _xlamch(kind: str = "E"):
"""
Placeholder for getting machine precision; used by linalg.solve. Not used by pytensor to numbify graphs.
"""
pass
@overload(_xlamch)
def xlamch_impl(kind: str = "E") -> Callable[[str], float]:
"""
Compute the machine precision for a given floating point type.
"""
from pytensor import config
ensure_lapack()
w_type = _get_underlying_float(config.floatX)
if w_type == "float32":
dtype = types.float32
elif w_type == "float64":
dtype = types.float64
else:
raise NotImplementedError("Unsupported dtype")
numba_lamch = _LAPACK().numba_xlamch(dtype)
def impl(kind: str = "E") -> float:
KIND = val_to_int_ptr(ord(kind))
return numba_lamch(KIND) # type: ignore
return impl
...@@ -58,8 +58,6 @@ def numba_funcify_Cholesky(op, node, **kwargs): ...@@ -58,8 +58,6 @@ def numba_funcify_Cholesky(op, node, **kwargs):
""" """
lower = op.lower lower = op.lower
overwrite_a = op.overwrite_a overwrite_a = op.overwrite_a
check_finite = op.check_finite
on_error = op.on_error
inp_dtype = node.inputs[0].type.numpy_dtype inp_dtype = node.inputs[0].type.numpy_dtype
if inp_dtype.kind == "c": if inp_dtype.kind == "c":
...@@ -77,30 +75,11 @@ def numba_funcify_Cholesky(op, node, **kwargs): ...@@ -77,30 +75,11 @@ def numba_funcify_Cholesky(op, node, **kwargs):
if discrete_inp: if discrete_inp:
a = a.astype(out_dtype) a = a.astype(out_dtype)
elif check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) found in input to cholesky"
)
res, info = _cholesky(a, lower, overwrite_a, check_finite)
if on_error == "raise":
if info > 0:
raise np.linalg.LinAlgError(
"Input to cholesky is not positive definite"
)
if info < 0:
raise ValueError(
'LAPACK reported an illegal value in input on entry to "POTRF."'
)
else:
if info != 0:
res = np.full_like(res, np.nan)
return res return _cholesky(a, lower, overwrite_a)
cache_key = 1 cache_version = 2
return cholesky, cache_key return cholesky, cache_version
@register_funcify_default_op_cache_key(PivotToPermutations) @register_funcify_default_op_cache_key(PivotToPermutations)
...@@ -116,8 +95,8 @@ def pivot_to_permutation(op, node, **kwargs): ...@@ -116,8 +95,8 @@ def pivot_to_permutation(op, node, **kwargs):
return np.argsort(p_inv) return np.argsort(p_inv)
cache_key = 1 cache_version = 2
return numba_pivot_to_permutation, cache_key return numba_pivot_to_permutation, cache_version
@register_funcify_default_op_cache_key(LU) @register_funcify_default_op_cache_key(LU)
...@@ -131,7 +110,6 @@ def numba_funcify_LU(op, node, **kwargs): ...@@ -131,7 +110,6 @@ def numba_funcify_LU(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype out_dtype = node.outputs[0].type.numpy_dtype
permute_l = op.permute_l permute_l = op.permute_l
check_finite = op.check_finite
p_indices = op.p_indices p_indices = op.p_indices
overwrite_a = op.overwrite_a overwrite_a = op.overwrite_a
...@@ -151,17 +129,11 @@ def numba_funcify_LU(op, node, **kwargs): ...@@ -151,17 +129,11 @@ def numba_funcify_LU(op, node, **kwargs):
if discrete_inp: if discrete_inp:
a = a.astype(out_dtype) a = a.astype(out_dtype)
elif check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) found in input to lu"
)
if p_indices: if p_indices:
res = _lu_1( res = _lu_1(
a, a,
permute_l=permute_l, permute_l=permute_l,
check_finite=check_finite,
p_indices=p_indices, p_indices=p_indices,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
) )
...@@ -169,7 +141,6 @@ def numba_funcify_LU(op, node, **kwargs): ...@@ -169,7 +141,6 @@ def numba_funcify_LU(op, node, **kwargs):
res = _lu_2( res = _lu_2(
a, a,
permute_l=permute_l, permute_l=permute_l,
check_finite=check_finite,
p_indices=p_indices, p_indices=p_indices,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
) )
...@@ -177,15 +148,14 @@ def numba_funcify_LU(op, node, **kwargs): ...@@ -177,15 +148,14 @@ def numba_funcify_LU(op, node, **kwargs):
res = _lu_3( res = _lu_3(
a, a,
permute_l=permute_l, permute_l=permute_l,
check_finite=check_finite,
p_indices=p_indices, p_indices=p_indices,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
) )
return res return res
cache_key = 1 cache_version = 2
return lu, cache_key return lu, cache_version
@register_funcify_default_op_cache_key(LUFactor) @register_funcify_default_op_cache_key(LUFactor)
...@@ -198,7 +168,6 @@ def numba_funcify_LUFactor(op, node, **kwargs): ...@@ -198,7 +168,6 @@ def numba_funcify_LUFactor(op, node, **kwargs):
print("LUFactor requires casting discrete input to float") # noqa: T201 print("LUFactor requires casting discrete input to float") # noqa: T201
out_dtype = node.outputs[0].type.numpy_dtype out_dtype = node.outputs[0].type.numpy_dtype
check_finite = op.check_finite
overwrite_a = op.overwrite_a overwrite_a = op.overwrite_a
@numba_basic.numba_njit @numba_basic.numba_njit
...@@ -211,18 +180,13 @@ def numba_funcify_LUFactor(op, node, **kwargs): ...@@ -211,18 +180,13 @@ def numba_funcify_LUFactor(op, node, **kwargs):
if discrete_inp: if discrete_inp:
a = a.astype(out_dtype) a = a.astype(out_dtype)
elif check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) found in input to cholesky"
)
LU, piv = _lu_factor(a, overwrite_a) LU, piv = _lu_factor(a, overwrite_a)
return LU, piv return LU, piv
cache_key = 1 cache_version = 2
return lu_factor, cache_key return lu_factor, cache_version
@register_funcify_default_op_cache_key(BlockDiagonal) @register_funcify_default_op_cache_key(BlockDiagonal)
...@@ -288,8 +252,8 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs): ...@@ -288,8 +252,8 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs):
globals() | {"np": np}, globals() | {"np": np},
) )
cache_key = 1 cache_version = 1
return numba_basic.numba_njit(block_diag), cache_key return numba_basic.numba_njit(block_diag), cache_version
@register_funcify_default_op_cache_key(Solve) @register_funcify_default_op_cache_key(Solve)
...@@ -306,12 +270,9 @@ def numba_funcify_Solve(op, node, **kwargs): ...@@ -306,12 +270,9 @@ def numba_funcify_Solve(op, node, **kwargs):
if must_cast_B and config.compiler_verbose: if must_cast_B and config.compiler_verbose:
print("Solve requires casting second input `b`") # noqa: T201 print("Solve requires casting second input `b`") # noqa: T201
check_finite = op.check_finite
overwrite_a = op.overwrite_a overwrite_a = op.overwrite_a
assume_a = op.assume_a assume_a = op.assume_a
lower = op.lower lower = op.lower
check_finite = op.check_finite
overwrite_a = op.overwrite_a overwrite_a = op.overwrite_a
overwrite_b = op.overwrite_b overwrite_b = op.overwrite_b
transposed = False # TODO: Solve doesnt currently allow the transposed argument transposed = False # TODO: Solve doesnt currently allow the transposed argument
...@@ -344,30 +305,18 @@ def numba_funcify_Solve(op, node, **kwargs): ...@@ -344,30 +305,18 @@ def numba_funcify_Solve(op, node, **kwargs):
a = a.astype(out_dtype) a = a.astype(out_dtype)
if must_cast_B: if must_cast_B:
b = b.astype(out_dtype) b = b.astype(out_dtype)
if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input A to solve"
)
if np.any(np.bitwise_or(np.isinf(b), np.isnan(b))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input b to solve"
)
res = solve_fn(a, b, lower, overwrite_a, overwrite_b, check_finite, transposed)
return res
cache_key = 1 return solve_fn(a, b, lower, overwrite_a, overwrite_b, transposed)
return solve, cache_key
cache_version = 2
return solve, cache_version
@register_funcify_default_op_cache_key(SolveTriangular) @register_funcify_default_op_cache_key(SolveTriangular)
def numba_funcify_SolveTriangular(op, node, **kwargs): def numba_funcify_SolveTriangular(op, node, **kwargs):
lower = op.lower lower = op.lower
unit_diagonal = op.unit_diagonal unit_diagonal = op.unit_diagonal
check_finite = op.check_finite
overwrite_b = op.overwrite_b overwrite_b = op.overwrite_b
b_ndim = op.b_ndim
A_dtype, b_dtype = (i.type.numpy_dtype for i in node.inputs) A_dtype, b_dtype = (i.type.numpy_dtype for i in node.inputs)
out_dtype = node.outputs[0].type.numpy_dtype out_dtype = node.outputs[0].type.numpy_dtype
...@@ -389,37 +338,24 @@ def numba_funcify_SolveTriangular(op, node, **kwargs): ...@@ -389,37 +338,24 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
a = a.astype(out_dtype) a = a.astype(out_dtype)
if must_cast_B: if must_cast_B:
b = b.astype(out_dtype) b = b.astype(out_dtype)
if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): return _solve_triangular(
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input A to solve_triangular"
)
if np.any(np.bitwise_or(np.isinf(b), np.isnan(b))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input b to solve_triangular"
)
res = _solve_triangular(
a, a,
b, b,
trans=0, # transposing is handled explicitly on the graph, so we never use this argument trans=0, # transposing is handled explicitly on the graph, so we never use this argument
lower=lower, lower=lower,
unit_diagonal=unit_diagonal, unit_diagonal=unit_diagonal,
overwrite_b=overwrite_b, overwrite_b=overwrite_b,
b_ndim=b_ndim,
) )
return res cache_version = 2
return solve_triangular, cache_version
cache_key = 1
return solve_triangular, cache_key
@register_funcify_default_op_cache_key(CholeskySolve) @register_funcify_default_op_cache_key(CholeskySolve)
def numba_funcify_CholeskySolve(op, node, **kwargs): def numba_funcify_CholeskySolve(op, node, **kwargs):
lower = op.lower lower = op.lower
overwrite_b = op.overwrite_b overwrite_b = op.overwrite_b
check_finite = op.check_finite
c_dtype, b_dtype = (i.type.numpy_dtype for i in node.inputs) c_dtype, b_dtype = (i.type.numpy_dtype for i in node.inputs)
out_dtype = node.outputs[0].type.numpy_dtype out_dtype = node.outputs[0].type.numpy_dtype
...@@ -439,36 +375,24 @@ def numba_funcify_CholeskySolve(op, node, **kwargs): ...@@ -439,36 +375,24 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
return np.zeros(b.shape, dtype=out_dtype) return np.zeros(b.shape, dtype=out_dtype)
if must_cast_c: if must_cast_c:
c = c.astype(out_dtype) c = c.astype(out_dtype)
if check_finite:
if np.any(np.bitwise_or(np.isinf(c), np.isnan(c))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input A to cho_solve"
)
if must_cast_b: if must_cast_b:
b = b.astype(out_dtype) b = b.astype(out_dtype)
if check_finite:
if np.any(np.bitwise_or(np.isinf(b), np.isnan(b))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input b to cho_solve"
)
return _cho_solve( return _cho_solve(
c, c,
b, b,
lower=lower, lower=lower,
overwrite_b=overwrite_b, overwrite_b=overwrite_b,
check_finite=check_finite,
) )
cache_key = 1 cache_version = 2
return cho_solve, cache_key return cho_solve, cache_version
@register_funcify_default_op_cache_key(QR) @register_funcify_default_op_cache_key(QR)
def numba_funcify_QR(op, node, **kwargs): def numba_funcify_QR(op, node, **kwargs):
mode = op.mode mode = op.mode
check_finite = op.check_finite
pivoting = op.pivoting pivoting = op.pivoting
overwrite_a = op.overwrite_a overwrite_a = op.overwrite_a
...@@ -481,12 +405,6 @@ def numba_funcify_QR(op, node, **kwargs): ...@@ -481,12 +405,6 @@ def numba_funcify_QR(op, node, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def qr(a): def qr(a):
if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) found in input to qr"
)
if integer_input: if integer_input:
a = a.astype(out_dtype) a = a.astype(out_dtype)
...@@ -496,7 +414,6 @@ def numba_funcify_QR(op, node, **kwargs): ...@@ -496,7 +414,6 @@ def numba_funcify_QR(op, node, **kwargs):
mode=mode, mode=mode,
pivoting=pivoting, pivoting=pivoting,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
check_finite=check_finite,
) )
return Q, R, P return Q, R, P
...@@ -506,7 +423,6 @@ def numba_funcify_QR(op, node, **kwargs): ...@@ -506,7 +423,6 @@ def numba_funcify_QR(op, node, **kwargs):
mode=mode, mode=mode,
pivoting=pivoting, pivoting=pivoting,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
check_finite=check_finite,
) )
return Q, R return Q, R
...@@ -516,7 +432,6 @@ def numba_funcify_QR(op, node, **kwargs): ...@@ -516,7 +432,6 @@ def numba_funcify_QR(op, node, **kwargs):
mode=mode, mode=mode,
pivoting=pivoting, pivoting=pivoting,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
check_finite=check_finite,
) )
return R, P return R, P
...@@ -526,7 +441,6 @@ def numba_funcify_QR(op, node, **kwargs): ...@@ -526,7 +441,6 @@ def numba_funcify_QR(op, node, **kwargs):
mode=mode, mode=mode,
pivoting=pivoting, pivoting=pivoting,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
check_finite=check_finite,
) )
return R return R
...@@ -536,7 +450,6 @@ def numba_funcify_QR(op, node, **kwargs): ...@@ -536,7 +450,6 @@ def numba_funcify_QR(op, node, **kwargs):
mode=mode, mode=mode,
pivoting=pivoting, pivoting=pivoting,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
check_finite=check_finite,
) )
return H, tau, R, P return H, tau, R, P
...@@ -546,7 +459,6 @@ def numba_funcify_QR(op, node, **kwargs): ...@@ -546,7 +459,6 @@ def numba_funcify_QR(op, node, **kwargs):
mode=mode, mode=mode,
pivoting=pivoting, pivoting=pivoting,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
check_finite=check_finite,
) )
return H, tau, R return H, tau, R
...@@ -555,5 +467,5 @@ def numba_funcify_QR(op, node, **kwargs): ...@@ -555,5 +467,5 @@ def numba_funcify_QR(op, node, **kwargs):
f"QR mode={mode}, pivoting={pivoting} not supported in numba mode." f"QR mode={mode}, pivoting={pivoting} not supported in numba mode."
) )
cache_key = 1 cache_version = 2
return qr, cache_key return qr, cache_version
...@@ -181,8 +181,8 @@ def numba_funcify_ExtractDiag(op, node, **kwargs): ...@@ -181,8 +181,8 @@ def numba_funcify_ExtractDiag(op, node, **kwargs):
out[..., i] = new_entry out[..., i] = new_entry
return out return out
cache_key = 1 cache_version = 1
return extract_diag, cache_key return extract_diag, cache_version
@register_funcify_default_op_cache_key(Eye) @register_funcify_default_op_cache_key(Eye)
......
...@@ -20,14 +20,13 @@ from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_so ...@@ -20,14 +20,13 @@ from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_so
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
def decompose_A(A, assume_a, check_finite, lower): def decompose_A(A, assume_a, lower):
if assume_a == "gen": if assume_a == "gen":
return lu_factor(A, check_finite=check_finite) return lu_factor(A)
elif assume_a == "tridiagonal": elif assume_a == "tridiagonal":
# We didn't implement check_finite for tridiagonal LU factorization
return tridiagonal_lu_factor(A) return tridiagonal_lu_factor(A)
elif assume_a == "pos": elif assume_a == "pos":
return cholesky(A, lower=lower, check_finite=check_finite) return cholesky(A, lower=lower)
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -36,7 +35,6 @@ def solve_decomposed_system( ...@@ -36,7 +35,6 @@ def solve_decomposed_system(
A_decomp, b, transposed=False, lower=False, *, core_solve_op: Solve A_decomp, b, transposed=False, lower=False, *, core_solve_op: Solve
): ):
b_ndim = core_solve_op.b_ndim b_ndim = core_solve_op.b_ndim
check_finite = core_solve_op.check_finite
assume_a = core_solve_op.assume_a assume_a = core_solve_op.assume_a
if assume_a == "gen": if assume_a == "gen":
...@@ -45,10 +43,8 @@ def solve_decomposed_system( ...@@ -45,10 +43,8 @@ def solve_decomposed_system(
b, b,
b_ndim=b_ndim, b_ndim=b_ndim,
trans=transposed, trans=transposed,
check_finite=check_finite,
) )
elif assume_a == "tridiagonal": elif assume_a == "tridiagonal":
# We didn't implement check_finite for tridiagonal LU solve
return tridiagonal_lu_solve( return tridiagonal_lu_solve(
A_decomp, A_decomp,
b, b,
...@@ -61,7 +57,6 @@ def solve_decomposed_system( ...@@ -61,7 +57,6 @@ def solve_decomposed_system(
(A_decomp, lower), (A_decomp, lower),
b, b,
b_ndim=b_ndim, b_ndim=b_ndim,
check_finite=check_finite,
) )
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -141,17 +136,8 @@ def _split_decomp_and_solve_steps( ...@@ -141,17 +136,8 @@ def _split_decomp_and_solve_steps(
): ):
return None return None
# If any Op had check_finite=True, we also do it for the LU decomposition
check_finite_decomp = False
for client, _ in A_solve_clients_and_transpose:
if client.op.core_op.check_finite:
check_finite_decomp = True
break
lower = node.op.core_op.lower lower = node.op.core_op.lower
A_decomp = decompose_A( A_decomp = decompose_A(A, assume_a=assume_a, lower=lower)
A, assume_a=assume_a, check_finite=check_finite_decomp, lower=lower
)
replacements = {} replacements = {}
for client, transposed in A_solve_clients_and_transpose: for client, transposed in A_solve_clients_and_transpose:
......
...@@ -6,7 +6,7 @@ from typing import Literal, cast ...@@ -6,7 +6,7 @@ from typing import Literal, cast
import numpy as np import numpy as np
import scipy.linalg as scipy_linalg import scipy.linalg as scipy_linalg
from scipy.linalg import LinAlgError, LinAlgWarning, get_lapack_funcs from scipy.linalg import get_lapack_funcs
import pytensor import pytensor
from pytensor import ifelse from pytensor import ifelse
...@@ -14,7 +14,7 @@ from pytensor import tensor as pt ...@@ -14,7 +14,7 @@ from pytensor import tensor as pt
from pytensor.gradient import DisconnectedType from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.raise_op import Assert from pytensor.raise_op import Assert, CheckAndRaise
from pytensor.tensor import TensorLike from pytensor.tensor import TensorLike
from pytensor.tensor import basic as ptb from pytensor.tensor import basic as ptb
from pytensor.tensor import math as ptm from pytensor.tensor import math as ptm
...@@ -32,22 +32,16 @@ logger = logging.getLogger(__name__) ...@@ -32,22 +32,16 @@ logger = logging.getLogger(__name__)
class Cholesky(Op): class Cholesky(Op):
# TODO: LAPACK wrapper with in-place behavior, for solve also # TODO: LAPACK wrapper with in-place behavior, for solve also
__props__ = ("lower", "check_finite", "on_error", "overwrite_a") __props__ = ("lower", "overwrite_a")
gufunc_signature = "(m,m)->(m,m)" gufunc_signature = "(m,m)->(m,m)"
def __init__( def __init__(
self, self,
*, *,
lower: bool = True, lower: bool = True,
check_finite: bool = False,
on_error: Literal["raise", "nan"] = "raise",
overwrite_a: bool = False, overwrite_a: bool = False,
): ):
self.lower = lower self.lower = lower
self.check_finite = check_finite
if on_error not in ("raise", "nan"):
raise ValueError('on_error must be one of "raise" or ""nan"')
self.on_error = on_error
self.overwrite_a = overwrite_a self.overwrite_a = overwrite_a
if self.overwrite_a: if self.overwrite_a:
...@@ -77,13 +71,6 @@ class Cholesky(Op): ...@@ -77,13 +71,6 @@ class Cholesky(Op):
out[0] = np.empty_like(x, dtype=potrf.dtype) out[0] = np.empty_like(x, dtype=potrf.dtype)
return return
if self.check_finite and not np.isfinite(x).all():
if self.on_error == "nan":
out[0] = np.full(x.shape, np.nan, dtype=potrf.dtype)
return
else:
raise ValueError("array must not contain infs or NaNs")
# Squareness check # Squareness check
if x.shape[0] != x.shape[1]: if x.shape[0] != x.shape[1]:
raise ValueError( raise ValueError(
...@@ -104,17 +91,8 @@ class Cholesky(Op): ...@@ -104,17 +91,8 @@ class Cholesky(Op):
c, info = potrf(x, lower=lower, overwrite_a=overwrite_a, clean=True) c, info = potrf(x, lower=lower, overwrite_a=overwrite_a, clean=True)
if info != 0: if info != 0:
if self.on_error == "nan": c[...] = np.nan
out[0] = np.full(x.shape, np.nan, dtype=node.outputs[0].type.dtype) out[0] = c
elif info > 0:
raise scipy_linalg.LinAlgError(
f"{info}-th leading minor of the array is not positive definite"
)
elif info < 0:
raise ValueError(
f"LAPACK reported an illegal value in {-info}-th argument "
f'on entry to "POTRF".'
)
else: else:
# Transpose result if input was transposed # Transpose result if input was transposed
out[0] = c.T if c_contiguous_input else c out[0] = c.T if c_contiguous_input else c
...@@ -135,13 +113,6 @@ class Cholesky(Op): ...@@ -135,13 +113,6 @@ class Cholesky(Op):
dz = gradients[0] dz = gradients[0]
chol_x = outputs[0] chol_x = outputs[0]
# Replace the cholesky decomposition with 1 if there are nans
# or solve_upper_triangular will throw a ValueError.
if self.on_error == "nan":
ok = ~ptm.any(ptm.isnan(chol_x))
chol_x = ptb.switch(ok, chol_x, 1)
dz = ptb.switch(ok, dz, 1)
# deal with upper triangular by converting to lower triangular # deal with upper triangular by converting to lower triangular
if not self.lower: if not self.lower:
chol_x = chol_x.T chol_x = chol_x.T
...@@ -165,10 +136,7 @@ class Cholesky(Op): ...@@ -165,10 +136,7 @@ class Cholesky(Op):
else: else:
grad = ptb.triu(s + s.T) - ptb.diag(ptb.diagonal(s)) grad = ptb.triu(s + s.T) - ptb.diag(ptb.diagonal(s))
if self.on_error == "nan": return [grad]
return [ptb.switch(ok, grad, np.nan)]
else:
return [grad]
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
if not allowed_inplace_inputs: if not allowed_inplace_inputs:
...@@ -182,9 +150,9 @@ def cholesky( ...@@ -182,9 +150,9 @@ def cholesky(
x: "TensorLike", x: "TensorLike",
lower: bool = True, lower: bool = True,
*, *,
check_finite: bool = False, check_finite: bool = True,
overwrite_a: bool = False, overwrite_a: bool = False,
on_error: Literal["raise", "nan"] = "raise", on_error: Literal["raise", "nan"] = "nan",
): ):
""" """
Return a triangular matrix square root of positive semi-definite `x`. Return a triangular matrix square root of positive semi-definite `x`.
...@@ -196,8 +164,8 @@ def cholesky( ...@@ -196,8 +164,8 @@ def cholesky(
x: tensor_like x: tensor_like
lower : bool, default=True lower : bool, default=True
Whether to return the lower or upper cholesky factor Whether to return the lower or upper cholesky factor
check_finite : bool, default=False check_finite : bool
Whether to check that the input matrix contains only finite numbers. Unused by PyTensor. PyTensor will return nan if the operation fails.
overwrite_a: bool, ignored overwrite_a: bool, ignored
Whether to use the same memory for the output as `a`. This argument is ignored, and is present here only Whether to use the same memory for the output as `a`. This argument is ignored, and is present here only
for consistency with scipy.linalg.cholesky. for consistency with scipy.linalg.cholesky.
...@@ -228,10 +196,19 @@ def cholesky( ...@@ -228,10 +196,19 @@ def cholesky(
assert np.allclose(L_value @ L_value.T, x_value) assert np.allclose(L_value @ L_value.T, x_value)
""" """
res = Blockwise(Cholesky(lower=lower))(x)
return Blockwise( if on_error == "raise":
Cholesky(lower=lower, on_error=on_error, check_finite=check_finite) # For back-compatibility
)(x) warnings.warn(
'Cholesky on_raise == "raise" is deprecated. The operation will return nan when in fails. Setting this argument will fail in the future',
FutureWarning,
)
res = CheckAndRaise(np.linalg.LinAlgError, "Matrix is not positive definite")(
res, ~ptm.isnan(res).any()
)
return res
class SolveBase(Op): class SolveBase(Op):
...@@ -239,7 +216,6 @@ class SolveBase(Op): ...@@ -239,7 +216,6 @@ class SolveBase(Op):
__props__: tuple[str, ...] = ( __props__: tuple[str, ...] = (
"lower", "lower",
"check_finite",
"b_ndim", "b_ndim",
"overwrite_a", "overwrite_a",
"overwrite_b", "overwrite_b",
...@@ -249,13 +225,11 @@ class SolveBase(Op): ...@@ -249,13 +225,11 @@ class SolveBase(Op):
self, self,
*, *,
lower=False, lower=False,
check_finite=True,
b_ndim, b_ndim,
overwrite_a=False, overwrite_a=False,
overwrite_b=False, overwrite_b=False,
): ):
self.lower = lower self.lower = lower
self.check_finite = check_finite
assert b_ndim in (1, 2) assert b_ndim in (1, 2)
self.b_ndim = b_ndim self.b_ndim = b_ndim
...@@ -358,7 +332,6 @@ def _default_b_ndim(b, b_ndim): ...@@ -358,7 +332,6 @@ def _default_b_ndim(b, b_ndim):
class CholeskySolve(SolveBase): class CholeskySolve(SolveBase):
__props__ = ( __props__ = (
"lower", "lower",
"check_finite",
"b_ndim", "b_ndim",
"overwrite_b", "overwrite_b",
) )
...@@ -366,7 +339,6 @@ class CholeskySolve(SolveBase): ...@@ -366,7 +339,6 @@ class CholeskySolve(SolveBase):
def __init__(self, **kwargs): def __init__(self, **kwargs):
if kwargs.get("overwrite_a", False): if kwargs.get("overwrite_a", False):
raise ValueError("overwrite_a is not supported for CholeskySolve") raise ValueError("overwrite_a is not supported for CholeskySolve")
kwargs.setdefault("lower", True)
super().__init__(**kwargs) super().__init__(**kwargs)
def make_node(self, *inputs): def make_node(self, *inputs):
...@@ -387,9 +359,6 @@ class CholeskySolve(SolveBase): ...@@ -387,9 +359,6 @@ class CholeskySolve(SolveBase):
(potrs,) = get_lapack_funcs(("potrs",), (c, b)) (potrs,) = get_lapack_funcs(("potrs",), (c, b))
if self.check_finite and not (np.isfinite(c).all() and np.isfinite(b).all()):
raise ValueError("array must not contain infs or NaNs")
if c.shape[0] != c.shape[1]: if c.shape[0] != c.shape[1]:
raise ValueError("The factored matrix c is not square.") raise ValueError("The factored matrix c is not square.")
if c.shape[1] != b.shape[0]: if c.shape[1] != b.shape[0]:
...@@ -402,7 +371,7 @@ class CholeskySolve(SolveBase): ...@@ -402,7 +371,7 @@ class CholeskySolve(SolveBase):
x, info = potrs(c, b, lower=self.lower, overwrite_b=self.overwrite_b) x, info = potrs(c, b, lower=self.lower, overwrite_b=self.overwrite_b)
if info != 0: if info != 0:
raise ValueError(f"illegal value in {-info}th argument of internal potrs") x[...] = np.nan
output_storage[0][0] = x output_storage[0][0] = x
...@@ -423,7 +392,6 @@ def cho_solve( ...@@ -423,7 +392,6 @@ def cho_solve(
c_and_lower: tuple[TensorLike, bool], c_and_lower: tuple[TensorLike, bool],
b: TensorLike, b: TensorLike,
*, *,
check_finite: bool = True,
b_ndim: int | None = None, b_ndim: int | None = None,
): ):
"""Solve the linear equations A x = b, given the Cholesky factorization of A. """Solve the linear equations A x = b, given the Cholesky factorization of A.
...@@ -434,33 +402,26 @@ def cho_solve( ...@@ -434,33 +402,26 @@ def cho_solve(
Cholesky factorization of a, as given by cho_factor Cholesky factorization of a, as given by cho_factor
b : TensorLike b : TensorLike
Right-hand side Right-hand side
check_finite : bool, optional check_finite : bool
Whether to check that the input matrices contain only finite numbers. Unused by PyTensor. PyTensor will return nan if the operation fails.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
b_ndim : int b_ndim : int
Whether the core case of b is a vector (1) or matrix (2). Whether the core case of b is a vector (1) or matrix (2).
This will influence how batched dimensions are interpreted. This will influence how batched dimensions are interpreted.
""" """
A, lower = c_and_lower A, lower = c_and_lower
b_ndim = _default_b_ndim(b, b_ndim) b_ndim = _default_b_ndim(b, b_ndim)
return Blockwise( return Blockwise(CholeskySolve(lower=lower, b_ndim=b_ndim))(A, b)
CholeskySolve(lower=lower, check_finite=check_finite, b_ndim=b_ndim)
)(A, b)
class LU(Op): class LU(Op):
"""Decompose a matrix into lower and upper triangular matrices.""" """Decompose a matrix into lower and upper triangular matrices."""
__props__ = ("permute_l", "overwrite_a", "check_finite", "p_indices") __props__ = ("permute_l", "overwrite_a", "p_indices")
def __init__( def __init__(self, *, permute_l=False, overwrite_a=False, p_indices=False):
self, *, permute_l=False, overwrite_a=False, check_finite=True, p_indices=False
):
if permute_l and p_indices: if permute_l and p_indices:
raise ValueError("Only one of permute_l and p_indices can be True") raise ValueError("Only one of permute_l and p_indices can be True")
self.permute_l = permute_l self.permute_l = permute_l
self.check_finite = check_finite
self.p_indices = p_indices self.p_indices = p_indices
self.overwrite_a = overwrite_a self.overwrite_a = overwrite_a
...@@ -523,7 +484,6 @@ class LU(Op): ...@@ -523,7 +484,6 @@ class LU(Op):
A, A,
permute_l=self.permute_l, permute_l=self.permute_l,
overwrite_a=self.overwrite_a, overwrite_a=self.overwrite_a,
check_finite=self.check_finite,
p_indices=self.p_indices, p_indices=self.p_indices,
) )
...@@ -563,7 +523,7 @@ class LU(Op): ...@@ -563,7 +523,7 @@ class LU(Op):
# TODO: Rewrite into permute_l = False for graphs where we need to compute the gradient # TODO: Rewrite into permute_l = False for graphs where we need to compute the gradient
# We need L, not PL. It's not possible to recover it from PL, though. So we need to do a new forward pass # We need L, not PL. It's not possible to recover it from PL, though. So we need to do a new forward pass
P_or_indices, L, U = lu( # type: ignore P_or_indices, L, U = lu( # type: ignore
A, permute_l=False, check_finite=self.check_finite, p_indices=False A, permute_l=False, p_indices=False
) )
else: else:
...@@ -621,8 +581,8 @@ def lu( ...@@ -621,8 +581,8 @@ def lu(
permute_l: bool permute_l: bool
If True, L is a product of permutation and unit lower triangular matrices. Only two values, PL and U, will If True, L is a product of permutation and unit lower triangular matrices. Only two values, PL and U, will
be returned in this case, and PL will not be lower triangular. be returned in this case, and PL will not be lower triangular.
check_finite: bool check_finite : bool
Whether to check that the input matrix contains only finite numbers. Unused by PyTensor. PyTensor will return nan if the operation fails.
p_indices: bool p_indices: bool
If True, return integer matrix indices for the permutation matrix. Otherwise, return the permutation matrix If True, return integer matrix indices for the permutation matrix. Otherwise, return the permutation matrix
itself. itself.
...@@ -640,9 +600,7 @@ def lu( ...@@ -640,9 +600,7 @@ def lu(
return cast( return cast(
tuple[TensorVariable, TensorVariable, TensorVariable] tuple[TensorVariable, TensorVariable, TensorVariable]
| tuple[TensorVariable, TensorVariable], | tuple[TensorVariable, TensorVariable],
Blockwise( Blockwise(LU(permute_l=permute_l, p_indices=p_indices))(a),
LU(permute_l=permute_l, p_indices=p_indices, check_finite=check_finite)
)(a),
) )
...@@ -680,12 +638,11 @@ def pivot_to_permutation(p: TensorLike, inverse=False): ...@@ -680,12 +638,11 @@ def pivot_to_permutation(p: TensorLike, inverse=False):
class LUFactor(Op): class LUFactor(Op):
__props__ = ("overwrite_a", "check_finite") __props__ = ("overwrite_a",)
gufunc_signature = "(m,m)->(m,m),(m)" gufunc_signature = "(m,m)->(m,m),(m)"
def __init__(self, *, overwrite_a=False, check_finite=True): def __init__(self, *, overwrite_a=False):
self.overwrite_a = overwrite_a self.overwrite_a = overwrite_a
self.check_finite = check_finite
if self.overwrite_a: if self.overwrite_a:
self.destroy_map = {1: [0]} self.destroy_map = {1: [0]}
...@@ -723,21 +680,10 @@ class LUFactor(Op): ...@@ -723,21 +680,10 @@ class LUFactor(Op):
outputs[1][0] = np.array([], dtype=np.int32) outputs[1][0] = np.array([], dtype=np.int32)
return return
if self.check_finite and not np.isfinite(A).all():
raise ValueError("array must not contain infs or NaNs")
(getrf,) = get_lapack_funcs(("getrf",), (A,)) (getrf,) = get_lapack_funcs(("getrf",), (A,))
LU, p, info = getrf(A, overwrite_a=self.overwrite_a) LU, p, info = getrf(A, overwrite_a=self.overwrite_a)
if info < 0: if info != 0:
raise ValueError( LU[...] = np.nan
f"illegal value in {-info}th argument of internal getrf (lu_factor)"
)
if info > 0:
warnings.warn(
f"Diagonal number {info} is exactly zero. Singular matrix.",
LinAlgWarning,
stacklevel=2,
)
outputs[0][0] = LU outputs[0][0] = LU
outputs[1][0] = p outputs[1][0] = p
...@@ -782,7 +728,7 @@ def lu_factor( ...@@ -782,7 +728,7 @@ def lu_factor(
a: TensorLike a: TensorLike
Matrix to be factorized Matrix to be factorized
check_finite: bool check_finite: bool
Whether to check that the input matrix contains only finite numbers. Unused by PyTensor. PyTensor will return nan if the operation fails.
overwrite_a: bool overwrite_a: bool
Unused by PyTensor. PyTensor will always perform the operation in-place if possible. Unused by PyTensor. PyTensor will always perform the operation in-place if possible.
...@@ -796,7 +742,7 @@ def lu_factor( ...@@ -796,7 +742,7 @@ def lu_factor(
return cast( return cast(
tuple[TensorVariable, TensorVariable], tuple[TensorVariable, TensorVariable],
Blockwise(LUFactor(check_finite=check_finite))(a), Blockwise(LUFactor())(a),
) )
...@@ -806,7 +752,6 @@ def _lu_solve( ...@@ -806,7 +752,6 @@ def _lu_solve(
b: TensorLike, b: TensorLike,
trans: bool = False, trans: bool = False,
b_ndim: int | None = None, b_ndim: int | None = None,
check_finite: bool = True,
): ):
b_ndim = _default_b_ndim(b, b_ndim) b_ndim = _default_b_ndim(b, b_ndim)
...@@ -824,7 +769,6 @@ def _lu_solve( ...@@ -824,7 +769,6 @@ def _lu_solve(
unit_diagonal=not trans, unit_diagonal=not trans,
trans=trans, trans=trans,
b_ndim=b_ndim, b_ndim=b_ndim,
check_finite=check_finite,
) )
x = solve_triangular( x = solve_triangular(
...@@ -834,7 +778,6 @@ def _lu_solve( ...@@ -834,7 +778,6 @@ def _lu_solve(
unit_diagonal=trans, unit_diagonal=trans,
trans=trans, trans=trans,
b_ndim=b_ndim, b_ndim=b_ndim,
check_finite=check_finite,
) )
# TODO: Use PermuteRows(inverse=True) on x # TODO: Use PermuteRows(inverse=True) on x
...@@ -867,7 +810,7 @@ def lu_solve( ...@@ -867,7 +810,7 @@ def lu_solve(
The number of core dimensions in b. Used to distinguish between a batch of vectors (b_ndim=1) and a matrix The number of core dimensions in b. Used to distinguish between a batch of vectors (b_ndim=1) and a matrix
of vectors (b_ndim=2). Default is None, which will infer the number of core dimensions from the input. of vectors (b_ndim=2). Default is None, which will infer the number of core dimensions from the input.
check_finite: bool check_finite: bool
If True, check that the input matrices contain only finite numbers. Default is True. Unused by PyTensor. PyTensor will return nan if the operation fails.
overwrite_b: bool overwrite_b: bool
Ignored by Pytensor. Pytensor will always compute inplace when possible. Ignored by Pytensor. Pytensor will always compute inplace when possible.
""" """
...@@ -876,9 +819,7 @@ def lu_solve( ...@@ -876,9 +819,7 @@ def lu_solve(
signature = "(m,m),(m),(m)->(m)" signature = "(m,m),(m),(m)->(m)"
else: else:
signature = "(m,m),(m),(m,n)->(m,n)" signature = "(m,m),(m),(m,n)->(m,n)"
partialled_func = partial( partialled_func = partial(_lu_solve, trans=trans, b_ndim=b_ndim)
_lu_solve, trans=trans, b_ndim=b_ndim, check_finite=check_finite
)
return pt.vectorize(partialled_func, signature=signature)(*LU_and_pivots, b) return pt.vectorize(partialled_func, signature=signature)(*LU_and_pivots, b)
...@@ -888,7 +829,6 @@ class SolveTriangular(SolveBase): ...@@ -888,7 +829,6 @@ class SolveTriangular(SolveBase):
__props__ = ( __props__ = (
"unit_diagonal", "unit_diagonal",
"lower", "lower",
"check_finite",
"b_ndim", "b_ndim",
"overwrite_b", "overwrite_b",
) )
...@@ -905,10 +845,7 @@ class SolveTriangular(SolveBase): ...@@ -905,10 +845,7 @@ class SolveTriangular(SolveBase):
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
A, b = inputs A, b = inputs
if self.check_finite and not (np.isfinite(A).all() and np.isfinite(b).all()): if A.ndim != 2 or A.shape[0] != A.shape[1]:
raise ValueError("array must not contain infs or NaNs")
if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
raise ValueError("expected square matrix") raise ValueError("expected square matrix")
if A.shape[0] != b.shape[0]: if A.shape[0] != b.shape[0]:
...@@ -941,12 +878,8 @@ class SolveTriangular(SolveBase): ...@@ -941,12 +878,8 @@ class SolveTriangular(SolveBase):
unitdiag=self.unit_diagonal, unitdiag=self.unit_diagonal,
) )
if info > 0: if info != 0:
raise LinAlgError( x[...] = np.nan
f"singular matrix: resolution failed at diagonal {info - 1}"
)
elif info < 0:
raise ValueError(f"illegal value in {-info}-th argument of internal trtrs")
outputs[0][0] = x outputs[0][0] = x
...@@ -998,9 +931,7 @@ def solve_triangular( ...@@ -998,9 +931,7 @@ def solve_triangular(
unit_diagonal: bool, optional unit_diagonal: bool, optional
If True, diagonal elements of `a` are assumed to be 1 and will not be referenced. If True, diagonal elements of `a` are assumed to be 1 and will not be referenced.
check_finite : bool, optional check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers. Unused by PyTensor. PyTensor will return nan if the operation fails.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
b_ndim : int b_ndim : int
Whether the core case of b is a vector (1) or matrix (2). Whether the core case of b is a vector (1) or matrix (2).
This will influence how batched dimensions are interpreted. This will influence how batched dimensions are interpreted.
...@@ -1018,7 +949,6 @@ def solve_triangular( ...@@ -1018,7 +949,6 @@ def solve_triangular(
SolveTriangular( SolveTriangular(
lower=lower, lower=lower,
unit_diagonal=unit_diagonal, unit_diagonal=unit_diagonal,
check_finite=check_finite,
b_ndim=b_ndim, b_ndim=b_ndim,
) )
)(a, b) )(a, b)
...@@ -1033,7 +963,6 @@ class Solve(SolveBase): ...@@ -1033,7 +963,6 @@ class Solve(SolveBase):
__props__ = ( __props__ = (
"assume_a", "assume_a",
"lower", "lower",
"check_finite",
"b_ndim", "b_ndim",
"overwrite_a", "overwrite_a",
"overwrite_b", "overwrite_b",
...@@ -1073,15 +1002,18 @@ class Solve(SolveBase): ...@@ -1073,15 +1002,18 @@ class Solve(SolveBase):
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
a, b = inputs a, b = inputs
outputs[0][0] = scipy_linalg.solve( try:
a=a, outputs[0][0] = scipy_linalg.solve(
b=b, a=a,
lower=self.lower, b=b,
check_finite=self.check_finite, lower=self.lower,
assume_a=self.assume_a, check_finite=False,
overwrite_a=self.overwrite_a, assume_a=self.assume_a,
overwrite_b=self.overwrite_b, overwrite_a=self.overwrite_a,
) overwrite_b=self.overwrite_b,
)
except np.linalg.LinAlgError:
outputs[0][0] = np.full(a.shape, np.nan, dtype=a.dtype)
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
if not allowed_inplace_inputs: if not allowed_inplace_inputs:
...@@ -1152,10 +1084,8 @@ def solve( ...@@ -1152,10 +1084,8 @@ def solve(
Unused by PyTensor. PyTensor will always perform the operation in-place if possible. Unused by PyTensor. PyTensor will always perform the operation in-place if possible.
overwrite_b : bool overwrite_b : bool
Unused by PyTensor. PyTensor will always perform the operation in-place if possible. Unused by PyTensor. PyTensor will always perform the operation in-place if possible.
check_finite : bool, optional check_finite : bool
Whether to check that the input matrices contain only finite numbers. Unused by PyTensor. PyTensor returns nan if the operation fails.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
assume_a : str, optional assume_a : str, optional
Valid entries are explained above. Valid entries are explained above.
transposed: bool, default False transposed: bool, default False
...@@ -1174,7 +1104,6 @@ def solve( ...@@ -1174,7 +1104,6 @@ def solve(
b, b,
lower=lower, lower=lower,
trans=transposed, trans=transposed,
check_finite=check_finite,
b_ndim=b_ndim, b_ndim=b_ndim,
) )
...@@ -1195,7 +1124,6 @@ def solve( ...@@ -1195,7 +1124,6 @@ def solve(
return Blockwise( return Blockwise(
Solve( Solve(
lower=lower, lower=lower,
check_finite=check_finite,
assume_a=assume_a, assume_a=assume_a,
b_ndim=b_ndim, b_ndim=b_ndim,
) )
...@@ -1779,7 +1707,6 @@ class QR(Op): ...@@ -1779,7 +1707,6 @@ class QR(Op):
"overwrite_a", "overwrite_a",
"mode", "mode",
"pivoting", "pivoting",
"check_finite",
) )
def __init__( def __init__(
...@@ -1787,12 +1714,10 @@ class QR(Op): ...@@ -1787,12 +1714,10 @@ class QR(Op):
mode: Literal["full", "r", "economic", "raw"] = "full", mode: Literal["full", "r", "economic", "raw"] = "full",
overwrite_a: bool = False, overwrite_a: bool = False,
pivoting: bool = False, pivoting: bool = False,
check_finite: bool = False,
): ):
self.mode = mode self.mode = mode
self.overwrite_a = overwrite_a self.overwrite_a = overwrite_a
self.pivoting = pivoting self.pivoting = pivoting
self.check_finite = check_finite
self.destroy_map = {} self.destroy_map = {}
......
from collections.abc import Sequence from collections.abc import Sequence
from typing import Literal
from pytensor.tensor.slinalg import Cholesky, Solve from pytensor.tensor.slinalg import Cholesky, Solve
from pytensor.xtensor.type import as_xtensor from pytensor.xtensor.type import as_xtensor
...@@ -10,8 +9,7 @@ def cholesky( ...@@ -10,8 +9,7 @@ def cholesky(
x, x,
lower: bool = True, lower: bool = True,
*, *,
check_finite: bool = False, check_finite: bool = True,
on_error: Literal["raise", "nan"] = "raise",
dims: Sequence[str], dims: Sequence[str],
): ):
"""Compute the Cholesky decomposition of an XTensorVariable. """Compute the Cholesky decomposition of an XTensorVariable.
...@@ -22,22 +20,15 @@ def cholesky( ...@@ -22,22 +20,15 @@ def cholesky(
The input variable to decompose. The input variable to decompose.
lower : bool, optional lower : bool, optional
Whether to return the lower triangular matrix. Default is True. Whether to return the lower triangular matrix. Default is True.
check_finite : bool, optional check_finite : bool
Whether to check that the input is finite. Default is False. Unused by PyTensor. PyTensor will return nan if the operation fails.
on_error : {'raise', 'nan'}, optional
What to do if the input is not positive definite. If 'raise', an error is raised.
If 'nan', the output will contain NaNs. Default is 'raise'.
dims : Sequence[str] dims : Sequence[str]
The two core dimensions of the input variable, over which the Cholesky decomposition is computed. The two core dimensions of the input variable, over which the Cholesky decomposition is computed.
""" """
if len(dims) != 2: if len(dims) != 2:
raise ValueError(f"Cholesky needs two dims, got {len(dims)}") raise ValueError(f"Cholesky needs two dims, got {len(dims)}")
core_op = Cholesky( core_op = Cholesky(lower=lower)
lower=lower,
check_finite=check_finite,
on_error=on_error,
)
core_dims = ( core_dims = (
((dims[0], dims[1]),), ((dims[0], dims[1]),),
((dims[0], dims[1]),), ((dims[0], dims[1]),),
...@@ -52,7 +43,7 @@ def solve( ...@@ -52,7 +43,7 @@ def solve(
dims: Sequence[str], dims: Sequence[str],
assume_a="gen", assume_a="gen",
lower: bool = False, lower: bool = False,
check_finite: bool = False, check_finite: bool = True,
): ):
"""Solve a system of linear equations using XTensorVariables. """Solve a system of linear equations using XTensorVariables.
...@@ -75,8 +66,8 @@ def solve( ...@@ -75,8 +66,8 @@ def solve(
Long form options can also be used ["general", "symmetric", "hermitian", "positive_definite"]. Long form options can also be used ["general", "symmetric", "hermitian", "positive_definite"].
lower : bool, optional lower : bool, optional
Whether `a` is lower triangular. Default is False. Only relevant if `assume_a` is "sym", "her", or "pos". Whether `a` is lower triangular. Default is False. Only relevant if `assume_a` is "sym", "her", or "pos".
check_finite : bool, optional check_finite : bool
Whether to check that the input is finite. Default is False. Unused by PyTensor. PyTensor will return nan if the operation fails.
""" """
a, b = as_xtensor(a), as_xtensor(b) a, b = as_xtensor(a), as_xtensor(b)
input_core_dims: tuple[tuple[str, str], tuple[str] | tuple[str, str]] input_core_dims: tuple[tuple[str, str], tuple[str] | tuple[str, str]]
...@@ -98,9 +89,7 @@ def solve( ...@@ -98,9 +89,7 @@ def solve(
else: else:
raise ValueError("Solve dims must have length 2 or 3") raise ValueError("Solve dims must have length 2 or 3")
core_op = Solve( core_op = Solve(b_ndim=b_ndim, assume_a=assume_a, lower=lower)
b_ndim=b_ndim, assume_a=assume_a, lower=lower, check_finite=check_finite
)
x_op = XBlockwise( x_op = XBlockwise(
core_op, core_op,
core_dims=(input_core_dims, output_core_dims), core_dims=(input_core_dims, output_core_dims),
......
import re
from typing import Literal from typing import Literal
import numpy as np import numpy as np
...@@ -36,70 +35,6 @@ floatX = config.floatX ...@@ -36,70 +35,6 @@ floatX = config.floatX
rng = np.random.default_rng(42849) rng = np.random.default_rng(42849)
def test_lamch():
from scipy.linalg import get_lapack_funcs
from pytensor.link.numba.dispatch.linalg.utils import _xlamch
@numba.njit()
def xlamch(kind):
return _xlamch(kind)
lamch = get_lapack_funcs("lamch", (np.array([0.0], dtype=floatX),))
np.testing.assert_allclose(xlamch("E"), lamch("E"))
np.testing.assert_allclose(xlamch("S"), lamch("S"))
np.testing.assert_allclose(xlamch("P"), lamch("P"))
np.testing.assert_allclose(xlamch("B"), lamch("B"))
np.testing.assert_allclose(xlamch("R"), lamch("R"))
np.testing.assert_allclose(xlamch("M"), lamch("M"))
@pytest.mark.parametrize(
"ord_numba, ord_scipy", [("F", "fro"), ("1", 1), ("I", np.inf)]
)
def test_xlange(ord_numba, ord_scipy):
# xlange is called internally only, we don't dispatch pt.linalg.norm to it
from scipy import linalg
from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange
@numba.njit()
def xlange(x, ord):
return _xlange(x, ord)
x = np.random.normal(size=(5, 5)).astype(floatX)
np.testing.assert_allclose(xlange(x, ord_numba), linalg.norm(x, ord_scipy))
@pytest.mark.parametrize("ord_numba, ord_scipy", [("1", 1), ("I", np.inf)])
def test_xgecon(ord_numba, ord_scipy):
# gecon is called internally only, we don't dispatch pt.linalg.norm to it
from scipy.linalg import get_lapack_funcs
from pytensor.link.numba.dispatch.linalg.solve.general import _xgecon
from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange
@numba.njit()
def gecon(x, norm):
anorm = _xlange(x, norm)
cond, info = _xgecon(x, anorm, norm)
return cond, info
x = np.random.normal(size=(5, 5)).astype(floatX)
rcond, info = gecon(x, norm=ord_numba)
# Test against direct call to the underlying LAPACK functions
# Solution does **not** agree with 1 / np.linalg.cond(x) !
lange, gecon = get_lapack_funcs(("lange", "gecon"), (x,))
norm = lange(ord_numba, x)
rcond2, _ = gecon(x, norm, norm=ord_numba)
assert info == 0
np.testing.assert_allclose(rcond, rcond2)
class TestSolves: class TestSolves:
@pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower={x}") @pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower={x}")
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -323,7 +258,7 @@ class TestSolves: ...@@ -323,7 +258,7 @@ class TestSolves:
np.testing.assert_allclose(b_val_not_contig, b_val) np.testing.assert_allclose(b_val_not_contig, b_val)
@pytest.mark.parametrize("value", [np.nan, np.inf]) @pytest.mark.parametrize("value", [np.nan, np.inf])
def test_solve_triangular_raises_on_nan_inf(self, value): def test_solve_triangular_does_not_raise_on_nan_inf(self, value):
A = pt.matrix("A") A = pt.matrix("A")
b = pt.matrix("b") b = pt.matrix("b")
...@@ -335,11 +270,8 @@ class TestSolves: ...@@ -335,11 +270,8 @@ class TestSolves:
A_tri = np.linalg.cholesky(A_sym).astype(floatX) A_tri = np.linalg.cholesky(A_sym).astype(floatX)
b = np.full((5, 1), value).astype(floatX) b = np.full((5, 1), value).astype(floatX)
with pytest.raises( # Not checking everything is nan, because, with inf, LAPACK returns a mix of inf/nan, but does not set info != 0
np.linalg.LinAlgError, assert not np.isfinite(f(A_tri, b)).any()
match=re.escape("Non-numeric values"),
):
f(A_tri, b)
@pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower = {x}") @pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower = {x}")
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -567,10 +499,13 @@ class TestDecompositions: ...@@ -567,10 +499,13 @@ class TestDecompositions:
x = pt.tensor(dtype=floatX, shape=(3, 3)) x = pt.tensor(dtype=floatX, shape=(3, 3))
x = x.T.dot(x) x = x.T.dot(x)
g = pt.linalg.cholesky(x, check_finite=True) with pytest.warns(FutureWarning):
g = pt.linalg.cholesky(x, check_finite=True, on_error="raise")
f = pytensor.function([x], g, mode="NUMBA") f = pytensor.function([x], g, mode="NUMBA")
with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"): with pytest.raises(
np.linalg.LinAlgError, match=r"Matrix is not positive definite"
):
f(test_value) f(test_value)
@pytest.mark.parametrize("on_error", ["nan", "raise"]) @pytest.mark.parametrize("on_error", ["nan", "raise"])
...@@ -578,13 +513,17 @@ class TestDecompositions: ...@@ -578,13 +513,17 @@ class TestDecompositions:
test_value = rng.random(size=(3, 3)).astype(floatX) test_value = rng.random(size=(3, 3)).astype(floatX)
x = pt.tensor(dtype=floatX, shape=(3, 3)) x = pt.tensor(dtype=floatX, shape=(3, 3))
g = pt.linalg.cholesky(x, on_error=on_error) if on_error == "raise":
with pytest.warns(FutureWarning):
g = pt.linalg.cholesky(x, on_error=on_error)
else:
g = pt.linalg.cholesky(x, on_error=on_error)
f = pytensor.function([x], g, mode="NUMBA") f = pytensor.function([x], g, mode="NUMBA")
if on_error == "raise": if on_error == "raise":
with pytest.raises( with pytest.raises(
np.linalg.LinAlgError, np.linalg.LinAlgError,
match=r"Input to cholesky is not positive definite", match=r"Matrix is not positive definite",
): ):
f(test_value) f(test_value)
else: else:
......
...@@ -213,47 +213,3 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed): ...@@ -213,47 +213,3 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed):
resx1 = fn_opt(A_test, x0_test) resx1 = fn_opt(A_test, x0_test)
rtol = 1e-7 if config.floatX == "float64" else 1e-4 rtol = 1e-7 if config.floatX == "float64" else 1e-4
np.testing.assert_allclose(resx0, resx1, rtol=rtol) np.testing.assert_allclose(resx0, resx1, rtol=rtol)
@pytest.mark.parametrize(
"assume_a, counter",
(
("gen", LUOpCounter),
("pos", CholeskyOpCounter),
),
)
def test_decomposition_reused_preserves_check_finite(assume_a, counter):
# Check that the LU decomposition rewrite preserves the check_finite flag
rewrite_name = reuse_decomposition_multiple_solves.__name__
A = tensor("A", shape=(2, 2))
b1 = tensor("b1", shape=(2,))
b2 = tensor("b2", shape=(2,))
x1 = solve(A, b1, assume_a=assume_a, check_finite=True)
x2 = solve(A, b2, assume_a=assume_a, check_finite=False)
fn_opt = function(
[A, b1, b2], [x1, x2], mode=get_default_mode().including(rewrite_name)
)
opt_nodes = fn_opt.maker.fgraph.apply_nodes
assert counter.count_vanilla_solve_nodes(opt_nodes) == 0
assert counter.count_decomp_nodes(opt_nodes) == 1
assert counter.count_solve_nodes(opt_nodes) == 2
# We should get an error if A or b1 is non finite
A_valid = np.array([[1, 0], [0, 1]], dtype=A.type.dtype)
b1_valid = np.array([1, 1], dtype=b1.type.dtype)
b2_valid = np.array([1, 1], dtype=b2.type.dtype)
assert fn_opt(A_valid, b1_valid, b2_valid) # Fine
assert fn_opt(
A_valid, b1_valid, b2_valid * np.nan
) # Should not raise (also fine on most LAPACK implementations?)
err_msg = (
"(array must not contain infs or NaNs"
r"|Non-numeric values \(nan or inf\))"
)
with pytest.raises((ValueError, np.linalg.LinAlgError), match=err_msg):
assert fn_opt(A_valid, b1_valid * np.nan, b2_valid)
with pytest.raises((ValueError, np.linalg.LinAlgError), match=err_msg):
assert fn_opt(A_valid * np.nan, b1_valid, b2_valid)
...@@ -74,9 +74,6 @@ def test_cholesky(): ...@@ -74,9 +74,6 @@ def test_cholesky():
chol = Cholesky(lower=False)(x) chol = Cholesky(lower=False)(x)
ch_f = function([x], chol) ch_f = function([x], chol)
check_upper_triangular(pd, ch_f) check_upper_triangular(pd, ch_f)
chol = Cholesky(lower=False, on_error="nan")(x)
ch_f = function([x], chol)
check_upper_triangular(pd, ch_f)
def test_cholesky_performance(benchmark): def test_cholesky_performance(benchmark):
...@@ -102,12 +99,15 @@ def test_cholesky_empty(): ...@@ -102,12 +99,15 @@ def test_cholesky_empty():
def test_cholesky_indef(): def test_cholesky_indef():
x = matrix() x = matrix()
mat = np.array([[1, 0.2], [0.2, -2]]).astype(config.floatX) mat = np.array([[1, 0.2], [0.2, -2]]).astype(config.floatX)
cholesky = Cholesky(lower=True, on_error="raise")
chol_f = function([x], cholesky(x)) with pytest.warns(FutureWarning):
out = cholesky(x, lower=True, on_error="raise")
chol_f = function([x], out)
with pytest.raises(scipy.linalg.LinAlgError): with pytest.raises(scipy.linalg.LinAlgError):
chol_f(mat) chol_f(mat)
cholesky = Cholesky(lower=True, on_error="nan")
chol_f = function([x], cholesky(x)) out = cholesky(x, lower=True, on_error="nan")
chol_f = function([x], out)
assert np.all(np.isnan(chol_f(mat))) assert np.all(np.isnan(chol_f(mat)))
...@@ -143,12 +143,16 @@ def test_cholesky_grad(): ...@@ -143,12 +143,16 @@ def test_cholesky_grad():
def test_cholesky_grad_indef(): def test_cholesky_grad_indef():
x = matrix() x = matrix()
mat = np.array([[1, 0.2], [0.2, -2]]).astype(config.floatX) mat = np.array([[1, 0.2], [0.2, -2]]).astype(config.floatX)
cholesky = Cholesky(lower=True, on_error="raise")
chol_f = function([x], grad(cholesky(x).sum(), [x])) with pytest.warns(FutureWarning):
with pytest.raises(scipy.linalg.LinAlgError): out = cholesky(x, lower=True, on_error="raise")
chol_f(mat) chol_f = function([x], grad(out.sum(), [x]), mode="FAST_RUN")
cholesky = Cholesky(lower=True, on_error="nan")
chol_f = function([x], grad(cholesky(x).sum(), [x])) # original cholesky doesn't show up in the grad (if mode="FAST_RUN"), so it does not raise
assert np.all(np.isnan(chol_f(mat)))
out = cholesky(x, lower=True, on_error="nan")
chol_f = function([x], grad(out.sum(), [x]))
assert np.all(np.isnan(chol_f(mat))) assert np.all(np.isnan(chol_f(mat)))
...@@ -237,7 +241,7 @@ class TestSolveBase: ...@@ -237,7 +241,7 @@ class TestSolveBase:
y = self.SolveTest(b_ndim=2)(A, b) y = self.SolveTest(b_ndim=2)(A, b)
assert ( assert (
y.__repr__() y.__repr__()
== "SolveTest{lower=False, check_finite=True, b_ndim=2, overwrite_a=False, overwrite_b=False}.0" == "SolveTest{lower=False, b_ndim=2, overwrite_a=False, overwrite_b=False}.0"
) )
...@@ -549,7 +553,7 @@ class TestCholeskySolve(utt.InferShapeTester): ...@@ -549,7 +553,7 @@ class TestCholeskySolve(utt.InferShapeTester):
def test_repr(self): def test_repr(self):
assert ( assert (
repr(CholeskySolve(lower=True, b_ndim=1)) repr(CholeskySolve(lower=True, b_ndim=1))
== "CholeskySolve(lower=True,check_finite=True,b_ndim=1,overwrite_b=False)" == "CholeskySolve(lower=True,b_ndim=1,overwrite_b=False)"
) )
def test_infer_shape(self): def test_infer_shape(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论