提交 19023545 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Refactor numba lapack codegen

上级 2774599e
...@@ -75,7 +75,7 @@ def numba_njit(*args, fastmath=None, **kwargs): ...@@ -75,7 +75,7 @@ def numba_njit(*args, fastmath=None, **kwargs):
message=( message=(
"(\x1b\\[1m)*" # ansi escape code for bold text "(\x1b\\[1m)*" # ansi escape code for bold text
"Cannot cache compiled function " "Cannot cache compiled function "
'"(numba_funcified_fgraph|store_core_outputs)" ' '"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve)" '
"as it uses dynamic globals" "as it uses dynamic globals"
), ),
category=NumbaWarning, category=NumbaWarning,
......
...@@ -390,3 +390,70 @@ class _LAPACK: ...@@ -390,3 +390,70 @@ class _LAPACK:
_ptr_int, # INFO _ptr_int, # INFO
) )
return functype(lapack_ptr) return functype(lapack_ptr)
@classmethod
def numba_xgttrf(cls, dtype):
"""
Compute the LU factorization of a tridiagonal matrix A using row interchanges.
Called by scipy.linalg.lu_factor
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "gttrf")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # N
float_pointer, # DL
float_pointer, # D
float_pointer, # DU
float_pointer, # DU2
_ptr_int, # IPIV
_ptr_int, # INFO
)
return functype(lapack_ptr)
@classmethod
def numba_xgttrs(cls, dtype):
"""
Solve a system of linear equations A @ X = B with a tridiagonal matrix A using the LU factorization computed by numba_gttrf.
Called by scipy.linalg.lu_solve
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "gttrs")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # TRANS
_ptr_int, # N
_ptr_int, # NRHS
float_pointer, # DL
float_pointer, # D
float_pointer, # DU
float_pointer, # DU2
_ptr_int, # IPIV
float_pointer, # B
_ptr_int, # LDB
_ptr_int, # INFO
)
return functype(lapack_ptr)
@classmethod
def numba_xgtcon(cls, dtype):
"""
Estimate the reciprocal of the condition number of a tridiagonal matrix A using the LU factorization computed by numba_gttrf.
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "gtcon")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # NORM
_ptr_int, # N
float_pointer, # DL
float_pointer, # D
float_pointer, # DU
float_pointer, # DU2
_ptr_int, # IPIV
float_pointer, # ANORM
float_pointer, # RCOND
float_pointer, # WORK
_ptr_int, # IWORK
_ptr_int, # INFO
)
return functype(lapack_ptr)
import numpy as np
from numba.core.extending import overload
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix
def _cholesky(a, lower=False, overwrite_a=False, check_finite=True):
return (
linalg.cholesky(
a, lower=lower, overwrite_a=overwrite_a, check_finite=check_finite
),
0,
)
@overload(_cholesky)
def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
ensure_lapack()
_check_scipy_linalg_matrix(A, "cholesky")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_potrf = _LAPACK().numba_xpotrf(dtype)
def impl(A, lower=0, overwrite_a=False, check_finite=True):
_N = np.int32(A.shape[-1])
if A.shape[-2] != _N:
raise linalg.LinAlgError("Last 2 dimensions of A must be square")
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
numba_potrf(
UPLO,
N,
A_copy.view(w_type).ctypes,
LDA,
INFO,
)
if lower:
for j in range(1, _N):
for i in range(j):
A_copy[i, j] = 0.0
else:
for j in range(_N):
for i in range(j + 1, _N):
A_copy[i, j] = 0.0
return A_copy, int_ptr_to_val(INFO)
return impl
import numpy as np
from numba.core.extending import overload
from numba.np.linalg import ensure_lapack
from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix,
_copy_to_fortran_order_even_if_1d,
_solve_check,
)
def _cho_solve(
C: np.ndarray, B: np.ndarray, lower: bool, overwrite_b: bool, check_finite: bool
):
"""
Solve a positive-definite linear system using the Cholesky decomposition.
"""
return linalg.cho_solve(
(C, lower), b=B, overwrite_b=overwrite_b, check_finite=check_finite
)
@overload(_cho_solve)
def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
ensure_lapack()
_check_scipy_linalg_matrix(C, "cho_solve")
_check_scipy_linalg_matrix(B, "cho_solve")
dtype = C.dtype
w_type = _get_underlying_float(dtype)
numba_potrs = _LAPACK().numba_xpotrs(dtype)
def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
_solve_check_input_shapes(C, B)
_N = np.int32(C.shape[-1])
if C.flags.f_contiguous or C.flags.c_contiguous:
C_f = C
if C.flags.c_contiguous:
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
lower = not lower
else:
C_f = np.asfortranarray(C)
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)
B_is_1d = B.ndim == 1
if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)
NRHS = 1 if B_is_1d else int(B.shape[-1])
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
N = val_to_int_ptr(_N)
NRHS = val_to_int_ptr(NRHS)
LDA = val_to_int_ptr(_N)
LDB = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)
numba_potrs(
UPLO,
N,
NRHS,
C_f.view(w_type).ctypes,
LDA,
B_copy.view(w_type).ctypes,
LDB,
INFO,
)
_solve_check(_N, int_ptr_to_val(INFO))
if B_is_1d:
return B_copy[..., 0]
return B_copy
return impl
from collections.abc import Callable
import numpy as np
from numba.core.extending import overload
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix,
_copy_to_fortran_order_even_if_1d,
_solve_check,
_trans_char_to_int,
)
def _xgecon(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]:
"""
Placeholder for computing the condition number of a matrix; used by linalg.solve. Not used by pytensor to numbify
graphs.
"""
return # type: ignore
@overload(_xgecon)
def xgecon_impl(
A: np.ndarray, A_norm: float, norm: str
) -> Callable[[np.ndarray, float, str], tuple[np.ndarray, int]]:
"""
Compute the condition number of a matrix A.
"""
ensure_lapack()
_check_scipy_linalg_matrix(A, "gecon")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_gecon = _LAPACK().numba_xgecon(dtype)
def impl(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]:
_N = np.int32(A.shape[-1])
A_copy = _copy_to_fortran_order(A)
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
A_NORM = np.array(A_norm, dtype=dtype)
NORM = val_to_int_ptr(ord(norm))
RCOND = np.empty(1, dtype=dtype)
WORK = np.empty(4 * _N, dtype=dtype)
IWORK = np.empty(_N, dtype=np.int32)
INFO = val_to_int_ptr(1)
numba_gecon(
NORM,
N,
A_copy.view(w_type).ctypes,
LDA,
A_NORM.view(w_type).ctypes,
RCOND.view(w_type).ctypes,
WORK.view(w_type).ctypes,
IWORK.ctypes,
INFO,
)
return RCOND, int_ptr_to_val(INFO)
return impl
def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]:
"""
Placeholder for LU factorization; used by linalg.solve.
# TODO: Implement an LU_factor Op, then dispatch to this function in numba mode.
"""
return # type: ignore
@overload(_getrf)
def getrf_impl(
A: np.ndarray, overwrite_a: bool = False
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray, int]]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "getrf")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_getrf = _LAPACK().numba_xgetrf(dtype)
def impl(
A: np.ndarray, overwrite_a: bool = False
) -> tuple[np.ndarray, np.ndarray, int]:
_M, _N = np.int32(A.shape[-2:]) # type: ignore
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
M = val_to_int_ptr(_M) # type: ignore
N = val_to_int_ptr(_N) # type: ignore
LDA = val_to_int_ptr(_M) # type: ignore
IPIV = np.empty(_N, dtype=np.int32) # type: ignore
INFO = val_to_int_ptr(0)
numba_getrf(M, N, A_copy.view(w_type).ctypes, LDA, IPIV.ctypes, INFO)
return A_copy, IPIV, int_ptr_to_val(INFO)
return impl
def _getrs(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
) -> tuple[np.ndarray, int]:
"""
Placeholder for solving a linear system with a matrix that has been LU-factored; used by linalg.solve.
# TODO: Implement an LU_solve Op, then dispatch to this function in numba mode.
"""
return # type: ignore
@overload(_getrs)
def getrs_impl(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, int, bool], tuple[np.ndarray, int]]:
ensure_lapack()
_check_scipy_linalg_matrix(LU, "getrs")
_check_scipy_linalg_matrix(B, "getrs")
dtype = LU.dtype
w_type = _get_underlying_float(dtype)
numba_getrs = _LAPACK().numba_xgetrs(dtype)
def impl(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
) -> tuple[np.ndarray, int]:
_N = np.int32(LU.shape[-1])
_solve_check_input_shapes(LU, B)
B_is_1d = B.ndim == 1
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)
if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)
NRHS = 1 if B_is_1d else int(B_copy.shape[-1])
TRANS = val_to_int_ptr(_trans_char_to_int(trans))
N = val_to_int_ptr(_N)
NRHS = val_to_int_ptr(NRHS)
LDA = val_to_int_ptr(_N)
LDB = val_to_int_ptr(_N)
IPIV = _copy_to_fortran_order(IPIV)
INFO = val_to_int_ptr(0)
numba_getrs(
TRANS,
N,
NRHS,
LU.view(w_type).ctypes,
LDA,
IPIV.ctypes,
B_copy.view(w_type).ctypes,
LDB,
INFO,
)
if B_is_1d:
B_copy = B_copy[..., 0]
return B_copy, int_ptr_to_val(INFO)
return impl
def _solve_gen(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
):
"""Thin wrapper around scipy.linalg.solve. Used as an overload target for numba to avoid unexpected side-effects
for users who import pytensor."""
return linalg.solve(
A,
B,
lower=lower,
overwrite_a=overwrite_a,
overwrite_b=overwrite_b,
check_finite=check_finite,
assume_a="gen",
transposed=transposed,
)
@overload(_solve_gen)
def solve_gen_impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "solve")
_check_scipy_linalg_matrix(B, "solve")
def impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> np.ndarray:
_N = np.int32(A.shape[-1])
_solve_check_input_shapes(A, B)
if overwrite_a and A.flags.c_contiguous:
# Work with the transposed system to avoid copying A
A = A.T
transposed = not transposed
order = "I" if transposed else "1"
norm = _xlange(A, order=order)
N = A.shape[1]
LU, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a)
_solve_check(N, INFO)
X, INFO = _getrs(
LU=LU, B=B, IPIV=IPIV, trans=transposed, overwrite_b=overwrite_b
)
_solve_check(N, INFO)
RCOND, INFO = _xgecon(LU, norm, "1")
_solve_check(N, INFO, True, RCOND)
return X
return impl
from collections.abc import Callable
import numpy as np
from numba.core.extending import overload
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix
def _xlange(A: np.ndarray, order: str | None = None) -> float:
"""
Placeholder for computing the norm of a matrix; used by linalg.solve. Will never be called in python mode.
"""
return # type: ignore
@overload(_xlange)
def xlange_impl(
A: np.ndarray, order: str | None = None
) -> Callable[[np.ndarray, str], float]:
"""
xLANGE returns the value of the one norm, or the Frobenius norm, or the infinity norm, or the element of
largest absolute value of a matrix A.
"""
ensure_lapack()
_check_scipy_linalg_matrix(A, "norm")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_lange = _LAPACK().numba_xlange(dtype)
def impl(A: np.ndarray, order: str | None = None):
_M, _N = np.int32(A.shape[-2:]) # type: ignore
A_copy = _copy_to_fortran_order(A)
M = val_to_int_ptr(_M) # type: ignore
N = val_to_int_ptr(_N) # type: ignore
LDA = val_to_int_ptr(_M) # type: ignore
NORM = (
val_to_int_ptr(ord(order))
if order is not None
else val_to_int_ptr(ord("1"))
)
WORK = np.empty(_M, dtype=dtype) # type: ignore
result = numba_lange(
NORM, M, N, A_copy.view(w_type).ctypes, LDA, WORK.view(w_type).ctypes
)
return result
return impl
from collections.abc import Callable
import numpy as np
from numba.core.extending import overload
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix,
_copy_to_fortran_order_even_if_1d,
_solve_check,
)
def _posv(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> tuple[np.ndarray, np.ndarray, int]:
"""
Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve.
"""
return # type: ignore
@overload(_posv)
def posv_impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> Callable[
[np.ndarray, np.ndarray, bool, bool, bool, bool, bool],
tuple[np.ndarray, np.ndarray, int],
]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "solve")
_check_scipy_linalg_matrix(B, "solve")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_posv = _LAPACK().numba_xposv(dtype)
def impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> tuple[np.ndarray, np.ndarray, int]:
_solve_check_input_shapes(A, B)
_N = np.int32(A.shape[-1])
if overwrite_a and (A.flags.f_contiguous or A.flags.c_contiguous):
A_copy = A
if A.flags.c_contiguous:
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
lower = not lower
else:
A_copy = _copy_to_fortran_order(A)
B_is_1d = B.ndim == 1
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)
if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
NRHS = 1 if B_is_1d else int(B.shape[-1])
N = val_to_int_ptr(_N)
NRHS = val_to_int_ptr(NRHS)
LDA = val_to_int_ptr(_N)
LDB = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)
numba_posv(
UPLO,
N,
NRHS,
A_copy.view(w_type).ctypes,
LDA,
B_copy.view(w_type).ctypes,
LDB,
INFO,
)
if B_is_1d:
B_copy = B_copy[..., 0]
return A_copy, B_copy, int_ptr_to_val(INFO)
return impl
def _pocon(A: np.ndarray, anorm: float) -> tuple[np.ndarray, int]:
"""
Placeholder for computing the condition number of a cholesky-factorized positive-definite matrix. Used by
linalg.solve when assume_a = "pos".
"""
return # type: ignore
@overload(_pocon)
def pocon_impl(
A: np.ndarray, anorm: float
) -> Callable[[np.ndarray, float], tuple[np.ndarray, int]]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "pocon")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_pocon = _LAPACK().numba_xpocon(dtype)
def impl(A: np.ndarray, anorm: float):
_N = np.int32(A.shape[-1])
A_copy = _copy_to_fortran_order(A)
UPLO = val_to_int_ptr(ord("L"))
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
ANORM = np.array(anorm, dtype=dtype)
RCOND = np.empty(1, dtype=dtype)
WORK = np.empty(3 * _N, dtype=dtype)
IWORK = np.empty(_N, dtype=np.int32)
INFO = val_to_int_ptr(0)
numba_pocon(
UPLO,
N,
A_copy.view(w_type).ctypes,
LDA,
ANORM.view(w_type).ctypes,
RCOND.view(w_type).ctypes,
WORK.view(w_type).ctypes,
IWORK.ctypes,
INFO,
)
return RCOND, int_ptr_to_val(INFO)
return impl
def _solve_psd(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
):
"""Thin wrapper around scipy.linalg.solve for positive-definite matrices. Used as an overload target for numba to
avoid unexpected side-effects when users import pytensor."""
return linalg.solve(
A,
B,
lower=lower,
overwrite_a=overwrite_a,
overwrite_b=overwrite_b,
check_finite=check_finite,
transposed=transposed,
assume_a="pos",
)
@overload(_solve_psd)
def solve_psd_impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "solve")
_check_scipy_linalg_matrix(B, "solve")
def impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> np.ndarray:
_solve_check_input_shapes(A, B)
C, x, info = _posv(
A, B, lower, overwrite_a, overwrite_b, check_finite, transposed
)
_solve_check(A.shape[-1], info)
rcond, info = _pocon(C, _xlange(A))
_solve_check(A.shape[-1], info=info, lamch=True, rcond=rcond)
return x
return impl
from collections.abc import Callable
import numpy as np
from numba.core.extending import overload
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix,
_copy_to_fortran_order_even_if_1d,
_solve_check,
)
def _sysv(
A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool
) -> tuple[np.ndarray, np.ndarray, np.ndarray, int]:
"""
Placeholder for solving a linear system with a symmetric matrix; used by linalg.solve.
"""
return # type: ignore
@overload(_sysv)
def sysv_impl(
A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool
) -> Callable[
[np.ndarray, np.ndarray, bool, bool, bool],
tuple[np.ndarray, np.ndarray, np.ndarray, int],
]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "sysv")
_check_scipy_linalg_matrix(B, "sysv")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_sysv = _LAPACK().numba_xsysv(dtype)
def impl(
A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool
):
_LDA, _N = np.int32(A.shape[-2:]) # type: ignore
_solve_check_input_shapes(A, B)
if overwrite_a and (A.flags.f_contiguous or A.flags.c_contiguous):
A_copy = A
if A.flags.c_contiguous:
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
lower = not lower
else:
A_copy = _copy_to_fortran_order(A)
B_is_1d = B.ndim == 1
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)
if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)
NRHS = 1 if B_is_1d else int(B.shape[-1])
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
N = val_to_int_ptr(_N) # type: ignore
NRHS = val_to_int_ptr(NRHS)
LDA = val_to_int_ptr(_LDA) # type: ignore
IPIV = np.empty(_N, dtype=np.int32) # type: ignore
LDB = val_to_int_ptr(_N) # type: ignore
WORK = np.empty(1, dtype=dtype)
LWORK = val_to_int_ptr(-1)
INFO = val_to_int_ptr(0)
# Workspace query
numba_sysv(
UPLO,
N,
NRHS,
A_copy.view(w_type).ctypes,
LDA,
IPIV.ctypes,
B_copy.view(w_type).ctypes,
LDB,
WORK.view(w_type).ctypes,
LWORK,
INFO,
)
WS_SIZE = np.int32(WORK[0].real)
LWORK = val_to_int_ptr(WS_SIZE)
WORK = np.empty(WS_SIZE, dtype=dtype)
# Actual solve
numba_sysv(
UPLO,
N,
NRHS,
A_copy.view(w_type).ctypes,
LDA,
IPIV.ctypes,
B_copy.view(w_type).ctypes,
LDB,
WORK.view(w_type).ctypes,
LWORK,
INFO,
)
if B_is_1d:
B_copy = B_copy[..., 0]
return A_copy, B_copy, IPIV, int_ptr_to_val(INFO)
return impl
def _sycon(A: np.ndarray, ipiv: np.ndarray, anorm: float) -> tuple[np.ndarray, int]:
"""
Placeholder for computing the condition number of a symmetric matrix; used by linalg.solve. Never called in
python mode.
"""
return # type: ignore
@overload(_sycon)
def sycon_impl(
A: np.ndarray, ipiv: np.ndarray, anorm: float
) -> Callable[[np.ndarray, np.ndarray, float], tuple[np.ndarray, int]]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "sycon")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_sycon = _LAPACK().numba_xsycon(dtype)
def impl(A: np.ndarray, ipiv: np.ndarray, anorm: float) -> tuple[np.ndarray, int]:
_N = np.int32(A.shape[-1])
A_copy = _copy_to_fortran_order(A)
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
UPLO = val_to_int_ptr(ord("U"))
ANORM = np.array(anorm, dtype=dtype)
RCOND = np.empty(1, dtype=dtype)
WORK = np.empty(2 * _N, dtype=dtype)
IWORK = np.empty(_N, dtype=np.int32)
INFO = val_to_int_ptr(0)
numba_sycon(
UPLO,
N,
A_copy.view(w_type).ctypes,
LDA,
ipiv.ctypes,
ANORM.view(w_type).ctypes,
RCOND.view(w_type).ctypes,
WORK.view(w_type).ctypes,
IWORK.ctypes,
INFO,
)
return RCOND, int_ptr_to_val(INFO)
return impl
def _solve_symmetric(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
):
"""Thin wrapper around scipy.linalg.solve for symmetric matrices. Used as an overload target for numba to avoid
unexpected side-effects when users import pytensor."""
return linalg.solve(
A,
B,
lower=lower,
overwrite_a=overwrite_a,
overwrite_b=overwrite_b,
check_finite=check_finite,
assume_a="sym",
transposed=transposed,
)
@overload(_solve_symmetric)
def solve_symmetric_impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "solve")
_check_scipy_linalg_matrix(B, "solve")
def impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> np.ndarray:
_solve_check_input_shapes(A, B)
lu, x, ipiv, info = _sysv(A, B, lower, overwrite_a, overwrite_b)
_solve_check(A.shape[-1], info)
rcond, info = _sycon(lu, ipiv, _xlange(A, order="I"))
_solve_check(A.shape[-1], info, True, rcond)
return x
return impl
import numpy as np
from numba.core import types
from numba.core.extending import overload
from numba.np.linalg import ensure_lapack
from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix,
_copy_to_fortran_order_even_if_1d,
_solve_check,
_trans_char_to_int,
)
def _solve_triangular(
A, B, trans=0, lower=False, unit_diagonal=False, b_ndim=1, overwrite_b=False
):
"""
Thin wrapper around scipy.linalg.solve_triangular.
This function is overloaded instead of the original scipy function to avoid unexpected side-effects to users who
import pytensor.
The signature must be the same as solve_triangular_impl, so b_ndim is included, although this argument is not
used by scipy.linalg.solve_triangular.
"""
return linalg.solve_triangular(
A,
B,
trans=trans,
lower=lower,
unit_diagonal=unit_diagonal,
overwrite_b=overwrite_b,
)
@overload(_solve_triangular)
def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
ensure_lapack()
_check_scipy_linalg_matrix(A, "solve_triangular")
_check_scipy_linalg_matrix(B, "solve_triangular")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_trtrs = _LAPACK().numba_xtrtrs(dtype)
if isinstance(dtype, types.Complex):
# If you want to make this work with complex numbers make sure you handle the c_contiguous trick correctly
raise TypeError(
"This function is not expected to work with complex numbers yet"
)
def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
_N = np.int32(A.shape[-1])
_solve_check_input_shapes(A, B)
# Seems weird to not use the b_ndim input directly, but when I did that Numba complained that the output type
# could potentially be 3d (it didn't understand b_ndim was always equal to B.ndim)
B_is_1d = B.ndim == 1
if A.flags.f_contiguous or (A.flags.c_contiguous and trans in (0, 1)):
A_f = A
if A.flags.c_contiguous:
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
# Is this valid for complex matrices that were .conj().mT by PyTensor?
lower = not lower
trans = 1 - trans
else:
A_f = np.asfortranarray(A)
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)
if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)
NRHS = 1 if B_is_1d else int(B_copy.shape[-1])
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
TRANS = val_to_int_ptr(_trans_char_to_int(trans))
DIAG = val_to_int_ptr(ord("U") if unit_diagonal else ord("N"))
N = val_to_int_ptr(_N)
NRHS = val_to_int_ptr(NRHS)
LDA = val_to_int_ptr(_N)
LDB = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)
numba_trtrs(
UPLO,
TRANS,
DIAG,
N,
NRHS,
A_f.view(w_type).ctypes,
LDA,
B_copy.view(w_type).ctypes,
LDB,
INFO,
)
_solve_check(int_ptr_to_val(LDA), int_ptr_to_val(INFO))
if B_is_1d:
return B_copy[..., 0]
return B_copy
return impl
from scipy import linalg
from pytensor.link.numba.dispatch import basic as numba_basic
@numba_basic.numba_njit(inline="always")
def _solve_check_input_shapes(A, B):
if A.shape[0] != B.shape[0]:
raise linalg.LinAlgError("Dimensions of A and B do not conform")
if A.shape[-2] != A.shape[-1]:
raise linalg.LinAlgError("Last 2 dimensions of A must be square")
from collections.abc import Callable
import numba
from numba.core import types
from numba.core.extending import overload
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from numpy.linalg import LinAlgError
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
val_to_int_ptr,
)
@numba_basic.numba_njit(inline="always")
def _copy_to_fortran_order_even_if_1d(x):
# Numba's _copy_to_fortran_order doesn't do anything for vectors
return x.copy() if x.ndim == 1 else _copy_to_fortran_order(x)
@numba_basic.numba_njit(inline="always")
def _trans_char_to_int(trans):
if trans not in [0, 1, 2]:
raise ValueError('Parameter "trans" should be one of 0, 1, 2')
if trans == 0:
return ord("N")
elif trans == 1:
return ord("T")
else:
return ord("C")
def _check_scipy_linalg_matrix(a, func_name):
"""
Adapted from https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L831
"""
prefix = "scipy.linalg"
# Unpack optional type
if isinstance(a, types.Optional):
a = a.type
if not isinstance(a, types.Array):
msg = f"{prefix}.{func_name}() only supported for array types"
raise numba.TypingError(msg, highlighting=False)
if a.ndim not in [1, 2]:
msg = (
f"{prefix}.{func_name}() only supported on 1d or 2d arrays, found {a.ndim}."
)
raise numba.TypingError(msg, highlighting=False)
if not isinstance(a.dtype, types.Float | types.Complex):
msg = f"{prefix}.{func_name}() only supported on float and complex arrays."
raise numba.TypingError(msg, highlighting=False)
@numba_basic.numba_njit(inline="always")
def _solve_check(n, info, lamch=False, rcond=None):
"""
Check arguments during the different steps of the solution phase
Adapted from https://github.com/scipy/scipy/blob/7f7f04caa4a55306a9c6613c89eef91fedbd72d4/scipy/linalg/_basic.py#L38
"""
if info < 0:
# TODO: figure out how to do an fstring here
msg = "LAPACK reported an illegal value in input"
raise ValueError(msg)
elif 0 < info:
raise LinAlgError("Matrix is singular.")
if lamch:
E = _xlamch("E")
if rcond < E:
# TODO: This should be a warning, but we can't raise warnings in numba mode
print( # noqa: T201
"Ill-conditioned matrix, rcond=", rcond, ", result may not be accurate."
)
def _xlamch(kind: str = "E"):
"""
Placeholder for getting machine precision; used by linalg.solve. Not used by pytensor to numbify graphs.
"""
pass
@overload(_xlamch)
def xlamch_impl(kind: str = "E") -> Callable[[str], float]:
"""
Compute the machine precision for a given floating point type.
"""
from pytensor import config
ensure_lapack()
w_type = _get_underlying_float(config.floatX)
if w_type == "float32":
dtype = types.float32
elif w_type == "float64":
dtype = types.float64
else:
raise NotImplementedError("Unsupported dtype")
numba_lamch = _LAPACK().numba_xlamch(dtype)
def impl(kind: str = "E") -> float:
KIND = val_to_int_ptr(ord(kind))
return numba_lamch(KIND) # type: ignore
return impl
import warnings import warnings
from collections.abc import Callable
import numba
import numpy as np import numpy as np
from numba.core import types
from numba.extending import overload
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from numpy.linalg import LinAlgError
from scipy import linalg
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import numba_funcify, numba_njit
from pytensor.link.numba.dispatch._LAPACK import ( from pytensor.link.numba.dispatch.linalg.decomposition.cholesky import _cholesky
_LAPACK, from pytensor.link.numba.dispatch.linalg.solve.cholesky import _cho_solve
_get_underlying_float, from pytensor.link.numba.dispatch.linalg.solve.general import _solve_gen
int_ptr_to_val, from pytensor.link.numba.dispatch.linalg.solve.posdef import _solve_psd
val_to_int_ptr, from pytensor.link.numba.dispatch.linalg.solve.symmetric import _solve_symmetric
) from pytensor.link.numba.dispatch.linalg.solve.triangular import _solve_triangular
from pytensor.link.numba.dispatch.basic import numba_funcify
from pytensor.tensor.slinalg import ( from pytensor.tensor.slinalg import (
BlockDiagonal, BlockDiagonal,
Cholesky, Cholesky,
...@@ -33,265 +25,6 @@ _COMPLEX_DTYPE_NOT_SUPPORTED_MSG = ( ...@@ -33,265 +25,6 @@ _COMPLEX_DTYPE_NOT_SUPPORTED_MSG = (
) )
@numba_basic.numba_njit(inline="always")
def _copy_to_fortran_order_even_if_1d(x):
# Numba's _copy_to_fortran_order doesn't do anything for vectors
return x.copy() if x.ndim == 1 else _copy_to_fortran_order(x)
@numba_basic.numba_njit(inline="always")
def _solve_check(n, info, lamch=False, rcond=None):
"""
Check arguments during the different steps of the solution phase
Adapted from https://github.com/scipy/scipy/blob/7f7f04caa4a55306a9c6613c89eef91fedbd72d4/scipy/linalg/_basic.py#L38
"""
if info < 0:
# TODO: figure out how to do an fstring here
msg = "LAPACK reported an illegal value in input"
raise ValueError(msg)
elif 0 < info:
raise LinAlgError("Matrix is singular.")
if lamch:
E = _xlamch("E")
if rcond < E:
# TODO: This should be a warning, but we can't raise warnings in numba mode
print( # noqa: T201
"Ill-conditioned matrix, rcond=", rcond, ", result may not be accurate."
)
def _check_scipy_linalg_matrix(a, func_name):
"""
Adapted from https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L831
"""
prefix = "scipy.linalg"
# Unpack optional type
if isinstance(a, types.Optional):
a = a.type
if not isinstance(a, types.Array):
msg = f"{prefix}.{func_name}() only supported for array types"
raise numba.TypingError(msg, highlighting=False)
if a.ndim not in [1, 2]:
msg = (
f"{prefix}.{func_name}() only supported on 1d or 2d arrays, found {a.ndim}."
)
raise numba.TypingError(msg, highlighting=False)
if not isinstance(a.dtype, types.Float | types.Complex):
msg = f"{prefix}.{func_name}() only supported on float and complex arrays."
raise numba.TypingError(msg, highlighting=False)
def _solve_triangular(
A, B, trans=0, lower=False, unit_diagonal=False, b_ndim=1, overwrite_b=False
):
"""
Thin wrapper around scipy.linalg.solve_triangular.
This function is overloaded instead of the original scipy function to avoid unexpected side-effects to users who
import pytensor.
The signature must be the same as solve_triangular_impl, so b_ndim is included, although this argument is not
used by scipy.linalg.solve_triangular.
"""
return linalg.solve_triangular(
A,
B,
trans=trans,
lower=lower,
unit_diagonal=unit_diagonal,
overwrite_b=overwrite_b,
)
@numba_basic.numba_njit(inline="always")
def _trans_char_to_int(trans):
if trans not in [0, 1, 2]:
raise ValueError('Parameter "trans" should be one of 0, 1, 2')
if trans == 0:
return ord("N")
elif trans == 1:
return ord("T")
else:
return ord("C")
@numba_basic.numba_njit(inline="always")
def _solve_check_input_shapes(A, B):
if A.shape[0] != B.shape[0]:
raise linalg.LinAlgError("Dimensions of A and B do not conform")
if A.shape[-2] != A.shape[-1]:
raise linalg.LinAlgError("Last 2 dimensions of A must be square")
@overload(_solve_triangular)
def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
ensure_lapack()
_check_scipy_linalg_matrix(A, "solve_triangular")
_check_scipy_linalg_matrix(B, "solve_triangular")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_trtrs = _LAPACK().numba_xtrtrs(dtype)
if isinstance(dtype, types.Complex):
# If you want to make this work with complex numbers make sure you handle the c_contiguous trick correctly
raise TypeError("This function is not expected to work with complex numbers")
def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
_N = np.int32(A.shape[-1])
_solve_check_input_shapes(A, B)
# Seems weird to not use the b_ndim input directly, but when I did that Numba complained that the output type
# could potentially be 3d (it didn't understand b_ndim was always equal to B.ndim)
B_is_1d = B.ndim == 1
if A.flags.f_contiguous or (A.flags.c_contiguous and trans in (0, 1)):
A_f = A
if A.flags.c_contiguous:
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
# Is this valid for complex matrices that were .conj().mT by PyTensor?
lower = not lower
trans = 1 - trans
else:
A_f = np.asfortranarray(A)
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)
if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)
NRHS = 1 if B_is_1d else int(B_copy.shape[-1])
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
TRANS = val_to_int_ptr(_trans_char_to_int(trans))
DIAG = val_to_int_ptr(ord("U") if unit_diagonal else ord("N"))
N = val_to_int_ptr(_N)
NRHS = val_to_int_ptr(NRHS)
LDA = val_to_int_ptr(_N)
LDB = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)
numba_trtrs(
UPLO,
TRANS,
DIAG,
N,
NRHS,
A_f.view(w_type).ctypes,
LDA,
B_copy.view(w_type).ctypes,
LDB,
INFO,
)
_solve_check(int_ptr_to_val(LDA), int_ptr_to_val(INFO))
if B_is_1d:
return B_copy[..., 0]
return B_copy
return impl
@numba_funcify.register(SolveTriangular)
def numba_funcify_SolveTriangular(op, node, **kwargs):
lower = op.lower
unit_diagonal = op.unit_diagonal
check_finite = op.check_finite
overwrite_b = op.overwrite_b
b_ndim = op.b_ndim
dtype = node.inputs[0].dtype
if dtype in complex_dtypes:
raise NotImplementedError(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op="Solve Triangular")
)
@numba_basic.numba_njit(inline="always")
def solve_triangular(a, b):
if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input A to solve_triangular"
)
if np.any(np.bitwise_or(np.isinf(b), np.isnan(b))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input b to solve_triangular"
)
res = _solve_triangular(
a,
b,
trans=0, # transposing is handled explicitly on the graph, so we never use this argument
lower=lower,
unit_diagonal=unit_diagonal,
overwrite_b=overwrite_b,
b_ndim=b_ndim,
)
return res
return solve_triangular
def _cholesky(a, lower=False, overwrite_a=False, check_finite=True):
return (
linalg.cholesky(
a, lower=lower, overwrite_a=overwrite_a, check_finite=check_finite
),
0,
)
@overload(_cholesky)
def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
ensure_lapack()
_check_scipy_linalg_matrix(A, "cholesky")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_potrf = _LAPACK().numba_xpotrf(dtype)
def impl(A, lower=0, overwrite_a=False, check_finite=True):
_N = np.int32(A.shape[-1])
if A.shape[-2] != _N:
raise linalg.LinAlgError("Last 2 dimensions of A must be square")
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
numba_potrf(
UPLO,
N,
A_copy.view(w_type).ctypes,
LDA,
INFO,
)
if lower:
for j in range(1, _N):
for i in range(j):
A_copy[i, j] = 0.0
else:
for j in range(_N):
for i in range(j + 1, _N):
A_copy[i, j] = 0.0
return A_copy, int_ptr_to_val(INFO)
return impl
@numba_funcify.register(Cholesky) @numba_funcify.register(Cholesky)
def numba_funcify_Cholesky(op, node, **kwargs): def numba_funcify_Cholesky(op, node, **kwargs):
""" """
...@@ -309,8 +42,8 @@ def numba_funcify_Cholesky(op, node, **kwargs): ...@@ -309,8 +42,8 @@ def numba_funcify_Cholesky(op, node, **kwargs):
if dtype in complex_dtypes: if dtype in complex_dtypes:
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
@numba_basic.numba_njit(inline="always") @numba_njit
def nb_cholesky(a): def cholesky(a):
if check_finite: if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError( raise np.linalg.LinAlgError(
...@@ -333,7 +66,7 @@ def numba_funcify_Cholesky(op, node, **kwargs): ...@@ -333,7 +66,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
return res return res
return nb_cholesky return cholesky
@numba_funcify.register(BlockDiagonal) @numba_funcify.register(BlockDiagonal)
...@@ -341,7 +74,7 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs): ...@@ -341,7 +74,7 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs):
dtype = node.outputs[0].dtype dtype = node.outputs[0].dtype
# TODO: Why do we always inline all functions? It doesn't work with starred args, so can't use it in this case. # TODO: Why do we always inline all functions? It doesn't work with starred args, so can't use it in this case.
@numba_basic.numba_njit(inline="never") @numba_njit
def block_diag(*arrs): def block_diag(*arrs):
shapes = np.array([a.shape for a in arrs], dtype="int") shapes = np.array([a.shape for a in arrs], dtype="int")
out_shape = [int(s) for s in np.sum(shapes, axis=0)] out_shape = [int(s) for s in np.sum(shapes, axis=0)]
...@@ -359,731 +92,6 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs): ...@@ -359,731 +92,6 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs):
return block_diag return block_diag
def _xlamch(kind: str = "E"):
"""
Placeholder for getting machine precision; used by linalg.solve. Not used by pytensor to numbify graphs.
"""
pass
@overload(_xlamch)
def xlamch_impl(kind: str = "E") -> Callable[[str], float]:
"""
Compute the machine precision for a given floating point type.
"""
from pytensor import config
ensure_lapack()
w_type = _get_underlying_float(config.floatX)
if w_type == "float32":
dtype = types.float32
elif w_type == "float64":
dtype = types.float64
else:
raise NotImplementedError("Unsupported dtype")
numba_lamch = _LAPACK().numba_xlamch(dtype)
def impl(kind: str = "E") -> float:
KIND = val_to_int_ptr(ord(kind))
return numba_lamch(KIND) # type: ignore
return impl
def _xlange(A: np.ndarray, order: str | None = None) -> float:
"""
Placeholder for computing the norm of a matrix; used by linalg.solve. Will never be called in python mode.
"""
return # type: ignore
@overload(_xlange)
def xlange_impl(
A: np.ndarray, order: str | None = None
) -> Callable[[np.ndarray, str], float]:
"""
xLANGE returns the value of the one norm, or the Frobenius norm, or the infinity norm, or the element of
largest absolute value of a matrix A.
"""
ensure_lapack()
_check_scipy_linalg_matrix(A, "norm")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_lange = _LAPACK().numba_xlange(dtype)
def impl(A: np.ndarray, order: str | None = None):
_M, _N = np.int32(A.shape[-2:]) # type: ignore
A_copy = _copy_to_fortran_order(A)
M = val_to_int_ptr(_M) # type: ignore
N = val_to_int_ptr(_N) # type: ignore
LDA = val_to_int_ptr(_M) # type: ignore
NORM = (
val_to_int_ptr(ord(order))
if order is not None
else val_to_int_ptr(ord("1"))
)
WORK = np.empty(_M, dtype=dtype) # type: ignore
result = numba_lange(
NORM, M, N, A_copy.view(w_type).ctypes, LDA, WORK.view(w_type).ctypes
)
return result
return impl
def _xgecon(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]:
"""
Placeholder for computing the condition number of a matrix; used by linalg.solve. Not used by pytensor to numbify
graphs.
"""
return # type: ignore
@overload(_xgecon)
def xgecon_impl(
A: np.ndarray, A_norm: float, norm: str
) -> Callable[[np.ndarray, float, str], tuple[np.ndarray, int]]:
"""
Compute the condition number of a matrix A.
"""
ensure_lapack()
_check_scipy_linalg_matrix(A, "gecon")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_gecon = _LAPACK().numba_xgecon(dtype)
def impl(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]:
_N = np.int32(A.shape[-1])
A_copy = _copy_to_fortran_order(A)
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
A_NORM = np.array(A_norm, dtype=dtype)
NORM = val_to_int_ptr(ord(norm))
RCOND = np.empty(1, dtype=dtype)
WORK = np.empty(4 * _N, dtype=dtype)
IWORK = np.empty(_N, dtype=np.int32)
INFO = val_to_int_ptr(1)
numba_gecon(
NORM,
N,
A_copy.view(w_type).ctypes,
LDA,
A_NORM.view(w_type).ctypes,
RCOND.view(w_type).ctypes,
WORK.view(w_type).ctypes,
IWORK.ctypes,
INFO,
)
return RCOND, int_ptr_to_val(INFO)
return impl
def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]:
"""
Placeholder for LU factorization; used by linalg.solve.
# TODO: Implement an LU_factor Op, then dispatch to this function in numba mode.
"""
return # type: ignore
@overload(_getrf)
def getrf_impl(
A: np.ndarray, overwrite_a: bool = False
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray, int]]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "getrf")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_getrf = _LAPACK().numba_xgetrf(dtype)
def impl(
A: np.ndarray, overwrite_a: bool = False
) -> tuple[np.ndarray, np.ndarray, int]:
_M, _N = np.int32(A.shape[-2:]) # type: ignore
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
M = val_to_int_ptr(_M) # type: ignore
N = val_to_int_ptr(_N) # type: ignore
LDA = val_to_int_ptr(_M) # type: ignore
IPIV = np.empty(_N, dtype=np.int32) # type: ignore
INFO = val_to_int_ptr(0)
numba_getrf(M, N, A_copy.view(w_type).ctypes, LDA, IPIV.ctypes, INFO)
return A_copy, IPIV, int_ptr_to_val(INFO)
return impl
def _getrs(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
) -> tuple[np.ndarray, int]:
"""
Placeholder for solving a linear system with a matrix that has been LU-factored; used by linalg.solve.
# TODO: Implement an LU_solve Op, then dispatch to this function in numba mode.
"""
return # type: ignore
@overload(_getrs)
def getrs_impl(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, int, bool], tuple[np.ndarray, int]]:
ensure_lapack()
_check_scipy_linalg_matrix(LU, "getrs")
_check_scipy_linalg_matrix(B, "getrs")
dtype = LU.dtype
w_type = _get_underlying_float(dtype)
numba_getrs = _LAPACK().numba_xgetrs(dtype)
def impl(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
) -> tuple[np.ndarray, int]:
_N = np.int32(LU.shape[-1])
_solve_check_input_shapes(LU, B)
B_is_1d = B.ndim == 1
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)
if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)
NRHS = 1 if B_is_1d else int(B_copy.shape[-1])
TRANS = val_to_int_ptr(_trans_char_to_int(trans))
N = val_to_int_ptr(_N)
NRHS = val_to_int_ptr(NRHS)
LDA = val_to_int_ptr(_N)
LDB = val_to_int_ptr(_N)
IPIV = _copy_to_fortran_order(IPIV)
INFO = val_to_int_ptr(0)
numba_getrs(
TRANS,
N,
NRHS,
LU.view(w_type).ctypes,
LDA,
IPIV.ctypes,
B_copy.view(w_type).ctypes,
LDB,
INFO,
)
if B_is_1d:
B_copy = B_copy[..., 0]
return B_copy, int_ptr_to_val(INFO)
return impl
def _solve_gen(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
):
"""Thin wrapper around scipy.linalg.solve. Used as an overload target for numba to avoid unexpected side-effects
for users who import pytensor."""
return linalg.solve(
A,
B,
lower=lower,
overwrite_a=overwrite_a,
overwrite_b=overwrite_b,
check_finite=check_finite,
assume_a="gen",
transposed=transposed,
)
@overload(_solve_gen)
def solve_gen_impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "solve")
_check_scipy_linalg_matrix(B, "solve")
def impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> np.ndarray:
_N = np.int32(A.shape[-1])
_solve_check_input_shapes(A, B)
if overwrite_a and A.flags.c_contiguous:
# Work with the transposed system to avoid copying A
A = A.T
transposed = not transposed
order = "I" if transposed else "1"
norm = _xlange(A, order=order)
N = A.shape[1]
LU, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a)
_solve_check(N, INFO)
X, INFO = _getrs(
LU=LU, B=B, IPIV=IPIV, trans=transposed, overwrite_b=overwrite_b
)
_solve_check(N, INFO)
RCOND, INFO = _xgecon(LU, norm, "1")
_solve_check(N, INFO, True, RCOND)
return X
return impl
def _sysv(
A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool
) -> tuple[np.ndarray, np.ndarray, np.ndarray, int]:
"""
Placeholder for solving a linear system with a symmetric matrix; used by linalg.solve.
"""
return # type: ignore
@overload(_sysv)
def sysv_impl(
A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool
) -> Callable[
[np.ndarray, np.ndarray, bool, bool, bool],
tuple[np.ndarray, np.ndarray, np.ndarray, int],
]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "sysv")
_check_scipy_linalg_matrix(B, "sysv")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_sysv = _LAPACK().numba_xsysv(dtype)
def impl(
A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool
):
_LDA, _N = np.int32(A.shape[-2:]) # type: ignore
_solve_check_input_shapes(A, B)
if overwrite_a and (A.flags.f_contiguous or A.flags.c_contiguous):
A_copy = A
if A.flags.c_contiguous:
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
lower = not lower
else:
A_copy = _copy_to_fortran_order(A)
B_is_1d = B.ndim == 1
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)
if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)
NRHS = 1 if B_is_1d else int(B.shape[-1])
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
N = val_to_int_ptr(_N) # type: ignore
NRHS = val_to_int_ptr(NRHS)
LDA = val_to_int_ptr(_LDA) # type: ignore
IPIV = np.empty(_N, dtype=np.int32) # type: ignore
LDB = val_to_int_ptr(_N) # type: ignore
WORK = np.empty(1, dtype=dtype)
LWORK = val_to_int_ptr(-1)
INFO = val_to_int_ptr(0)
# Workspace query
numba_sysv(
UPLO,
N,
NRHS,
A_copy.view(w_type).ctypes,
LDA,
IPIV.ctypes,
B_copy.view(w_type).ctypes,
LDB,
WORK.view(w_type).ctypes,
LWORK,
INFO,
)
WS_SIZE = np.int32(WORK[0].real)
LWORK = val_to_int_ptr(WS_SIZE)
WORK = np.empty(WS_SIZE, dtype=dtype)
# Actual solve
numba_sysv(
UPLO,
N,
NRHS,
A_copy.view(w_type).ctypes,
LDA,
IPIV.ctypes,
B_copy.view(w_type).ctypes,
LDB,
WORK.view(w_type).ctypes,
LWORK,
INFO,
)
if B_is_1d:
B_copy = B_copy[..., 0]
return A_copy, B_copy, IPIV, int_ptr_to_val(INFO)
return impl
def _sycon(A: np.ndarray, ipiv: np.ndarray, anorm: float) -> tuple[np.ndarray, int]:
"""
Placeholder for computing the condition number of a symmetric matrix; used by linalg.solve. Never called in
python mode.
"""
return # type: ignore
@overload(_sycon)
def sycon_impl(
A: np.ndarray, ipiv: np.ndarray, anorm: float
) -> Callable[[np.ndarray, np.ndarray, float], tuple[np.ndarray, int]]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "sycon")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_sycon = _LAPACK().numba_xsycon(dtype)
def impl(A: np.ndarray, ipiv: np.ndarray, anorm: float) -> tuple[np.ndarray, int]:
_N = np.int32(A.shape[-1])
A_copy = _copy_to_fortran_order(A)
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
UPLO = val_to_int_ptr(ord("U"))
ANORM = np.array(anorm, dtype=dtype)
RCOND = np.empty(1, dtype=dtype)
WORK = np.empty(2 * _N, dtype=dtype)
IWORK = np.empty(_N, dtype=np.int32)
INFO = val_to_int_ptr(0)
numba_sycon(
UPLO,
N,
A_copy.view(w_type).ctypes,
LDA,
ipiv.ctypes,
ANORM.view(w_type).ctypes,
RCOND.view(w_type).ctypes,
WORK.view(w_type).ctypes,
IWORK.ctypes,
INFO,
)
return RCOND, int_ptr_to_val(INFO)
return impl
def _solve_symmetric(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
):
"""Thin wrapper around scipy.linalg.solve for symmetric matrices. Used as an overload target for numba to avoid
unexpected side-effects when users import pytensor."""
return linalg.solve(
A,
B,
lower=lower,
overwrite_a=overwrite_a,
overwrite_b=overwrite_b,
check_finite=check_finite,
assume_a="sym",
transposed=transposed,
)
@overload(_solve_symmetric)
def solve_symmetric_impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "solve")
_check_scipy_linalg_matrix(B, "solve")
def impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> np.ndarray:
_solve_check_input_shapes(A, B)
lu, x, ipiv, info = _sysv(A, B, lower, overwrite_a, overwrite_b)
_solve_check(A.shape[-1], info)
rcond, info = _sycon(lu, ipiv, _xlange(A, order="I"))
_solve_check(A.shape[-1], info, True, rcond)
return x
return impl
def _posv(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> tuple[np.ndarray, np.ndarray, int]:
"""
Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve.
"""
return # type: ignore
@overload(_posv)
def posv_impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> Callable[
[np.ndarray, np.ndarray, bool, bool, bool, bool, bool],
tuple[np.ndarray, np.ndarray, int],
]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "solve")
_check_scipy_linalg_matrix(B, "solve")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_posv = _LAPACK().numba_xposv(dtype)
def impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> tuple[np.ndarray, np.ndarray, int]:
_solve_check_input_shapes(A, B)
_N = np.int32(A.shape[-1])
if overwrite_a and (A.flags.f_contiguous or A.flags.c_contiguous):
A_copy = A
if A.flags.c_contiguous:
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
lower = not lower
else:
A_copy = _copy_to_fortran_order(A)
B_is_1d = B.ndim == 1
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)
if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
NRHS = 1 if B_is_1d else int(B.shape[-1])
N = val_to_int_ptr(_N)
NRHS = val_to_int_ptr(NRHS)
LDA = val_to_int_ptr(_N)
LDB = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)
numba_posv(
UPLO,
N,
NRHS,
A_copy.view(w_type).ctypes,
LDA,
B_copy.view(w_type).ctypes,
LDB,
INFO,
)
if B_is_1d:
B_copy = B_copy[..., 0]
return A_copy, B_copy, int_ptr_to_val(INFO)
return impl
def _pocon(A: np.ndarray, anorm: float) -> tuple[np.ndarray, int]:
"""
Placeholder for computing the condition number of a cholesky-factorized positive-definite matrix. Used by
linalg.solve when assume_a = "pos".
"""
return # type: ignore
@overload(_pocon)
def pocon_impl(
A: np.ndarray, anorm: float
) -> Callable[[np.ndarray, float], tuple[np.ndarray, int]]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "pocon")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_pocon = _LAPACK().numba_xpocon(dtype)
def impl(A: np.ndarray, anorm: float):
_N = np.int32(A.shape[-1])
A_copy = _copy_to_fortran_order(A)
UPLO = val_to_int_ptr(ord("L"))
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
ANORM = np.array(anorm, dtype=dtype)
RCOND = np.empty(1, dtype=dtype)
WORK = np.empty(3 * _N, dtype=dtype)
IWORK = np.empty(_N, dtype=np.int32)
INFO = val_to_int_ptr(0)
numba_pocon(
UPLO,
N,
A_copy.view(w_type).ctypes,
LDA,
ANORM.view(w_type).ctypes,
RCOND.view(w_type).ctypes,
WORK.view(w_type).ctypes,
IWORK.ctypes,
INFO,
)
return RCOND, int_ptr_to_val(INFO)
return impl
def _solve_psd(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
):
"""Thin wrapper around scipy.linalg.solve for positive-definite matrices. Used as an overload target for numba to
avoid unexpected side-effects when users import pytensor."""
return linalg.solve(
A,
B,
lower=lower,
overwrite_a=overwrite_a,
overwrite_b=overwrite_b,
check_finite=check_finite,
transposed=transposed,
assume_a="pos",
)
@overload(_solve_psd)
def solve_psd_impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "solve")
_check_scipy_linalg_matrix(B, "solve")
def impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> np.ndarray:
_solve_check_input_shapes(A, B)
C, x, info = _posv(
A, B, lower, overwrite_a, overwrite_b, check_finite, transposed
)
_solve_check(A.shape[-1], info)
rcond, info = _pocon(C, _xlange(A))
_solve_check(A.shape[-1], info=info, lamch=True, rcond=rcond)
return x
return impl
@numba_funcify.register(Solve) @numba_funcify.register(Solve)
def numba_funcify_Solve(op, node, **kwargs): def numba_funcify_Solve(op, node, **kwargs):
assume_a = op.assume_a assume_a = op.assume_a
...@@ -1109,12 +117,12 @@ def numba_funcify_Solve(op, node, **kwargs): ...@@ -1109,12 +117,12 @@ def numba_funcify_Solve(op, node, **kwargs):
else: else:
warnings.warn( warnings.warn(
f"Numba assume_a={assume_a} not implemented. Falling back to general solve.\n" f"Numba assume_a={assume_a} not implemented. Falling back to general solve.\n"
f"If appropriate, you may want to set assume_a to one of 'sym', 'pos', or 'her' to improve performance.", f"If appropriate, you may want to set assume_a to one of 'sym', 'pos', 'her', or 'triangular' to improve performance.",
UserWarning, UserWarning,
) )
solve_fn = _solve_gen solve_fn = _solve_gen
@numba_basic.numba_njit(inline="always") @numba_njit
def solve(a, b): def solve(a, b):
if check_finite: if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
...@@ -1132,74 +140,45 @@ def numba_funcify_Solve(op, node, **kwargs): ...@@ -1132,74 +140,45 @@ def numba_funcify_Solve(op, node, **kwargs):
return solve return solve
def _cho_solve( @numba_funcify.register(SolveTriangular)
C: np.ndarray, B: np.ndarray, lower: bool, overwrite_b: bool, check_finite: bool def numba_funcify_SolveTriangular(op, node, **kwargs):
): lower = op.lower
""" unit_diagonal = op.unit_diagonal
Solve a positive-definite linear system using the Cholesky decomposition. check_finite = op.check_finite
""" overwrite_b = op.overwrite_b
return linalg.cho_solve( b_ndim = op.b_ndim
(C, lower), b=B, overwrite_b=overwrite_b, check_finite=check_finite
)
@overload(_cho_solve)
def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
ensure_lapack()
_check_scipy_linalg_matrix(C, "cho_solve")
_check_scipy_linalg_matrix(B, "cho_solve")
dtype = C.dtype
w_type = _get_underlying_float(dtype)
numba_potrs = _LAPACK().numba_xpotrs(dtype)
def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
_solve_check_input_shapes(C, B)
_N = np.int32(C.shape[-1])
if C.flags.f_contiguous or C.flags.c_contiguous:
C_f = C
if C.flags.c_contiguous:
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
lower = not lower
else:
C_f = np.asfortranarray(C)
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)
B_is_1d = B.ndim == 1
if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)
NRHS = 1 if B_is_1d else int(B.shape[-1])
UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) dtype = node.inputs[0].dtype
N = val_to_int_ptr(_N) if dtype in complex_dtypes:
NRHS = val_to_int_ptr(NRHS) raise NotImplementedError(
LDA = val_to_int_ptr(_N) _COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op="Solve Triangular")
LDB = val_to_int_ptr(_N) )
INFO = val_to_int_ptr(0)
numba_potrs( @numba_njit
UPLO, def solve_triangular(a, b):
N, if check_finite:
NRHS, if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
C_f.view(w_type).ctypes, raise np.linalg.LinAlgError(
LDA, "Non-numeric values (nan or inf) in input A to solve_triangular"
B_copy.view(w_type).ctypes, )
LDB, if np.any(np.bitwise_or(np.isinf(b), np.isnan(b))):
INFO, raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input b to solve_triangular"
) )
_solve_check(_N, int_ptr_to_val(INFO)) res = _solve_triangular(
a,
b,
trans=0, # transposing is handled explicitly on the graph, so we never use this argument
lower=lower,
unit_diagonal=unit_diagonal,
overwrite_b=overwrite_b,
b_ndim=b_ndim,
)
if B_is_1d: return res
return B_copy[..., 0]
return B_copy
return impl return solve_triangular
@numba_funcify.register(CholeskySolve) @numba_funcify.register(CholeskySolve)
...@@ -1212,7 +191,7 @@ def numba_funcify_CholeskySolve(op, node, **kwargs): ...@@ -1212,7 +191,7 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
if dtype in complex_dtypes: if dtype in complex_dtypes:
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
@numba_basic.numba_njit(inline="always") @numba_njit
def cho_solve(c, b): def cho_solve(c, b):
if check_finite: if check_finite:
if np.any(np.bitwise_or(np.isinf(c), np.isnan(c))): if np.any(np.bitwise_or(np.isinf(c), np.isnan(c))):
......
...@@ -566,7 +566,8 @@ class Solve(SolveBase): ...@@ -566,7 +566,8 @@ class Solve(SolveBase):
if 1 in allowed_inplace_inputs: if 1 in allowed_inplace_inputs:
# Give preference to overwrite_b # Give preference to overwrite_b
new_props["overwrite_b"] = True new_props["overwrite_b"] = True
else: # allowed inputs == [0] # We can't overwrite_a if we're assuming tridiagonal
elif not self.assume_a == "tridiagonal": # allowed inputs == [0]
new_props["overwrite_a"] = True new_props["overwrite_a"] = True
return type(self)(**new_props) return type(self)(**new_props)
......
...@@ -12,6 +12,8 @@ from pytensor.tensor.slinalg import Cholesky, CholeskySolve, Solve, SolveTriangu ...@@ -12,6 +12,8 @@ from pytensor.tensor.slinalg import Cholesky, CholeskySolve, Solve, SolveTriangu
from tests.link.numba.test_basic import compare_numba_and_py, numba_inplace_mode from tests.link.numba.test_basic import compare_numba_and_py, numba_inplace_mode
pytestmark = pytest.mark.filterwarnings("error")
numba = pytest.importorskip("numba") numba = pytest.importorskip("numba")
floatX = config.floatX floatX = config.floatX
...@@ -22,7 +24,7 @@ rng = np.random.default_rng(42849) ...@@ -22,7 +24,7 @@ rng = np.random.default_rng(42849)
def test_lamch(): def test_lamch():
from scipy.linalg import get_lapack_funcs from scipy.linalg import get_lapack_funcs
from pytensor.link.numba.dispatch.slinalg import _xlamch from pytensor.link.numba.dispatch.linalg.utils import _xlamch
@numba.njit() @numba.njit()
def xlamch(kind): def xlamch(kind):
...@@ -45,7 +47,7 @@ def test_xlange(ord_numba, ord_scipy): ...@@ -45,7 +47,7 @@ def test_xlange(ord_numba, ord_scipy):
# xlange is called internally only, we don't dispatch pt.linalg.norm to it # xlange is called internally only, we don't dispatch pt.linalg.norm to it
from scipy import linalg from scipy import linalg
from pytensor.link.numba.dispatch.slinalg import _xlange from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange
@numba.njit() @numba.njit()
def xlange(x, ord): def xlange(x, ord):
...@@ -60,7 +62,8 @@ def test_xgecon(ord_numba, ord_scipy): ...@@ -60,7 +62,8 @@ def test_xgecon(ord_numba, ord_scipy):
# gecon is called internally only, we don't dispatch pt.linalg.norm to it # gecon is called internally only, we don't dispatch pt.linalg.norm to it
from scipy.linalg import get_lapack_funcs from scipy.linalg import get_lapack_funcs
from pytensor.link.numba.dispatch.slinalg import _xgecon, _xlange from pytensor.link.numba.dispatch.linalg.solve.general import _xgecon
from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange
@numba.njit() @numba.njit()
def gecon(x, norm): def gecon(x, norm):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论