提交 bbe663d9 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

Implement numba dispatch for all `linalg.solve` modes

上级 8e5e8a40
import ctypes
import numpy as np
from numba.core import cgutils, types
from numba.core.extending import get_cython_function_address, intrinsic
from numba.np.linalg import ensure_lapack, get_blas_kind
_PTR = ctypes.POINTER
_dbl = ctypes.c_double
_float = ctypes.c_float
_char = ctypes.c_char
_int = ctypes.c_int
_ptr_float = _PTR(_float)
_ptr_dbl = _PTR(_dbl)
_ptr_char = _PTR(_char)
_ptr_int = _PTR(_int)
def _get_lapack_ptr_and_ptr_type(dtype, name):
d = get_blas_kind(dtype)
func_name = f"{d}{name}"
float_pointer = _get_float_pointer_for_dtype(d)
lapack_ptr = get_cython_function_address("scipy.linalg.cython_lapack", func_name)
return lapack_ptr, float_pointer
def _get_underlying_float(dtype):
s_dtype = str(dtype)
out_type = s_dtype
if s_dtype == "complex64":
out_type = "float32"
elif s_dtype == "complex128":
out_type = "float64"
return np.dtype(out_type)
def _get_float_pointer_for_dtype(blas_dtype):
if blas_dtype in ["s", "c"]:
return _ptr_float
elif blas_dtype in ["d", "z"]:
return _ptr_dbl
def _get_output_ctype(dtype):
s_dtype = str(dtype)
if s_dtype in ["float32", "complex64"]:
return _float
elif s_dtype in ["float64", "complex128"]:
return _dbl
@intrinsic
def sptr_to_val(typingctx, data):
def impl(context, builder, signature, args):
val = builder.load(args[0])
return val
sig = types.float32(types.CPointer(types.float32))
return sig, impl
@intrinsic
def dptr_to_val(typingctx, data):
def impl(context, builder, signature, args):
val = builder.load(args[0])
return val
sig = types.float64(types.CPointer(types.float64))
return sig, impl
@intrinsic
def int_ptr_to_val(typingctx, data):
def impl(context, builder, signature, args):
val = builder.load(args[0])
return val
sig = types.int32(types.CPointer(types.int32))
return sig, impl
@intrinsic
def val_to_int_ptr(typingctx, data):
def impl(context, builder, signature, args):
ptr = cgutils.alloca_once_value(builder, args[0])
return ptr
sig = types.CPointer(types.int32)(types.int32)
return sig, impl
@intrinsic
def val_to_sptr(typingctx, data):
def impl(context, builder, signature, args):
ptr = cgutils.alloca_once_value(builder, args[0])
return ptr
sig = types.CPointer(types.float32)(types.float32)
return sig, impl
@intrinsic
def val_to_zptr(typingctx, data):
def impl(context, builder, signature, args):
ptr = cgutils.alloca_once_value(builder, args[0])
return ptr
sig = types.CPointer(types.complex128)(types.complex128)
return sig, impl
@intrinsic
def val_to_dptr(typingctx, data):
def impl(context, builder, signature, args):
ptr = cgutils.alloca_once_value(builder, args[0])
return ptr
sig = types.CPointer(types.float64)(types.float64)
return sig, impl
class _LAPACK:
"""
Functions to return type signatures for wrapped LAPACK functions.
Patterned after https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L74
"""
def __init__(self):
ensure_lapack()
@classmethod
def numba_xtrtrs(cls, dtype):
"""
Solve a triangular system of equations of the form A @ X = B or A.T @ X = B.
Called by scipy.linalg.solve_triangular
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "trtrs")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # UPLO
_ptr_int, # TRANS
_ptr_int, # DIAG
_ptr_int, # N
_ptr_int, # NRHS
float_pointer, # A
_ptr_int, # LDA
float_pointer, # B
_ptr_int, # LDB
_ptr_int, # INFO
)
return functype(lapack_ptr)
@classmethod
def numba_xpotrf(cls, dtype):
"""
Compute the Cholesky factorization of a real symmetric positive definite matrix.
Called by scipy.linalg.cholesky
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrf")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # UPLO,
_ptr_int, # N
float_pointer, # A
_ptr_int, # LDA
_ptr_int, # INFO
)
return functype(lapack_ptr)
@classmethod
def numba_xpotrs(cls, dtype):
"""
Solve a system of linear equations A @ X = B with a symmetric positive definite matrix A using the Cholesky
factorization computed by numba_potrf.
Called by scipy.linalg.cho_solve
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrs")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # UPLO
_ptr_int, # N
_ptr_int, # NRHS
float_pointer, # A
_ptr_int, # LDA
float_pointer, # B
_ptr_int, # LDB
_ptr_int, # INFO
)
return functype(lapack_ptr)
@classmethod
def numba_xlange(cls, dtype):
"""
Compute the value of the 1-norm, Frobenius norm, infinity-norm, or the largest absolute value of any element of
a general M-by-N matrix A.
Called by scipy.linalg.solve
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "lange")
output_ctype = _get_output_ctype(dtype)
functype = ctypes.CFUNCTYPE(
output_ctype, # Output
_ptr_int, # NORM
_ptr_int, # M
_ptr_int, # N
float_pointer, # A
_ptr_int, # LDA
float_pointer, # WORK
)
return functype(lapack_ptr)
@classmethod
def numba_xlamch(cls, dtype):
"""
Determine machine precision for floating point arithmetic.
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "lamch")
output_dtype = _get_output_ctype(dtype)
functype = ctypes.CFUNCTYPE(
output_dtype, # Output
_ptr_int, # CMACH
)
return functype(lapack_ptr)
@classmethod
def numba_xgecon(cls, dtype):
"""
Estimates the condition number of a matrix A, using the LU factorization computed by numba_getrf.
Called by scipy.linalg.solve when assume_a == "gen"
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "gecon")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # NORM
_ptr_int, # N
float_pointer, # A
_ptr_int, # LDA
float_pointer, # ANORM
float_pointer, # RCOND
float_pointer, # WORK
_ptr_int, # IWORK
_ptr_int, # INFO
)
return functype(lapack_ptr)
@classmethod
def numba_xgetrf(cls, dtype):
"""
Compute partial pivoting LU factorization of a general M-by-N matrix A using row interchanges.
Called by scipy.linalg.lu_factor
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "getrf")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # M
_ptr_int, # N
float_pointer, # A
_ptr_int, # LDA
_ptr_int, # IPIV
_ptr_int, # INFO
)
return functype(lapack_ptr)
@classmethod
def numba_xgetrs(cls, dtype):
"""
Solve a system of linear equations A @ X = B or A.T @ X = B with a general N-by-N matrix A using the LU
factorization computed by GETRF.
Called by scipy.linalg.lu_solve
"""
...
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "getrs")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # TRANS
_ptr_int, # N
_ptr_int, # NRHS
float_pointer, # A
_ptr_int, # LDA
_ptr_int, # IPIV
float_pointer, # B
_ptr_int, # LDB
_ptr_int, # INFO
)
return functype(lapack_ptr)
@classmethod
def numba_xsysv(cls, dtype):
"""
Solve a system of linear equations A @ X = B with a symmetric matrix A using the diagonal pivoting method,
factorizing A into LDL^T or UDU^T form, depending on the value of UPLO
Called by scipy.linalg.solve when assume_a == "sym"
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "sysv")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # UPLO
_ptr_int, # N
_ptr_int, # NRHS
float_pointer, # A
_ptr_int, # LDA
_ptr_int, # IPIV
float_pointer, # B
_ptr_int, # LDB
float_pointer, # WORK
_ptr_int, # LWORK
_ptr_int, # INFO
)
return functype(lapack_ptr)
@classmethod
def numba_xsycon(cls, dtype):
"""
Estimate the reciprocal of the condition number of a symmetric matrix A using the UDU or LDL factorization
computed by xSYTRF.
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "sycon")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # UPLO
_ptr_int, # N
float_pointer, # A
_ptr_int, # LDA
_ptr_int, # IPIV
float_pointer, # ANORM
float_pointer, # RCOND
float_pointer, # WORK
_ptr_int, # IWORK
_ptr_int, # INFO
)
return functype(lapack_ptr)
@classmethod
def numba_xpocon(cls, dtype):
"""
Estimates the reciprocal of the condition number of a positive definite matrix A using the Cholesky factorization
computed by potrf.
Called by scipy.linalg.solve when assume_a == "pos"
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "pocon")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # UPLO
_ptr_int, # N
float_pointer, # A
_ptr_int, # LDA
float_pointer, # ANORM
float_pointer, # RCOND
float_pointer, # WORK
_ptr_int, # IWORK
_ptr_int, # INFO
)
return functype(lapack_ptr)
@classmethod
def numba_xposv(cls, dtype):
"""
Solve a system of linear equations A @ X = B with a symmetric positive definite matrix A using the Cholesky
factorization computed by potrf.
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "posv")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # UPLO
_ptr_int, # N
_ptr_int, # NRHS
float_pointer, # A
_ptr_int, # LDA
float_pointer, # B
_ptr_int, # LDB
_ptr_int, # INFO
)
return functype(lapack_ptr)
......@@ -367,7 +367,7 @@ def numba_typify(data, dtype=None, **kwargs):
def generate_fallback_impl(op, node=None, storage_map=None, **kwargs):
"""Create a Numba compatible function from an Aesara `Op`."""
"""Create a Numba compatible function from a Pytensor `Op`."""
warnings.warn(
f"Numba will use object mode to run {op}'s perform method",
......
import ctypes
from collections.abc import Callable
import numba
import numpy as np
from numba.core import cgutils, types
from numba.extending import get_cython_function_address, intrinsic, overload
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack, get_blas_kind
from numba.core import types
from numba.extending import overload
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from numpy.linalg import LinAlgError
from scipy import linalg
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.basic import numba_funcify
from pytensor.tensor.slinalg import BlockDiagonal, Cholesky, SolveTriangular
from pytensor.tensor.slinalg import (
BlockDiagonal,
Cholesky,
CholeskySolve,
Solve,
SolveTriangular,
)
_PTR = ctypes.POINTER
_dbl = ctypes.c_double
_float = ctypes.c_float
_char = ctypes.c_char
_int = ctypes.c_int
_ptr_float = _PTR(_float)
_ptr_dbl = _PTR(_dbl)
_ptr_char = _PTR(_char)
_ptr_int = _PTR(_int)
@numba.core.extending.register_jitable
def _check_finite_matrix(a, func_name):
for v in np.nditer(a):
if not np.isfinite(v.item()):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input to " + func_name
@numba_basic.numba_njit(inline="always")
def _solve_check(n, info, lamch=False, rcond=None):
"""
Check arguments during the different steps of the solution phase
Adapted from https://github.com/scipy/scipy/blob/7f7f04caa4a55306a9c6613c89eef91fedbd72d4/scipy/linalg/_basic.py#L38
"""
if info < 0:
# TODO: figure out how to do an fstring here
msg = "LAPACK reported an illegal value in input"
raise ValueError(msg)
elif 0 < info:
raise LinAlgError("Matrix is singular.")
if lamch:
E = _xlamch("E")
if rcond < E:
# TODO: This should be a warning, but we can't raise warnings in numba mode
print( # noqa: T201
"Ill-conditioned matrix, rcond=", rcond, ", result may not be accurate."
)
@intrinsic
def val_to_dptr(typingctx, data):
def impl(context, builder, signature, args):
ptr = cgutils.alloca_once_value(builder, args[0])
return ptr
sig = types.CPointer(types.float64)(types.float64)
return sig, impl
@intrinsic
def val_to_zptr(typingctx, data):
def impl(context, builder, signature, args):
ptr = cgutils.alloca_once_value(builder, args[0])
return ptr
sig = types.CPointer(types.complex128)(types.complex128)
return sig, impl
@intrinsic
def val_to_sptr(typingctx, data):
def impl(context, builder, signature, args):
ptr = cgutils.alloca_once_value(builder, args[0])
return ptr
sig = types.CPointer(types.float32)(types.float32)
return sig, impl
@intrinsic
def val_to_int_ptr(typingctx, data):
def impl(context, builder, signature, args):
ptr = cgutils.alloca_once_value(builder, args[0])
return ptr
sig = types.CPointer(types.int32)(types.int32)
return sig, impl
@intrinsic
def int_ptr_to_val(typingctx, data):
def impl(context, builder, signature, args):
val = builder.load(args[0])
return val
sig = types.int32(types.CPointer(types.int32))
return sig, impl
@intrinsic
def dptr_to_val(typingctx, data):
def impl(context, builder, signature, args):
val = builder.load(args[0])
return val
sig = types.float64(types.CPointer(types.float64))
return sig, impl
@intrinsic
def sptr_to_val(typingctx, data):
def impl(context, builder, signature, args):
val = builder.load(args[0])
return val
sig = types.float32(types.CPointer(types.float32))
return sig, impl
def _get_float_pointer_for_dtype(blas_dtype):
if blas_dtype in ["s", "c"]:
return _ptr_float
elif blas_dtype in ["d", "z"]:
return _ptr_dbl
def _get_underlying_float(dtype):
s_dtype = str(dtype)
out_type = s_dtype
if s_dtype == "complex64":
out_type = "float32"
elif s_dtype == "complex128":
out_type = "float64"
return np.dtype(out_type)
def _get_lapack_ptr_and_ptr_type(dtype, name):
d = get_blas_kind(dtype)
func_name = f"{d}{name}"
float_pointer = _get_float_pointer_for_dtype(d)
lapack_ptr = get_cython_function_address("scipy.linalg.cython_lapack", func_name)
return lapack_ptr, float_pointer
def _check_scipy_linalg_matrix(a, func_name):
"""
Adapted from https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L831
......@@ -152,64 +68,50 @@ def _check_scipy_linalg_matrix(a, func_name):
raise numba.TypingError(msg, highlighting=False)
class _LAPACK:
def _solve_triangular(
A, B, trans=0, lower=False, unit_diagonal=False, b_ndim=1, overwrite_b=False
):
"""
Functions to return type signatures for wrapped LAPACK functions.
Thin wrapper around scipy.linalg.solve_triangular.
Patterned after https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L74
"""
def __init__(self):
ensure_lapack()
This function is overloaded instead of the original scipy function to avoid unexpected side-effects to users who
import pytensor.
@classmethod
def numba_xtrtrs(cls, dtype):
"""
Called by scipy.linalg.solve_triangular
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "trtrs")
The signature must be the same as solve_triangular_impl, so b_ndim is included, although this argument is not
used by scipy.linalg.solve_triangular.
"""
return linalg.solve_triangular(
A,
B,
trans=trans,
lower=lower,
unit_diagonal=unit_diagonal,
overwrite_b=overwrite_b,
)
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # UPLO
_ptr_int, # TRANS
_ptr_int, # DIAG
_ptr_int, # N
_ptr_int, # NRHS
float_pointer, # A
_ptr_int, # LDA
float_pointer, # B
_ptr_int, # LDB
_ptr_int, # INFO
)
return functype(lapack_ptr)
@classmethod
def numba_xpotrf(cls, dtype):
"""
Called by scipy.linalg.cholesky
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrf")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # UPLO,
_ptr_int, # N
float_pointer, # A
_ptr_int, # LDA
_ptr_int, # INFO
)
return functype(lapack_ptr)
@numba_basic.numba_njit(inline="always")
def _trans_char_to_int(trans):
if trans not in [0, 1, 2]:
raise ValueError('Parameter "trans" should be one of 0, 1, 2')
if trans == 0:
return ord("N")
elif trans == 1:
return ord("T")
else:
return ord("C")
def _solve_triangular(A, B, trans=0, lower=False, unit_diagonal=False):
return linalg.solve_triangular(
A, B, trans=trans, lower=lower, unit_diagonal=unit_diagonal
)
@numba_basic.numba_njit(inline="always")
def _solve_check_input_shapes(A, B):
if A.shape[0] != B.shape[0]:
raise linalg.LinAlgError("Dimensions of A and B do not conform")
if A.shape[-2] != A.shape[-1]:
raise linalg.LinAlgError("Last 2 dimensions of A must be square")
@overload(_solve_triangular)
def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False):
def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
ensure_lapack()
_check_scipy_linalg_matrix(A, "solve_triangular")
......@@ -218,37 +120,27 @@ def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False):
w_type = _get_underlying_float(dtype)
numba_trtrs = _LAPACK().numba_xtrtrs(dtype)
def impl(A, B, trans=0, lower=False, unit_diagonal=False):
B_is_1d = B.ndim == 1
def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
_N = np.int32(A.shape[-1])
if A.shape[-2] != _N:
raise linalg.LinAlgError("Last 2 dimensions of A must be square")
_solve_check_input_shapes(A, B)
if A.shape[0] != B.shape[0]:
raise linalg.LinAlgError("Dimensions of A and B do not conform")
B_is_1d = B.ndim == 1
if B_is_1d:
B_copy = np.asfortranarray(np.expand_dims(B, -1))
else:
if not overwrite_b:
B_copy = _copy_to_fortran_order(B)
if trans not in [0, 1, 2]:
raise ValueError('Parameter "trans" should be one of N, C, T or 0, 1, 2')
if trans == 0:
transval = ord("N")
elif trans == 1:
transval = ord("T")
else:
transval = ord("C")
B_copy = B
B_NDIM = 1 if B_is_1d else int(B.shape[1])
if B_is_1d:
B_copy = np.expand_dims(B, -1)
NRHS = 1 if B_is_1d else int(B_copy.shape[-1])
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
TRANS = val_to_int_ptr(transval)
TRANS = val_to_int_ptr(_trans_char_to_int(trans))
DIAG = val_to_int_ptr(ord("U") if unit_diagonal else ord("N"))
N = val_to_int_ptr(_N)
NRHS = val_to_int_ptr(B_NDIM)
NRHS = val_to_int_ptr(NRHS)
LDA = val_to_int_ptr(_N)
LDB = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)
......@@ -266,19 +158,24 @@ def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False):
INFO,
)
_solve_check(int_ptr_to_val(LDA), int_ptr_to_val(INFO))
if B_is_1d:
return B_copy[..., 0], int_ptr_to_val(INFO)
return B_copy, int_ptr_to_val(INFO)
return B_copy[..., 0]
return B_copy
return impl
@numba_funcify.register(SolveTriangular)
def numba_funcify_SolveTriangular(op, node, **kwargs):
trans = op.trans
trans = bool(op.trans)
lower = op.lower
unit_diagonal = op.unit_diagonal
check_finite = op.check_finite
overwrite_b = op.overwrite_b
b_ndim = op.b_ndim
dtype = node.inputs[0].dtype
if str(dtype).startswith("complex"):
......@@ -298,11 +195,16 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
"Non-numeric values (nan or inf) in input b to solve_triangular"
)
res, info = _solve_triangular(a, b, trans, lower, unit_diagonal)
if info != 0:
raise np.linalg.LinAlgError(
"Singular matrix in input A to solve_triangular"
)
res = _solve_triangular(
a,
b,
trans=trans,
lower=lower,
unit_diagonal=unit_diagonal,
overwrite_b=overwrite_b,
b_ndim=b_ndim,
)
return res
return solve_triangular
......@@ -429,3 +331,853 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs):
return out
return block_diag
def _xlamch(kind: str = "E"):
"""
Placeholder for getting machine precision; used by linalg.solve. Not used by pytensor to numbify graphs.
"""
pass
@overload(_xlamch)
def xlamch_impl(kind: str = "E") -> Callable[[str], float]:
"""
Compute the machine precision for a given floating point type.
"""
from pytensor import config
ensure_lapack()
w_type = _get_underlying_float(config.floatX)
if w_type == "float32":
dtype = types.float32
elif w_type == "float64":
dtype = types.float64
else:
raise NotImplementedError("Unsupported dtype")
numba_lamch = _LAPACK().numba_xlamch(dtype)
def impl(kind: str = "E") -> float:
KIND = val_to_int_ptr(ord(kind))
return numba_lamch(KIND) # type: ignore
return impl
def _xlange(A: np.ndarray, order: str | None = None) -> float:
"""
Placeholder for computing the norm of a matrix; used by linalg.solve. Will never be called in python mode.
"""
return # type: ignore
@overload(_xlange)
def xlange_impl(
A: np.ndarray, order: str | None = None
) -> Callable[[np.ndarray, str], float]:
"""
xLANGE returns the value of the one norm, or the Frobenius norm, or the infinity norm, or the element of
largest absolute value of a matrix A.
"""
ensure_lapack()
_check_scipy_linalg_matrix(A, "norm")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_lange = _LAPACK().numba_xlange(dtype)
def impl(A: np.ndarray, order: str | None = None):
_M, _N = np.int32(A.shape[-2:]) # type: ignore
A_copy = _copy_to_fortran_order(A)
M = val_to_int_ptr(_M) # type: ignore
N = val_to_int_ptr(_N) # type: ignore
LDA = val_to_int_ptr(_M) # type: ignore
NORM = (
val_to_int_ptr(ord(order))
if order is not None
else val_to_int_ptr(ord("1"))
)
WORK = np.empty(_M, dtype=dtype) # type: ignore
result = numba_lange(
NORM, M, N, A_copy.view(w_type).ctypes, LDA, WORK.view(w_type).ctypes
)
return result
return impl
def _xgecon(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]:
"""
Placeholder for computing the condition number of a matrix; used by linalg.solve. Not used by pytensor to numbify
graphs.
"""
return # type: ignore
@overload(_xgecon)
def xgecon_impl(
A: np.ndarray, A_norm: float, norm: str
) -> Callable[[np.ndarray, float, str], tuple[np.ndarray, int]]:
"""
Compute the condition number of a matrix A.
"""
ensure_lapack()
_check_scipy_linalg_matrix(A, "gecon")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_gecon = _LAPACK().numba_xgecon(dtype)
def impl(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]:
_N = np.int32(A.shape[-1])
A_copy = _copy_to_fortran_order(A)
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
A_NORM = np.array(A_norm, dtype=dtype)
NORM = val_to_int_ptr(ord(norm))
RCOND = np.empty(1, dtype=dtype)
WORK = np.empty(4 * _N, dtype=dtype)
IWORK = np.empty(_N, dtype=np.int32)
INFO = val_to_int_ptr(1)
numba_gecon(
NORM,
N,
A_copy.view(w_type).ctypes,
LDA,
A_NORM.view(w_type).ctypes,
RCOND.view(w_type).ctypes,
WORK.view(w_type).ctypes,
IWORK.ctypes,
INFO,
)
return RCOND, int_ptr_to_val(INFO)
return impl
def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]:
"""
Placeholder for LU factorization; used by linalg.solve.
# TODO: Implement an LU_factor Op, then dispatch to this function in numba mode.
"""
return # type: ignore
@overload(_getrf)
def getrf_impl(
A: np.ndarray, overwrite_a: bool = False
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray, int]]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "getrf")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_getrf = _LAPACK().numba_xgetrf(dtype)
def impl(
A: np.ndarray, overwrite_a: bool = False
) -> tuple[np.ndarray, np.ndarray, int]:
_M, _N = np.int32(A.shape[-2:]) # type: ignore
if not overwrite_a:
A_copy = _copy_to_fortran_order(A)
else:
A_copy = A
M = val_to_int_ptr(_M) # type: ignore
N = val_to_int_ptr(_N) # type: ignore
LDA = val_to_int_ptr(_M) # type: ignore
IPIV = np.empty(_N, dtype=np.int32) # type: ignore
INFO = val_to_int_ptr(0)
numba_getrf(M, N, A_copy.view(w_type).ctypes, LDA, IPIV.ctypes, INFO)
return A_copy, IPIV, int_ptr_to_val(INFO)
return impl
def _getrs(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
) -> tuple[np.ndarray, int]:
"""
Placeholder for solving a linear system with a matrix that has been LU-factored; used by linalg.solve.
# TODO: Implement an LU_solve Op, then dispatch to this function in numba mode.
"""
return # type: ignore
@overload(_getrs)
def getrs_impl(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, int, bool], tuple[np.ndarray, int]]:
ensure_lapack()
_check_scipy_linalg_matrix(LU, "getrs")
_check_scipy_linalg_matrix(B, "getrs")
dtype = LU.dtype
w_type = _get_underlying_float(dtype)
numba_getrs = _LAPACK().numba_xgetrs(dtype)
def impl(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
) -> tuple[np.ndarray, int]:
_N = np.int32(LU.shape[-1])
_solve_check_input_shapes(LU, B)
B_is_1d = B.ndim == 1
if not overwrite_b:
B_copy = _copy_to_fortran_order(B)
else:
B_copy = B
if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)
NRHS = 1 if B_is_1d else int(B_copy.shape[-1])
TRANS = val_to_int_ptr(_trans_char_to_int(trans))
N = val_to_int_ptr(_N)
NRHS = val_to_int_ptr(NRHS)
LDA = val_to_int_ptr(_N)
LDB = val_to_int_ptr(_N)
IPIV = _copy_to_fortran_order(IPIV)
INFO = val_to_int_ptr(0)
numba_getrs(
TRANS,
N,
NRHS,
LU.view(w_type).ctypes,
LDA,
IPIV.ctypes,
B_copy.view(w_type).ctypes,
LDB,
INFO,
)
if B_is_1d:
return B_copy[..., 0], int_ptr_to_val(INFO)
return B_copy, int_ptr_to_val(INFO)
return impl
def _solve_gen(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
):
"""Thin wrapper around scipy.linalg.solve. Used as an overload target for numba to avoid unexpected side-effects
for users who import pytensor."""
return linalg.solve(
A,
B,
lower=lower,
overwrite_a=overwrite_a,
overwrite_b=overwrite_b,
check_finite=check_finite,
assume_a="gen",
transposed=transposed,
)
@overload(_solve_gen)
def solve_gen_impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "solve")
_check_scipy_linalg_matrix(B, "solve")
def impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> np.ndarray:
_N = np.int32(A.shape[-1])
_solve_check_input_shapes(A, B)
order = "I" if transposed else "1"
norm = _xlange(A, order=order)
N = A.shape[1]
LU, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a)
_solve_check(N, INFO)
X, INFO = _getrs(
LU=LU, B=B, IPIV=IPIV, trans=transposed, overwrite_b=overwrite_b
)
_solve_check(N, INFO)
RCOND, INFO = _xgecon(LU, norm, "1")
_solve_check(N, INFO, True, RCOND)
return X
return impl
def _sysv(
A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool
) -> tuple[np.ndarray, np.ndarray, int]:
"""
Placeholder for solving a linear system with a symmetric matrix; used by linalg.solve.
"""
return # type: ignore
@overload(_sysv)
def sysv_impl(
A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool
) -> Callable[
[np.ndarray, np.ndarray, bool, bool, bool], tuple[np.ndarray, np.ndarray, int]
]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "sysv")
_check_scipy_linalg_matrix(B, "sysv")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_sysv = _LAPACK().numba_xsysv(dtype)
def impl(
A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool
):
_LDA, _N = np.int32(A.shape[-2:]) # type: ignore
_solve_check_input_shapes(A, B)
if not overwrite_a:
A_copy = _copy_to_fortran_order(A)
else:
A_copy = A
B_is_1d = B.ndim == 1
if not overwrite_b:
B_copy = _copy_to_fortran_order(B)
else:
B_copy = B
if B_is_1d:
B_copy = np.asfortranarray(np.expand_dims(B_copy, -1))
NRHS = 1 if B_is_1d else int(B.shape[-1])
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
N = val_to_int_ptr(_N) # type: ignore
NRHS = val_to_int_ptr(NRHS)
LDA = val_to_int_ptr(_LDA) # type: ignore
IPIV = np.empty(_N, dtype=np.int32) # type: ignore
LDB = val_to_int_ptr(_N) # type: ignore
WORK = np.empty(1, dtype=dtype)
LWORK = val_to_int_ptr(-1)
INFO = val_to_int_ptr(0)
# Workspace query
numba_sysv(
UPLO,
N,
NRHS,
A_copy.view(w_type).ctypes,
LDA,
IPIV.ctypes,
B_copy.view(w_type).ctypes,
LDB,
WORK.view(w_type).ctypes,
LWORK,
INFO,
)
WS_SIZE = np.int32(WORK[0].real)
LWORK = val_to_int_ptr(WS_SIZE)
WORK = np.empty(WS_SIZE, dtype=dtype)
# Actual solve
numba_sysv(
UPLO,
N,
NRHS,
A_copy.view(w_type).ctypes,
LDA,
IPIV.ctypes,
B_copy.view(w_type).ctypes,
LDB,
WORK.view(w_type).ctypes,
LWORK,
INFO,
)
if B_is_1d:
return B_copy[..., 0], IPIV, int_ptr_to_val(INFO)
return B_copy, IPIV, int_ptr_to_val(INFO)
return impl
def _sycon(A: np.ndarray, ipiv: np.ndarray, anorm: float) -> tuple[np.ndarray, int]:
"""
Placeholder for computing the condition number of a symmetric matrix; used by linalg.solve. Never called in
python mode.
"""
return # type: ignore
@overload(_sycon)
def sycon_impl(
A: np.ndarray, ipiv: np.ndarray, anorm: float
) -> Callable[[np.ndarray, np.ndarray, float], tuple[np.ndarray, int]]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "sycon")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_sycon = _LAPACK().numba_xsycon(dtype)
def impl(A: np.ndarray, ipiv: np.ndarray, anorm: float) -> tuple[np.ndarray, int]:
_N = np.int32(A.shape[-1])
A_copy = _copy_to_fortran_order(A)
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
UPLO = val_to_int_ptr(ord("L"))
ANORM = np.array(anorm, dtype=dtype)
RCOND = np.empty(1, dtype=dtype)
WORK = np.empty(2 * _N, dtype=dtype)
IWORK = np.empty(_N, dtype=np.int32)
INFO = val_to_int_ptr(0)
numba_sycon(
UPLO,
N,
A_copy.view(w_type).ctypes,
LDA,
ipiv.ctypes,
ANORM.view(w_type).ctypes,
RCOND.view(w_type).ctypes,
WORK.view(w_type).ctypes,
IWORK.ctypes,
INFO,
)
return RCOND, int_ptr_to_val(INFO)
return impl
def _solve_symmetric(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
):
"""Thin wrapper around scipy.linalg.solve for symmetric matrices. Used as an overload target for numba to avoid
unexpected side-effects when users import pytensor."""
return linalg.solve(
A,
B,
lower=lower,
overwrite_a=overwrite_a,
overwrite_b=overwrite_b,
check_finite=check_finite,
assume_a="sym",
transposed=transposed,
)
@overload(_solve_symmetric)
def solve_symmetric_impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "solve")
_check_scipy_linalg_matrix(B, "solve")
def impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> np.ndarray:
_solve_check_input_shapes(A, B)
x, ipiv, info = _sysv(A, B, lower, overwrite_a, overwrite_b)
_solve_check(A.shape[-1], info)
rcond, info = _sycon(A, ipiv, _xlange(A, order="I"))
_solve_check(A.shape[-1], info, True, rcond)
return x
return impl
def _posv(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> tuple[np.ndarray, int]:
"""
Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve.
"""
return # type: ignore
@overload(_posv)
def posv_impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> Callable[
[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], tuple[np.ndarray, int]
]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "solve")
_check_scipy_linalg_matrix(B, "solve")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_posv = _LAPACK().numba_xposv(dtype)
def impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> tuple[np.ndarray, int]:
_solve_check_input_shapes(A, B)
_N = np.int32(A.shape[-1])
if not overwrite_a:
A_copy = _copy_to_fortran_order(A)
else:
A_copy = A
B_is_1d = B.ndim == 1
if not overwrite_b:
B_copy = _copy_to_fortran_order(B)
else:
B_copy = B
if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
NRHS = 1 if B_is_1d else int(B.shape[-1])
N = val_to_int_ptr(_N)
NRHS = val_to_int_ptr(NRHS)
LDA = val_to_int_ptr(_N)
LDB = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)
numba_posv(
UPLO,
N,
NRHS,
A_copy.view(w_type).ctypes,
LDA,
B_copy.view(w_type).ctypes,
LDB,
INFO,
)
if B_is_1d:
return B_copy[..., 0], int_ptr_to_val(INFO)
return B_copy, int_ptr_to_val(INFO)
return impl
def _pocon(A: np.ndarray, anorm: float) -> tuple[np.ndarray, int]:
"""
Placeholder for computing the condition number of a cholesky-factorized positive-definite matrix. Used by
linalg.solve when assume_a = "pos".
"""
return # type: ignore
@overload(_pocon)
def pocon_impl(
A: np.ndarray, anorm: float
) -> Callable[[np.ndarray, float], tuple[np.ndarray, int]]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "pocon")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_pocon = _LAPACK().numba_xpocon(dtype)
def impl(A: np.ndarray, anorm: float):
_N = np.int32(A.shape[-1])
A_copy = _copy_to_fortran_order(A)
UPLO = val_to_int_ptr(ord("L"))
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
ANORM = np.array(anorm, dtype=dtype)
RCOND = np.empty(1, dtype=dtype)
WORK = np.empty(3 * _N, dtype=dtype)
IWORK = np.empty(_N, dtype=np.int32)
INFO = val_to_int_ptr(0)
numba_pocon(
UPLO,
N,
A_copy.view(w_type).ctypes,
LDA,
ANORM.view(w_type).ctypes,
RCOND.view(w_type).ctypes,
WORK.view(w_type).ctypes,
IWORK.ctypes,
INFO,
)
return RCOND, int_ptr_to_val(INFO)
return impl
def _solve_psd(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
):
"""Thin wrapper around scipy.linalg.solve for positive-definite matrices. Used as an overload target for numba to
avoid unexpected side-effects when users import pytensor."""
return linalg.solve(
A,
B,
lower=lower,
overwrite_a=overwrite_a,
overwrite_b=overwrite_b,
check_finite=check_finite,
transposed=transposed,
assume_a="pos",
)
@overload(_solve_psd)
def solve_psd_impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "solve")
_check_scipy_linalg_matrix(B, "solve")
def impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> np.ndarray:
_solve_check_input_shapes(A, B)
x, info = _posv(A, B, lower, overwrite_a, overwrite_b, check_finite, transposed)
_solve_check(A.shape[-1], info)
rcond, info = _pocon(x, _xlange(A))
_solve_check(A.shape[-1], info=info, lamch=True, rcond=rcond)
return x
return impl
@numba_funcify.register(Solve)
def numba_funcify_Solve(op, node, **kwargs):
assume_a = op.assume_a
lower = op.lower
check_finite = op.check_finite
overwrite_a = op.overwrite_a
overwrite_b = op.overwrite_b
transposed = False # TODO: Solve doesnt currently allow the transposed argument
dtype = node.inputs[0].dtype
if str(dtype).startswith("complex"):
raise NotImplementedError(
"Complex inputs not currently supported by solve in Numba mode"
)
if assume_a == "gen":
solve_fn = _solve_gen
elif assume_a == "sym":
solve_fn = _solve_symmetric
elif assume_a == "her":
raise NotImplementedError(
'Use assume_a = "sym" for symmetric real matrices. If you need compelx support, '
"please open an issue on github."
)
elif assume_a == "pos":
solve_fn = _solve_psd
else:
raise NotImplementedError(f"Assumption {assume_a} not supported in Numba mode")
@numba_basic.numba_njit(inline="always")
def solve(a, b):
if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input A to solve"
)
if np.any(np.bitwise_or(np.isinf(b), np.isnan(b))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input b to solve"
)
res = solve_fn(a, b, lower, overwrite_a, overwrite_b, check_finite, transposed)
return res
return solve
def _cho_solve(A_and_lower, B, overwrite_a=False, overwrite_b=False, check_finite=True):
"""
Solve a positive-definite linear system using the Cholesky decomposition.
"""
A, lower = A_and_lower
return linalg.cho_solve((A, lower), B)
@overload(_cho_solve)
def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
ensure_lapack()
_check_scipy_linalg_matrix(C, "cho_solve")
_check_scipy_linalg_matrix(B, "cho_solve")
dtype = C.dtype
w_type = _get_underlying_float(dtype)
numba_potrs = _LAPACK().numba_xpotrs(dtype)
def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
_solve_check_input_shapes(C, B)
_N = np.int32(C.shape[-1])
C_copy = _copy_to_fortran_order(C)
B_is_1d = B.ndim == 1
if B_is_1d:
B_copy = np.asfortranarray(np.expand_dims(B, -1))
else:
B_copy = _copy_to_fortran_order(B)
NRHS = 1 if B_is_1d else int(B.shape[-1])
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
N = val_to_int_ptr(_N)
NRHS = val_to_int_ptr(NRHS)
LDA = val_to_int_ptr(_N)
LDB = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)
numba_potrs(
UPLO,
N,
NRHS,
C_copy.view(w_type).ctypes,
LDA,
B_copy.view(w_type).ctypes,
LDB,
INFO,
)
if B_is_1d:
return B_copy[..., 0], int_ptr_to_val(INFO)
return B_copy, int_ptr_to_val(INFO)
return impl
@numba_funcify.register(CholeskySolve)
def numba_funcify_CholeskySolve(op, node, **kwargs):
lower = op.lower
overwrite_b = op.overwrite_b
check_finite = op.check_finite
dtype = node.inputs[0].dtype
if str(dtype).startswith("complex"):
raise NotImplementedError(
"Complex inputs not currently supported by cho_solve in Numba mode"
)
@numba_basic.numba_njit(inline="always")
def cho_solve(c, b):
if check_finite:
if np.any(np.bitwise_or(np.isinf(c), np.isnan(c))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input A to cho_solve"
)
if np.any(np.bitwise_or(np.isinf(b), np.isnan(b))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input b to cho_solve"
)
res, info = _cho_solve(
c, b, lower=lower, overwrite_b=overwrite_b, check_finite=check_finite
)
if info < 0:
raise np.linalg.LinAlgError("Illegal values found in input to cho_solve")
elif info > 0:
raise np.linalg.LinAlgError(
"Matrix is not positive definite in input to cho_solve"
)
return res
return cho_solve
import logging
import typing
import warnings
from collections.abc import Sequence
from functools import reduce
from typing import Literal, cast
import numpy as np
import scipy.linalg
import scipy.linalg as scipy_linalg
import pytensor
import pytensor.tensor as pt
......@@ -58,7 +58,7 @@ class Cholesky(Op):
f"Cholesky only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input"
)
# Call scipy to find output dtype
dtype = scipy.linalg.cholesky(np.eye(1, dtype=x.type.dtype)).dtype
dtype = scipy_linalg.cholesky(np.eye(1, dtype=x.type.dtype)).dtype
return Apply(self, [x], [tensor(shape=x.type.shape, dtype=dtype)])
def perform(self, node, inputs, outputs):
......@@ -68,21 +68,21 @@ class Cholesky(Op):
# Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS
# If we have a `C_CONTIGUOUS` array we transpose to benefit from it
if self.overwrite_a and x.flags["C_CONTIGUOUS"]:
out[0] = scipy.linalg.cholesky(
out[0] = scipy_linalg.cholesky(
x.T,
lower=not self.lower,
check_finite=self.check_finite,
overwrite_a=True,
).T
else:
out[0] = scipy.linalg.cholesky(
out[0] = scipy_linalg.cholesky(
x,
lower=self.lower,
check_finite=self.check_finite,
overwrite_a=self.overwrite_a,
)
except scipy.linalg.LinAlgError:
except scipy_linalg.LinAlgError:
if self.on_error == "raise":
raise
else:
......@@ -334,7 +334,7 @@ class CholeskySolve(SolveBase):
def perform(self, node, inputs, output_storage):
C, b = inputs
rval = scipy.linalg.cho_solve(
rval = scipy_linalg.cho_solve(
(C, self.lower),
b,
check_finite=self.check_finite,
......@@ -369,7 +369,7 @@ def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None):
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
b_ndim : int
b_ndim : int
Whether the core case of b is a vector (1) or matrix (2).
This will influence how batched dimensions are interpreted.
"""
......@@ -401,7 +401,7 @@ class SolveTriangular(SolveBase):
def perform(self, node, inputs, outputs):
A, b = inputs
outputs[0][0] = scipy.linalg.solve_triangular(
outputs[0][0] = scipy_linalg.solve_triangular(
A,
b,
lower=self.lower,
......@@ -502,7 +502,7 @@ class Solve(SolveBase):
def perform(self, node, inputs, outputs):
a, b = inputs
outputs[0][0] = scipy.linalg.solve(
outputs[0][0] = scipy_linalg.solve(
a=a,
b=b,
lower=self.lower,
......@@ -619,9 +619,9 @@ class Eigvalsh(Op):
def perform(self, node, inputs, outputs):
(w,) = outputs
if len(inputs) == 2:
w[0] = scipy.linalg.eigvalsh(a=inputs[0], b=inputs[1], lower=self.lower)
w[0] = scipy_linalg.eigvalsh(a=inputs[0], b=inputs[1], lower=self.lower)
else:
w[0] = scipy.linalg.eigvalsh(a=inputs[0], b=None, lower=self.lower)
w[0] = scipy_linalg.eigvalsh(a=inputs[0], b=None, lower=self.lower)
def grad(self, inputs, g_outputs):
a, b = inputs
......@@ -675,7 +675,7 @@ class EigvalshGrad(Op):
def perform(self, node, inputs, outputs):
(a, b, gw) = inputs
w, v = scipy.linalg.eigh(a, b, lower=self.lower)
w, v = scipy_linalg.eigh(a, b, lower=self.lower)
gA = v.dot(np.diag(gw).dot(v.T))
gB = -v.dot(np.diag(gw * w).dot(v.T))
......@@ -718,7 +718,7 @@ class Expm(Op):
def perform(self, node, inputs, outputs):
(A,) = inputs
(expm,) = outputs
expm[0] = scipy.linalg.expm(A)
expm[0] = scipy_linalg.expm(A)
def grad(self, inputs, outputs):
(A,) = inputs
......@@ -758,8 +758,8 @@ class ExpmGrad(Op):
# this expression.
(A, gA) = inputs
(out,) = outputs
w, V = scipy.linalg.eig(A, right=True)
U = scipy.linalg.inv(V).T
w, V = scipy_linalg.eig(A, right=True)
U = scipy_linalg.inv(V).T
exp_w = np.exp(w)
X = np.subtract.outer(exp_w, exp_w) / np.subtract.outer(w, w)
......@@ -800,7 +800,7 @@ class SolveContinuousLyapunov(Op):
X = output_storage[0]
out_dtype = node.outputs[0].type.dtype
X[0] = scipy.linalg.solve_continuous_lyapunov(A, B).astype(out_dtype)
X[0] = scipy_linalg.solve_continuous_lyapunov(A, B).astype(out_dtype)
def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
......@@ -870,7 +870,7 @@ class BilinearSolveDiscreteLyapunov(Op):
X = output_storage[0]
out_dtype = node.outputs[0].type.dtype
X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear").astype(
X[0] = scipy_linalg.solve_discrete_lyapunov(A, B, method="bilinear").astype(
out_dtype
)
......@@ -992,7 +992,7 @@ class SolveDiscreteARE(Op):
Q = 0.5 * (Q + Q.T)
out_dtype = node.outputs[0].type.dtype
X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype)
X[0] = scipy_linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype)
def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
......@@ -1064,7 +1064,7 @@ def solve_discrete_are(
)
def _largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype:
def _largest_common_dtype(tensors: Sequence[TensorVariable]) -> np.dtype:
return reduce(lambda l, r: np.promote_types(l, r), [x.dtype for x in tensors])
......@@ -1118,7 +1118,7 @@ class BlockDiagonal(BaseBlockDiagonal):
def perform(self, node, inputs, output_storage, params=None):
dtype = node.outputs[0].type.dtype
output_storage[0][0] = scipy.linalg.block_diag(*inputs).astype(dtype)
output_storage[0][0] = scipy_linalg.block_diag(*inputs).astype(dtype)
def block_diag(*matrices: TensorVariable):
......@@ -1175,4 +1175,5 @@ __all__ = [
"solve_discrete_are",
"solve_triangular",
"block_diag",
"cho_solve",
]
......@@ -7,58 +7,13 @@ import pytensor.tensor as pt
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import nlinalg, slinalg
from pytensor.tensor import nlinalg
from tests.link.numba.test_basic import compare_numba_and_py, set_test_value
rng = np.random.default_rng(42849)
@pytest.mark.parametrize(
"A, x, lower, exc",
[
(
set_test_value(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
set_test_value(pt.dvector(), rng.random(size=(3,)).astype("float64")),
"gen",
None,
),
(
set_test_value(
pt.lmatrix(),
(lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64")
),
),
set_test_value(pt.dvector(), rng.random(size=(3,)).astype("float64")),
"gen",
None,
),
],
)
def test_Solve(A, x, lower, exc):
g = slinalg.Solve(lower=lower, b_ndim=1)(A, x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
)
@pytest.mark.parametrize(
"x, exc",
[
......
import re
from functools import partial
from typing import Literal
import numpy as np
import pytest
from numpy.testing import assert_allclose
from scipy import linalg as scipy_linalg
import pytensor
import pytensor.tensor as pt
from pytensor import config
from pytensor.graph import FunctionGraph
from tests import unittest_tools as utt
from tests.link.numba.test_basic import compare_numba_and_py
numba = pytest.importorskip("numba")
ATOL = 0 if config.floatX.endswith("64") else 1e-6
RTOL = 1e-7 if config.floatX.endswith("64") else 1e-6
floatX = pytensor.config.floatX
rng = np.random.default_rng(42849)
......@@ -27,8 +31,8 @@ def transpose_func(x, trans):
@pytest.mark.parametrize(
"b_func, b_size",
[(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))],
"b_shape",
[(5, 1), (5, 5), (5,)],
ids=["b_col_vec", "b_matrix", "b_vec"],
)
@pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"])
......@@ -36,50 +40,88 @@ def transpose_func(x, trans):
@pytest.mark.parametrize(
"unit_diag", [True, False], ids=["unit_diag=True", "unit_diag=False"]
)
@pytest.mark.parametrize("complex", [True, False], ids=["complex", "real"])
@pytest.mark.parametrize("is_complex", [True, False], ids=["complex", "real"])
@pytest.mark.filterwarnings(
'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
)
def test_solve_triangular(b_func, b_size, lower, trans, unit_diag, complex):
if complex:
def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_complex):
if is_complex:
# TODO: Complex raises ValueError: To change to a dtype of a different size, the last axis must be contiguous,
# why?
pytest.skip("Complex inputs currently not supported to solve_triangular")
complex_dtype = "complex64" if config.floatX.endswith("32") else "complex128"
dtype = complex_dtype if complex else config.floatX
complex_dtype = "complex64" if floatX.endswith("32") else "complex128"
dtype = complex_dtype if is_complex else floatX
A = pt.matrix("A", dtype=dtype)
b = b_func("b", dtype=dtype)
b = pt.tensor("b", shape=b_shape, dtype=dtype)
def A_func(x):
x = x @ x.conj().T
x_tri = scipy_linalg.cholesky(x, lower=lower).astype(dtype)
X = pt.linalg.solve_triangular(
A, b, lower=lower, trans=trans, unit_diagonal=unit_diag
if unit_diag:
x_tri[np.diag_indices_from(x_tri)] = 1.0
return x_tri.astype(dtype)
solve_op = partial(
pt.linalg.solve_triangular, lower=lower, trans=trans, unit_diagonal=unit_diag
)
X = solve_op(A, b)
f = pytensor.function([A, b], X, mode="NUMBA")
A_val = np.random.normal(size=(5, 5))
b = np.random.normal(size=b_size)
b_val = np.random.normal(size=b_shape)
if complex:
if is_complex:
A_val = A_val + np.random.normal(size=(5, 5)) * 1j
b = b + np.random.normal(size=b_size) * 1j
A_sym = A_val @ A_val.conj().T
b_val = b_val + np.random.normal(size=b_shape) * 1j
A_tri = np.linalg.cholesky(A_sym).astype(dtype)
if unit_diag:
adj_mat = np.ones((5, 5))
adj_mat[np.diag_indices(5)] = 1 / np.diagonal(A_tri)
A_tri = A_tri * adj_mat
X_np = f(A_func(A_val.copy()), b_val.copy())
A_tri = A_tri.astype(dtype)
b = b.astype(dtype)
test_input = transpose_func(A_func(A_val.copy()), trans)
if not lower:
A_tri = A_tri.T
ATOL = 1e-8 if floatX.endswith("64") else 1e-4
RTOL = 1e-8 if floatX.endswith("64") else 1e-4
X_np = f(A_tri, b)
np.testing.assert_allclose(
transpose_func(A_tri, trans) @ X_np, b, atol=ATOL, rtol=RTOL
np.testing.assert_allclose(test_input @ X_np, b_val, atol=ATOL, rtol=RTOL)
compare_numba_and_py(f.maker.fgraph, [A_func(A_val.copy()), b_val.copy()])
@pytest.mark.parametrize(
"lower, unit_diag, trans",
[(True, True, True), (False, False, False)],
ids=["lower_unit_trans", "defaults"],
)
def test_solve_triangular_grad(lower, unit_diag, trans):
A_val = np.random.normal(size=(5, 5)).astype(floatX)
b_val = np.random.normal(size=(5, 5)).astype(floatX)
# utt.verify_grad uses small perturbations to the input matrix to calculate the finite difference gradient. When
# a non-triangular matrix is passed to scipy.linalg.solve_triangular, no error is raise, but the result will be
# wrong, resulting in wrong gradients. As a result, it is necessary to add a mapping from the space of all matrices
# to the space of triangular matrices, and test the gradient of that entire graph.
def A_func_pt(x):
x = x @ x.conj().T
x_tri = pt.linalg.cholesky(x, lower=lower).astype(floatX)
if unit_diag:
n = A_val.shape[0]
x_tri = x_tri[np.diag_indices(n)].set(1.0)
return transpose_func(x_tri.astype(floatX), trans)
solve_op = partial(
pt.linalg.solve_triangular, lower=lower, trans=trans, unit_diagonal=unit_diag
)
utt.verify_grad(
lambda A, b: solve_op(A_func_pt(A), b),
[A_val.copy(), b_val.copy()],
mode="NUMBA",
)
......@@ -93,11 +135,11 @@ def test_solve_triangular_raises_on_nan_inf(value):
X = pt.linalg.solve_triangular(A, b, check_finite=True)
f = pytensor.function([A, b], X, mode="NUMBA")
A_val = np.random.normal(size=(5, 5))
A_val = np.random.normal(size=(5, 5)).astype(floatX)
A_sym = A_val @ A_val.conj().T
A_tri = np.linalg.cholesky(A_sym).astype(config.floatX)
b = np.full((5, 1), value)
A_tri = np.linalg.cholesky(A_sym).astype(floatX)
b = np.full((5, 1), value).astype(floatX)
with pytest.raises(
np.linalg.LinAlgError,
......@@ -119,19 +161,19 @@ def test_numba_Cholesky(lower, trans):
fg = FunctionGraph(outputs=[chol])
x = np.array([0.1, 0.2, 0.3])
val = np.eye(3) + x[None, :] * x[:, None]
x = np.array([0.1, 0.2, 0.3]).astype(floatX)
val = np.eye(3).astype(floatX) + x[None, :] * x[:, None]
compare_numba_and_py(fg, [val])
def test_numba_Cholesky_raises_on_nan_input():
test_value = rng.random(size=(3, 3)).astype(config.floatX)
test_value = rng.random(size=(3, 3)).astype(floatX)
test_value[0, 0] = np.nan
x = pt.tensor(dtype=config.floatX, shape=(3, 3))
x = pt.tensor(dtype=floatX, shape=(3, 3))
x = x.T.dot(x)
g = pt.linalg.cholesky(x, check_finite=True)
g = pt.linalg.cholesky(x)
f = pytensor.function([x], g, mode="NUMBA")
with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"):
......@@ -140,9 +182,9 @@ def test_numba_Cholesky_raises_on_nan_input():
@pytest.mark.parametrize("on_error", ["nan", "raise"])
def test_numba_Cholesky_raise_on(on_error):
test_value = rng.random(size=(3, 3)).astype(config.floatX)
test_value = rng.random(size=(3, 3)).astype(floatX)
x = pt.tensor(dtype=config.floatX, shape=(3, 3))
x = pt.tensor(dtype=floatX, shape=(3, 3))
g = pt.linalg.cholesky(x, on_error=on_error)
f = pytensor.function([x], g, mode="NUMBA")
......@@ -155,6 +197,16 @@ def test_numba_Cholesky_raise_on(on_error):
assert np.all(np.isnan(f(test_value)))
@pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"])
def test_numba_Cholesky_grad(lower):
rng = np.random.default_rng(utt.fetch_seed())
L = rng.normal(size=(5, 5)).astype(floatX)
X = L @ L.T
chol_op = partial(pt.linalg.cholesky, lower=lower)
utt.verify_grad(chol_op, [X], mode="NUMBA")
def test_block_diag():
A = pt.matrix("A")
B = pt.matrix("B")
......@@ -162,9 +214,242 @@ def test_block_diag():
D = pt.matrix("D")
X = pt.linalg.block_diag(A, B, C, D)
A_val = np.random.normal(size=(5, 5))
B_val = np.random.normal(size=(3, 3))
C_val = np.random.normal(size=(2, 2))
D_val = np.random.normal(size=(4, 4))
A_val = np.random.normal(size=(5, 5)).astype(floatX)
B_val = np.random.normal(size=(3, 3)).astype(floatX)
C_val = np.random.normal(size=(2, 2)).astype(floatX)
D_val = np.random.normal(size=(4, 4)).astype(floatX)
out_fg = pytensor.graph.FunctionGraph([A, B, C, D], [X])
compare_numba_and_py(out_fg, [A_val, B_val, C_val, D_val])
def test_lamch():
from scipy.linalg import get_lapack_funcs
from pytensor.link.numba.dispatch.slinalg import _xlamch
@numba.njit()
def xlamch(kind):
return _xlamch(kind)
lamch = get_lapack_funcs("lamch", (np.array([0.0], dtype=floatX),))
np.testing.assert_allclose(xlamch("E"), lamch("E"))
np.testing.assert_allclose(xlamch("S"), lamch("S"))
np.testing.assert_allclose(xlamch("P"), lamch("P"))
np.testing.assert_allclose(xlamch("B"), lamch("B"))
np.testing.assert_allclose(xlamch("R"), lamch("R"))
np.testing.assert_allclose(xlamch("M"), lamch("M"))
@pytest.mark.parametrize(
"ord_numba, ord_scipy", [("F", "fro"), ("1", 1), ("I", np.inf)]
)
def test_xlange(ord_numba, ord_scipy):
# xlange is called internally only, we don't dispatch pt.linalg.norm to it
from scipy import linalg
from pytensor.link.numba.dispatch.slinalg import _xlange
@numba.njit()
def xlange(x, ord):
return _xlange(x, ord)
x = np.random.normal(size=(5, 5)).astype(floatX)
np.testing.assert_allclose(xlange(x, ord_numba), linalg.norm(x, ord_scipy))
@pytest.mark.parametrize("ord_numba, ord_scipy", [("1", 1), ("I", np.inf)])
def test_xgecon(ord_numba, ord_scipy):
# gecon is called internally only, we don't dispatch pt.linalg.norm to it
from scipy.linalg import get_lapack_funcs
from pytensor.link.numba.dispatch.slinalg import _xgecon, _xlange
@numba.njit()
def gecon(x, norm):
anorm = _xlange(x, norm)
cond, info = _xgecon(x, anorm, norm)
return cond, info
x = np.random.normal(size=(5, 5)).astype(floatX)
rcond, info = gecon(x, norm=ord_numba)
# Test against direct call to the underlying LAPACK functions
# Solution does **not** agree with 1 / np.linalg.cond(x) !
lange, gecon = get_lapack_funcs(("lange", "gecon"), (x,))
norm = lange(ord_numba, x)
rcond2, _ = gecon(x, norm, norm=ord_numba)
assert info == 0
np.testing.assert_allclose(rcond, rcond2)
@pytest.mark.parametrize("overwrite_a", [True, False])
def test_getrf(overwrite_a):
from scipy.linalg import lu_factor
from pytensor.link.numba.dispatch.slinalg import _getrf
# TODO: Refactor this test to use compare_numba_and_py after we implement lu_factor in pytensor
@numba.njit()
def getrf(x, overwrite_a):
return _getrf(x, overwrite_a=overwrite_a)
x = np.random.normal(size=(5, 5)).astype(floatX)
x = np.asfortranarray(
x
) # x needs to be fortran-contiguous going into getrf for the overwrite option to work
lu, ipiv = lu_factor(x, overwrite_a=False)
LU, IPIV, info = getrf(x, overwrite_a=overwrite_a)
assert info == 0
assert_allclose(LU, lu)
if overwrite_a:
assert_allclose(x, LU)
# TODO: It seems IPIV is 1-indexed in FORTRAN, so we need to subtract 1. I can't find evidence that scipy is doing
# this, though.
assert_allclose(IPIV - 1, ipiv)
@pytest.mark.parametrize("trans", [0, 1])
@pytest.mark.parametrize("overwrite_a", [True, False])
@pytest.mark.parametrize("overwrite_b", [True, False])
@pytest.mark.parametrize("b_shape", [(5,), (5, 3)], ids=["b_1d", "b_2d"])
def test_getrs(trans, overwrite_a, overwrite_b, b_shape):
from scipy.linalg import lu_factor
from scipy.linalg import lu_solve as sp_lu_solve
from pytensor.link.numba.dispatch.slinalg import _getrf, _getrs
# TODO: Refactor this test to use compare_numba_and_py after we implement lu_solve in pytensor
@numba.njit()
def lu_solve(a, b, trans, overwrite_a, overwrite_b):
lu, ipiv, info = _getrf(a, overwrite_a=overwrite_a)
x, info = _getrs(lu, b, ipiv, trans=trans, overwrite_b=overwrite_b)
return x, lu, info
a = np.random.normal(size=(5, 5)).astype(floatX)
b = np.random.normal(size=b_shape).astype(floatX)
# inputs need to be fortran-contiguous going into getrf and getrs for the overwrite option to work
a = np.asfortranarray(a)
b = np.asfortranarray(b)
lu_and_piv = lu_factor(a, overwrite_a=False)
x_sp = sp_lu_solve(lu_and_piv, b, trans, overwrite_b=False)
x, lu, info = lu_solve(
a, b, trans, overwrite_a=overwrite_a, overwrite_b=overwrite_b
)
assert info == 0
if overwrite_a:
assert_allclose(a, lu)
if overwrite_b:
assert_allclose(b, x)
assert_allclose(x, x_sp)
@pytest.mark.parametrize(
"b_shape",
[(5, 1), (5, 5), (5,)],
ids=["b_col_vec", "b_matrix", "b_vec"],
)
@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str)
@pytest.mark.filterwarnings(
'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
)
def test_solve(b_shape: tuple[int], assume_a: Literal["gen", "sym", "pos"]):
A = pt.matrix("A", dtype=floatX)
b = pt.tensor("b", shape=b_shape, dtype=floatX)
A_val = np.asfortranarray(np.random.normal(size=(5, 5)).astype(floatX))
b_val = np.asfortranarray(np.random.normal(size=b_shape).astype(floatX))
def A_func(x):
if assume_a == "pos":
x = x @ x.T
elif assume_a == "sym":
x = (x + x.T) / 2
return x
X = pt.linalg.solve(
A_func(A),
b,
assume_a=assume_a,
b_ndim=len(b_shape),
)
f = pytensor.function(
[pytensor.In(A, mutable=True), pytensor.In(b, mutable=True)], X, mode="NUMBA"
)
op = f.maker.fgraph.outputs[0].owner.op
compare_numba_and_py(([A, b], [X]), inputs=[A_val, b_val], inplace=True)
# Calling this is destructive and will rewrite b_val to be the answer. Store copies of the inputs first.
A_val_copy = A_val.copy()
b_val_copy = b_val.copy()
X_np = f(A_val, b_val)
# overwrite_b is preferred when both inputs can be destroyed
assert op.destroy_map == {0: [1]}
# Confirm inputs were destroyed by checking against the copies
assert (A_val == A_val_copy).all() == (op.destroy_map.get(0, None) != [0])
assert (b_val == b_val_copy).all() == (op.destroy_map.get(0, None) != [1])
ATOL = 1e-8 if floatX.endswith("64") else 1e-4
RTOL = 1e-8 if floatX.endswith("64") else 1e-4
# Confirm b_val is used to store to solution
np.testing.assert_allclose(X_np, b_val, atol=ATOL, rtol=RTOL)
assert not np.allclose(b_val, b_val_copy)
# Test that the result is numerically correct. Need to use the unmodified copy
np.testing.assert_allclose(
A_func(A_val_copy) @ X_np, b_val_copy, atol=ATOL, rtol=RTOL
)
# See the note in tensor/test_slinalg.py::test_solve_correctness for details about the setup here
utt.verify_grad(
lambda A, b: pt.linalg.solve(
A_func(A), b, lower=False, assume_a=assume_a, b_ndim=len(b_shape)
),
[A_val_copy, b_val_copy],
mode="NUMBA",
)
@pytest.mark.parametrize(
"b_func, b_size",
[(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))],
ids=["b_col_vec", "b_matrix", "b_vec"],
)
@pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower = {x}")
def test_cho_solve(b_func, b_size, lower):
A = pt.matrix("A", dtype=floatX)
b = b_func("b", dtype=floatX)
C = pt.linalg.cholesky(A, lower=lower)
X = pt.linalg.cho_solve((C, lower), b)
f = pytensor.function([A, b], X, mode="NUMBA")
A = np.random.normal(size=(5, 5)).astype(floatX)
A = A @ A.conj().T
b = np.random.normal(size=b_size)
b = b.astype(floatX)
X_np = f(A, b)
ATOL = 1e-8 if floatX.endswith("64") else 1e-4
RTOL = 1e-8 if floatX.endswith("64") else 1e-4
np.testing.assert_allclose(A @ X_np, b, atol=ATOL, rtol=RTOL)
......@@ -209,12 +209,12 @@ class TestSolveBase:
)
class TestSolve(utt.InferShapeTester):
def test__init__(self):
with pytest.raises(ValueError) as excinfo:
Solve(assume_a="test", b_ndim=2)
assert "is not a recognized matrix structure" in str(excinfo.value)
def test_solve_raises_on_invalid_A():
with pytest.raises(ValueError, match="is not a recognized matrix structure"):
Solve(assume_a="test", b_ndim=2)
class TestSolve(utt.InferShapeTester):
@pytest.mark.parametrize("b_shape", [(5, 1), (5,)])
def test_infer_shape(self, b_shape):
rng = np.random.default_rng(utt.fetch_seed())
......@@ -232,64 +232,78 @@ class TestSolve(utt.InferShapeTester):
warn=False,
)
def test_correctness(self):
@pytest.mark.parametrize(
"b_size", [(5, 1), (5, 5), (5,)], ids=["b_col_vec", "b_matrix", "b_vec"]
)
@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str)
def test_solve_correctness(self, b_size: tuple[int], assume_a: str):
rng = np.random.default_rng(utt.fetch_seed())
A = matrix()
b = matrix()
y = solve(A, b)
gen_solve_func = pytensor.function([A, b], y)
A = pt.tensor("A", shape=(5, 5))
b = pt.tensor("b", shape=b_size)
b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX)
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
b_val = rng.normal(size=b_size).astype(config.floatX)
A_val = np.asarray(rng.random((5, 5)), dtype=config.floatX)
A_val = np.dot(A_val.transpose(), A_val)
solve_op = functools.partial(solve, assume_a=assume_a, b_ndim=len(b_size))
np.testing.assert_allclose(
scipy.linalg.solve(A_val, b_val, assume_a="gen"),
gen_solve_func(A_val, b_val),
)
def A_func(x):
if assume_a == "pos":
return x @ x.T
elif assume_a == "sym":
return (x + x.T) / 2
else:
return x
solve_input_val = A_func(A_val)
y = solve_op(A_func(A), b)
solve_func = pytensor.function([A, b], y)
X_np = solve_func(A_val.copy(), b_val.copy())
ATOL = 1e-8 if config.floatX.endswith("64") else 1e-4
RTOL = 1e-8 if config.floatX.endswith("64") else 1e-4
A_undef = np.array(
[
[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 1],
[0, 0, 0, 1, 0],
],
dtype=config.floatX,
)
np.testing.assert_allclose(
scipy.linalg.solve(A_undef, b_val), gen_solve_func(A_undef, b_val)
scipy.linalg.solve(solve_input_val, b_val, assume_a=assume_a),
X_np,
atol=ATOL,
rtol=RTOL,
)
np.testing.assert_allclose(A_func(A_val) @ X_np, b_val, atol=ATOL, rtol=RTOL)
@pytest.mark.parametrize(
"m, n, assume_a, lower",
[
(5, None, "gen", False),
(5, None, "gen", True),
(4, 2, "gen", False),
(4, 2, "gen", True),
],
"b_size", [(5, 1), (5, 5), (5,)], ids=["b_col_vec", "b_matrix", "b_vec"]
)
def test_solve_grad(self, m, n, assume_a, lower):
@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str)
@pytest.mark.skipif(
config.floatX == "float32", reason="Gradients not numerically stable in float32"
)
def test_solve_gradient(self, b_size: tuple[int], assume_a: str):
rng = np.random.default_rng(utt.fetch_seed())
# Ensure diagonal elements of `A` are relatively large to avoid
# numerical precision issues
A_val = (rng.normal(size=(m, m)) * 0.5 + np.eye(m)).astype(config.floatX)
eps = 2e-8 if config.floatX == "float64" else None
if n is None:
b_val = rng.normal(size=m).astype(config.floatX)
else:
b_val = rng.normal(size=(m, n)).astype(config.floatX)
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
b_val = rng.normal(size=b_size).astype(config.floatX)
eps = None
if config.floatX == "float64":
eps = 2e-8
def A_func(x):
if assume_a == "pos":
return x @ x.T
elif assume_a == "sym":
return (x + x.T) / 2
else:
return x
solve_op = Solve(assume_a=assume_a, lower=lower, b_ndim=1 if n is None else 2)
utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)
solve_op = functools.partial(solve, assume_a=assume_a, b_ndim=len(b_size))
# To correctly check the gradients, we need to include a transformation from the space of unconstrained matrices
# (A) to a valid input matrix for the given solver. This is done by the A_func function. If this isn't included,
# the random perturbations used by verify_grad will result in invalid input matrices, and
# LAPACK will silently do the wrong thing, making the gradients wrong
utt.verify_grad(
lambda A, b: solve_op(A_func(A), b), [A_val, b_val], 3, rng, eps=eps
)
class TestSolveTriangular(utt.InferShapeTester):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论