提交 bbe663d9 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

Implement numba dispatch for all `linalg.solve` modes

上级 8e5e8a40
import ctypes
import numpy as np
from numba.core import cgutils, types
from numba.core.extending import get_cython_function_address, intrinsic
from numba.np.linalg import ensure_lapack, get_blas_kind
_PTR = ctypes.POINTER
_dbl = ctypes.c_double
_float = ctypes.c_float
_char = ctypes.c_char
_int = ctypes.c_int
_ptr_float = _PTR(_float)
_ptr_dbl = _PTR(_dbl)
_ptr_char = _PTR(_char)
_ptr_int = _PTR(_int)
def _get_lapack_ptr_and_ptr_type(dtype, name):
d = get_blas_kind(dtype)
func_name = f"{d}{name}"
float_pointer = _get_float_pointer_for_dtype(d)
lapack_ptr = get_cython_function_address("scipy.linalg.cython_lapack", func_name)
return lapack_ptr, float_pointer
def _get_underlying_float(dtype):
s_dtype = str(dtype)
out_type = s_dtype
if s_dtype == "complex64":
out_type = "float32"
elif s_dtype == "complex128":
out_type = "float64"
return np.dtype(out_type)
def _get_float_pointer_for_dtype(blas_dtype):
if blas_dtype in ["s", "c"]:
return _ptr_float
elif blas_dtype in ["d", "z"]:
return _ptr_dbl
def _get_output_ctype(dtype):
s_dtype = str(dtype)
if s_dtype in ["float32", "complex64"]:
return _float
elif s_dtype in ["float64", "complex128"]:
return _dbl
@intrinsic
def sptr_to_val(typingctx, data):
def impl(context, builder, signature, args):
val = builder.load(args[0])
return val
sig = types.float32(types.CPointer(types.float32))
return sig, impl
@intrinsic
def dptr_to_val(typingctx, data):
def impl(context, builder, signature, args):
val = builder.load(args[0])
return val
sig = types.float64(types.CPointer(types.float64))
return sig, impl
@intrinsic
def int_ptr_to_val(typingctx, data):
def impl(context, builder, signature, args):
val = builder.load(args[0])
return val
sig = types.int32(types.CPointer(types.int32))
return sig, impl
@intrinsic
def val_to_int_ptr(typingctx, data):
def impl(context, builder, signature, args):
ptr = cgutils.alloca_once_value(builder, args[0])
return ptr
sig = types.CPointer(types.int32)(types.int32)
return sig, impl
@intrinsic
def val_to_sptr(typingctx, data):
def impl(context, builder, signature, args):
ptr = cgutils.alloca_once_value(builder, args[0])
return ptr
sig = types.CPointer(types.float32)(types.float32)
return sig, impl
@intrinsic
def val_to_zptr(typingctx, data):
def impl(context, builder, signature, args):
ptr = cgutils.alloca_once_value(builder, args[0])
return ptr
sig = types.CPointer(types.complex128)(types.complex128)
return sig, impl
@intrinsic
def val_to_dptr(typingctx, data):
def impl(context, builder, signature, args):
ptr = cgutils.alloca_once_value(builder, args[0])
return ptr
sig = types.CPointer(types.float64)(types.float64)
return sig, impl
class _LAPACK:
"""
Functions to return type signatures for wrapped LAPACK functions.
Patterned after https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L74
"""
def __init__(self):
ensure_lapack()
@classmethod
def numba_xtrtrs(cls, dtype):
"""
Solve a triangular system of equations of the form A @ X = B or A.T @ X = B.
Called by scipy.linalg.solve_triangular
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "trtrs")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # UPLO
_ptr_int, # TRANS
_ptr_int, # DIAG
_ptr_int, # N
_ptr_int, # NRHS
float_pointer, # A
_ptr_int, # LDA
float_pointer, # B
_ptr_int, # LDB
_ptr_int, # INFO
)
return functype(lapack_ptr)
@classmethod
def numba_xpotrf(cls, dtype):
"""
Compute the Cholesky factorization of a real symmetric positive definite matrix.
Called by scipy.linalg.cholesky
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrf")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # UPLO,
_ptr_int, # N
float_pointer, # A
_ptr_int, # LDA
_ptr_int, # INFO
)
return functype(lapack_ptr)
@classmethod
def numba_xpotrs(cls, dtype):
"""
Solve a system of linear equations A @ X = B with a symmetric positive definite matrix A using the Cholesky
factorization computed by numba_potrf.
Called by scipy.linalg.cho_solve
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrs")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # UPLO
_ptr_int, # N
_ptr_int, # NRHS
float_pointer, # A
_ptr_int, # LDA
float_pointer, # B
_ptr_int, # LDB
_ptr_int, # INFO
)
return functype(lapack_ptr)
@classmethod
def numba_xlange(cls, dtype):
"""
Compute the value of the 1-norm, Frobenius norm, infinity-norm, or the largest absolute value of any element of
a general M-by-N matrix A.
Called by scipy.linalg.solve
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "lange")
output_ctype = _get_output_ctype(dtype)
functype = ctypes.CFUNCTYPE(
output_ctype, # Output
_ptr_int, # NORM
_ptr_int, # M
_ptr_int, # N
float_pointer, # A
_ptr_int, # LDA
float_pointer, # WORK
)
return functype(lapack_ptr)
@classmethod
def numba_xlamch(cls, dtype):
"""
Determine machine precision for floating point arithmetic.
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "lamch")
output_dtype = _get_output_ctype(dtype)
functype = ctypes.CFUNCTYPE(
output_dtype, # Output
_ptr_int, # CMACH
)
return functype(lapack_ptr)
@classmethod
def numba_xgecon(cls, dtype):
"""
Estimates the condition number of a matrix A, using the LU factorization computed by numba_getrf.
Called by scipy.linalg.solve when assume_a == "gen"
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "gecon")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # NORM
_ptr_int, # N
float_pointer, # A
_ptr_int, # LDA
float_pointer, # ANORM
float_pointer, # RCOND
float_pointer, # WORK
_ptr_int, # IWORK
_ptr_int, # INFO
)
return functype(lapack_ptr)
@classmethod
def numba_xgetrf(cls, dtype):
"""
Compute partial pivoting LU factorization of a general M-by-N matrix A using row interchanges.
Called by scipy.linalg.lu_factor
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "getrf")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # M
_ptr_int, # N
float_pointer, # A
_ptr_int, # LDA
_ptr_int, # IPIV
_ptr_int, # INFO
)
return functype(lapack_ptr)
@classmethod
def numba_xgetrs(cls, dtype):
"""
Solve a system of linear equations A @ X = B or A.T @ X = B with a general N-by-N matrix A using the LU
factorization computed by GETRF.
Called by scipy.linalg.lu_solve
"""
...
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "getrs")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # TRANS
_ptr_int, # N
_ptr_int, # NRHS
float_pointer, # A
_ptr_int, # LDA
_ptr_int, # IPIV
float_pointer, # B
_ptr_int, # LDB
_ptr_int, # INFO
)
return functype(lapack_ptr)
@classmethod
def numba_xsysv(cls, dtype):
"""
Solve a system of linear equations A @ X = B with a symmetric matrix A using the diagonal pivoting method,
factorizing A into LDL^T or UDU^T form, depending on the value of UPLO
Called by scipy.linalg.solve when assume_a == "sym"
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "sysv")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # UPLO
_ptr_int, # N
_ptr_int, # NRHS
float_pointer, # A
_ptr_int, # LDA
_ptr_int, # IPIV
float_pointer, # B
_ptr_int, # LDB
float_pointer, # WORK
_ptr_int, # LWORK
_ptr_int, # INFO
)
return functype(lapack_ptr)
@classmethod
def numba_xsycon(cls, dtype):
"""
Estimate the reciprocal of the condition number of a symmetric matrix A using the UDU or LDL factorization
computed by xSYTRF.
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "sycon")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # UPLO
_ptr_int, # N
float_pointer, # A
_ptr_int, # LDA
_ptr_int, # IPIV
float_pointer, # ANORM
float_pointer, # RCOND
float_pointer, # WORK
_ptr_int, # IWORK
_ptr_int, # INFO
)
return functype(lapack_ptr)
@classmethod
def numba_xpocon(cls, dtype):
"""
Estimates the reciprocal of the condition number of a positive definite matrix A using the Cholesky factorization
computed by potrf.
Called by scipy.linalg.solve when assume_a == "pos"
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "pocon")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # UPLO
_ptr_int, # N
float_pointer, # A
_ptr_int, # LDA
float_pointer, # ANORM
float_pointer, # RCOND
float_pointer, # WORK
_ptr_int, # IWORK
_ptr_int, # INFO
)
return functype(lapack_ptr)
@classmethod
def numba_xposv(cls, dtype):
"""
Solve a system of linear equations A @ X = B with a symmetric positive definite matrix A using the Cholesky
factorization computed by potrf.
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "posv")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # UPLO
_ptr_int, # N
_ptr_int, # NRHS
float_pointer, # A
_ptr_int, # LDA
float_pointer, # B
_ptr_int, # LDB
_ptr_int, # INFO
)
return functype(lapack_ptr)
...@@ -367,7 +367,7 @@ def numba_typify(data, dtype=None, **kwargs): ...@@ -367,7 +367,7 @@ def numba_typify(data, dtype=None, **kwargs):
def generate_fallback_impl(op, node=None, storage_map=None, **kwargs): def generate_fallback_impl(op, node=None, storage_map=None, **kwargs):
"""Create a Numba compatible function from an Aesara `Op`.""" """Create a Numba compatible function from a Pytensor `Op`."""
warnings.warn( warnings.warn(
f"Numba will use object mode to run {op}'s perform method", f"Numba will use object mode to run {op}'s perform method",
......
import ctypes from collections.abc import Callable
import numba import numba
import numpy as np import numpy as np
from numba.core import cgutils, types from numba.core import types
from numba.extending import get_cython_function_address, intrinsic, overload from numba.extending import overload
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack, get_blas_kind from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from numpy.linalg import LinAlgError
from scipy import linalg from scipy import linalg
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.basic import numba_funcify from pytensor.link.numba.dispatch.basic import numba_funcify
from pytensor.tensor.slinalg import BlockDiagonal, Cholesky, SolveTriangular from pytensor.tensor.slinalg import (
BlockDiagonal,
Cholesky,
CholeskySolve,
Solve,
SolveTriangular,
)
_PTR = ctypes.POINTER @numba_basic.numba_njit(inline="always")
def _solve_check(n, info, lamch=False, rcond=None):
_dbl = ctypes.c_double """
_float = ctypes.c_float Check arguments during the different steps of the solution phase
_char = ctypes.c_char Adapted from https://github.com/scipy/scipy/blob/7f7f04caa4a55306a9c6613c89eef91fedbd72d4/scipy/linalg/_basic.py#L38
_int = ctypes.c_int """
if info < 0:
_ptr_float = _PTR(_float) # TODO: figure out how to do an fstring here
_ptr_dbl = _PTR(_dbl) msg = "LAPACK reported an illegal value in input"
_ptr_char = _PTR(_char) raise ValueError(msg)
_ptr_int = _PTR(_int) elif 0 < info:
raise LinAlgError("Matrix is singular.")
@numba.core.extending.register_jitable if lamch:
def _check_finite_matrix(a, func_name): E = _xlamch("E")
for v in np.nditer(a): if rcond < E:
if not np.isfinite(v.item()): # TODO: This should be a warning, but we can't raise warnings in numba mode
raise np.linalg.LinAlgError( print( # noqa: T201
"Non-numeric values (nan or inf) in input to " + func_name "Ill-conditioned matrix, rcond=", rcond, ", result may not be accurate."
) )
@intrinsic
def val_to_dptr(typingctx, data):
def impl(context, builder, signature, args):
ptr = cgutils.alloca_once_value(builder, args[0])
return ptr
sig = types.CPointer(types.float64)(types.float64)
return sig, impl
@intrinsic
def val_to_zptr(typingctx, data):
def impl(context, builder, signature, args):
ptr = cgutils.alloca_once_value(builder, args[0])
return ptr
sig = types.CPointer(types.complex128)(types.complex128)
return sig, impl
@intrinsic
def val_to_sptr(typingctx, data):
def impl(context, builder, signature, args):
ptr = cgutils.alloca_once_value(builder, args[0])
return ptr
sig = types.CPointer(types.float32)(types.float32)
return sig, impl
@intrinsic
def val_to_int_ptr(typingctx, data):
def impl(context, builder, signature, args):
ptr = cgutils.alloca_once_value(builder, args[0])
return ptr
sig = types.CPointer(types.int32)(types.int32)
return sig, impl
@intrinsic
def int_ptr_to_val(typingctx, data):
def impl(context, builder, signature, args):
val = builder.load(args[0])
return val
sig = types.int32(types.CPointer(types.int32))
return sig, impl
@intrinsic
def dptr_to_val(typingctx, data):
def impl(context, builder, signature, args):
val = builder.load(args[0])
return val
sig = types.float64(types.CPointer(types.float64))
return sig, impl
@intrinsic
def sptr_to_val(typingctx, data):
def impl(context, builder, signature, args):
val = builder.load(args[0])
return val
sig = types.float32(types.CPointer(types.float32))
return sig, impl
def _get_float_pointer_for_dtype(blas_dtype):
if blas_dtype in ["s", "c"]:
return _ptr_float
elif blas_dtype in ["d", "z"]:
return _ptr_dbl
def _get_underlying_float(dtype):
s_dtype = str(dtype)
out_type = s_dtype
if s_dtype == "complex64":
out_type = "float32"
elif s_dtype == "complex128":
out_type = "float64"
return np.dtype(out_type)
def _get_lapack_ptr_and_ptr_type(dtype, name):
d = get_blas_kind(dtype)
func_name = f"{d}{name}"
float_pointer = _get_float_pointer_for_dtype(d)
lapack_ptr = get_cython_function_address("scipy.linalg.cython_lapack", func_name)
return lapack_ptr, float_pointer
def _check_scipy_linalg_matrix(a, func_name): def _check_scipy_linalg_matrix(a, func_name):
""" """
Adapted from https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L831 Adapted from https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L831
...@@ -152,64 +68,50 @@ def _check_scipy_linalg_matrix(a, func_name): ...@@ -152,64 +68,50 @@ def _check_scipy_linalg_matrix(a, func_name):
raise numba.TypingError(msg, highlighting=False) raise numba.TypingError(msg, highlighting=False)
class _LAPACK: def _solve_triangular(
A, B, trans=0, lower=False, unit_diagonal=False, b_ndim=1, overwrite_b=False
):
""" """
Functions to return type signatures for wrapped LAPACK functions. Thin wrapper around scipy.linalg.solve_triangular.
Patterned after https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L74 This function is overloaded instead of the original scipy function to avoid unexpected side-effects to users who
""" import pytensor.
def __init__(self):
ensure_lapack()
@classmethod The signature must be the same as solve_triangular_impl, so b_ndim is included, although this argument is not
def numba_xtrtrs(cls, dtype): used by scipy.linalg.solve_triangular.
""" """
Called by scipy.linalg.solve_triangular return linalg.solve_triangular(
""" A,
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "trtrs") B,
trans=trans,
lower=lower,
unit_diagonal=unit_diagonal,
overwrite_b=overwrite_b,
)
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # UPLO
_ptr_int, # TRANS
_ptr_int, # DIAG
_ptr_int, # N
_ptr_int, # NRHS
float_pointer, # A
_ptr_int, # LDA
float_pointer, # B
_ptr_int, # LDB
_ptr_int, # INFO
)
return functype(lapack_ptr) @numba_basic.numba_njit(inline="always")
def _trans_char_to_int(trans):
@classmethod if trans not in [0, 1, 2]:
def numba_xpotrf(cls, dtype): raise ValueError('Parameter "trans" should be one of 0, 1, 2')
""" if trans == 0:
Called by scipy.linalg.cholesky return ord("N")
""" elif trans == 1:
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrf") return ord("T")
functype = ctypes.CFUNCTYPE( else:
None, return ord("C")
_ptr_int, # UPLO,
_ptr_int, # N
float_pointer, # A
_ptr_int, # LDA
_ptr_int, # INFO
)
return functype(lapack_ptr)
def _solve_triangular(A, B, trans=0, lower=False, unit_diagonal=False): @numba_basic.numba_njit(inline="always")
return linalg.solve_triangular( def _solve_check_input_shapes(A, B):
A, B, trans=trans, lower=lower, unit_diagonal=unit_diagonal if A.shape[0] != B.shape[0]:
) raise linalg.LinAlgError("Dimensions of A and B do not conform")
if A.shape[-2] != A.shape[-1]:
raise linalg.LinAlgError("Last 2 dimensions of A must be square")
@overload(_solve_triangular) @overload(_solve_triangular)
def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False): def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
ensure_lapack() ensure_lapack()
_check_scipy_linalg_matrix(A, "solve_triangular") _check_scipy_linalg_matrix(A, "solve_triangular")
...@@ -218,37 +120,27 @@ def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False): ...@@ -218,37 +120,27 @@ def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False):
w_type = _get_underlying_float(dtype) w_type = _get_underlying_float(dtype)
numba_trtrs = _LAPACK().numba_xtrtrs(dtype) numba_trtrs = _LAPACK().numba_xtrtrs(dtype)
def impl(A, B, trans=0, lower=False, unit_diagonal=False): def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
B_is_1d = B.ndim == 1
_N = np.int32(A.shape[-1]) _N = np.int32(A.shape[-1])
if A.shape[-2] != _N: _solve_check_input_shapes(A, B)
raise linalg.LinAlgError("Last 2 dimensions of A must be square")
if A.shape[0] != B.shape[0]: B_is_1d = B.ndim == 1
raise linalg.LinAlgError("Dimensions of A and B do not conform")
if B_is_1d: if not overwrite_b:
B_copy = np.asfortranarray(np.expand_dims(B, -1))
else:
B_copy = _copy_to_fortran_order(B) B_copy = _copy_to_fortran_order(B)
if trans not in [0, 1, 2]:
raise ValueError('Parameter "trans" should be one of N, C, T or 0, 1, 2')
if trans == 0:
transval = ord("N")
elif trans == 1:
transval = ord("T")
else: else:
transval = ord("C") B_copy = B
B_NDIM = 1 if B_is_1d else int(B.shape[1]) if B_is_1d:
B_copy = np.expand_dims(B, -1)
NRHS = 1 if B_is_1d else int(B_copy.shape[-1])
UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
TRANS = val_to_int_ptr(transval) TRANS = val_to_int_ptr(_trans_char_to_int(trans))
DIAG = val_to_int_ptr(ord("U") if unit_diagonal else ord("N")) DIAG = val_to_int_ptr(ord("U") if unit_diagonal else ord("N"))
N = val_to_int_ptr(_N) N = val_to_int_ptr(_N)
NRHS = val_to_int_ptr(B_NDIM) NRHS = val_to_int_ptr(NRHS)
LDA = val_to_int_ptr(_N) LDA = val_to_int_ptr(_N)
LDB = val_to_int_ptr(_N) LDB = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0) INFO = val_to_int_ptr(0)
...@@ -266,19 +158,24 @@ def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False): ...@@ -266,19 +158,24 @@ def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False):
INFO, INFO,
) )
_solve_check(int_ptr_to_val(LDA), int_ptr_to_val(INFO))
if B_is_1d: if B_is_1d:
return B_copy[..., 0], int_ptr_to_val(INFO) return B_copy[..., 0]
return B_copy, int_ptr_to_val(INFO)
return B_copy
return impl return impl
@numba_funcify.register(SolveTriangular) @numba_funcify.register(SolveTriangular)
def numba_funcify_SolveTriangular(op, node, **kwargs): def numba_funcify_SolveTriangular(op, node, **kwargs):
trans = op.trans trans = bool(op.trans)
lower = op.lower lower = op.lower
unit_diagonal = op.unit_diagonal unit_diagonal = op.unit_diagonal
check_finite = op.check_finite check_finite = op.check_finite
overwrite_b = op.overwrite_b
b_ndim = op.b_ndim
dtype = node.inputs[0].dtype dtype = node.inputs[0].dtype
if str(dtype).startswith("complex"): if str(dtype).startswith("complex"):
...@@ -298,11 +195,16 @@ def numba_funcify_SolveTriangular(op, node, **kwargs): ...@@ -298,11 +195,16 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
"Non-numeric values (nan or inf) in input b to solve_triangular" "Non-numeric values (nan or inf) in input b to solve_triangular"
) )
res, info = _solve_triangular(a, b, trans, lower, unit_diagonal) res = _solve_triangular(
if info != 0: a,
raise np.linalg.LinAlgError( b,
"Singular matrix in input A to solve_triangular" trans=trans,
) lower=lower,
unit_diagonal=unit_diagonal,
overwrite_b=overwrite_b,
b_ndim=b_ndim,
)
return res return res
return solve_triangular return solve_triangular
...@@ -429,3 +331,853 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs): ...@@ -429,3 +331,853 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs):
return out return out
return block_diag return block_diag
def _xlamch(kind: str = "E"):
"""
Placeholder for getting machine precision; used by linalg.solve. Not used by pytensor to numbify graphs.
"""
pass
@overload(_xlamch)
def xlamch_impl(kind: str = "E") -> Callable[[str], float]:
"""
Compute the machine precision for a given floating point type.
"""
from pytensor import config
ensure_lapack()
w_type = _get_underlying_float(config.floatX)
if w_type == "float32":
dtype = types.float32
elif w_type == "float64":
dtype = types.float64
else:
raise NotImplementedError("Unsupported dtype")
numba_lamch = _LAPACK().numba_xlamch(dtype)
def impl(kind: str = "E") -> float:
KIND = val_to_int_ptr(ord(kind))
return numba_lamch(KIND) # type: ignore
return impl
def _xlange(A: np.ndarray, order: str | None = None) -> float:
"""
Placeholder for computing the norm of a matrix; used by linalg.solve. Will never be called in python mode.
"""
return # type: ignore
@overload(_xlange)
def xlange_impl(
A: np.ndarray, order: str | None = None
) -> Callable[[np.ndarray, str], float]:
"""
xLANGE returns the value of the one norm, or the Frobenius norm, or the infinity norm, or the element of
largest absolute value of a matrix A.
"""
ensure_lapack()
_check_scipy_linalg_matrix(A, "norm")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_lange = _LAPACK().numba_xlange(dtype)
def impl(A: np.ndarray, order: str | None = None):
_M, _N = np.int32(A.shape[-2:]) # type: ignore
A_copy = _copy_to_fortran_order(A)
M = val_to_int_ptr(_M) # type: ignore
N = val_to_int_ptr(_N) # type: ignore
LDA = val_to_int_ptr(_M) # type: ignore
NORM = (
val_to_int_ptr(ord(order))
if order is not None
else val_to_int_ptr(ord("1"))
)
WORK = np.empty(_M, dtype=dtype) # type: ignore
result = numba_lange(
NORM, M, N, A_copy.view(w_type).ctypes, LDA, WORK.view(w_type).ctypes
)
return result
return impl
def _xgecon(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]:
"""
Placeholder for computing the condition number of a matrix; used by linalg.solve. Not used by pytensor to numbify
graphs.
"""
return # type: ignore
@overload(_xgecon)
def xgecon_impl(
A: np.ndarray, A_norm: float, norm: str
) -> Callable[[np.ndarray, float, str], tuple[np.ndarray, int]]:
"""
Compute the condition number of a matrix A.
"""
ensure_lapack()
_check_scipy_linalg_matrix(A, "gecon")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_gecon = _LAPACK().numba_xgecon(dtype)
def impl(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]:
_N = np.int32(A.shape[-1])
A_copy = _copy_to_fortran_order(A)
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
A_NORM = np.array(A_norm, dtype=dtype)
NORM = val_to_int_ptr(ord(norm))
RCOND = np.empty(1, dtype=dtype)
WORK = np.empty(4 * _N, dtype=dtype)
IWORK = np.empty(_N, dtype=np.int32)
INFO = val_to_int_ptr(1)
numba_gecon(
NORM,
N,
A_copy.view(w_type).ctypes,
LDA,
A_NORM.view(w_type).ctypes,
RCOND.view(w_type).ctypes,
WORK.view(w_type).ctypes,
IWORK.ctypes,
INFO,
)
return RCOND, int_ptr_to_val(INFO)
return impl
def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]:
"""
Placeholder for LU factorization; used by linalg.solve.
# TODO: Implement an LU_factor Op, then dispatch to this function in numba mode.
"""
return # type: ignore
@overload(_getrf)
def getrf_impl(
A: np.ndarray, overwrite_a: bool = False
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray, int]]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "getrf")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_getrf = _LAPACK().numba_xgetrf(dtype)
def impl(
A: np.ndarray, overwrite_a: bool = False
) -> tuple[np.ndarray, np.ndarray, int]:
_M, _N = np.int32(A.shape[-2:]) # type: ignore
if not overwrite_a:
A_copy = _copy_to_fortran_order(A)
else:
A_copy = A
M = val_to_int_ptr(_M) # type: ignore
N = val_to_int_ptr(_N) # type: ignore
LDA = val_to_int_ptr(_M) # type: ignore
IPIV = np.empty(_N, dtype=np.int32) # type: ignore
INFO = val_to_int_ptr(0)
numba_getrf(M, N, A_copy.view(w_type).ctypes, LDA, IPIV.ctypes, INFO)
return A_copy, IPIV, int_ptr_to_val(INFO)
return impl
def _getrs(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
) -> tuple[np.ndarray, int]:
"""
Placeholder for solving a linear system with a matrix that has been LU-factored; used by linalg.solve.
# TODO: Implement an LU_solve Op, then dispatch to this function in numba mode.
"""
return # type: ignore
@overload(_getrs)
def getrs_impl(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, int, bool], tuple[np.ndarray, int]]:
ensure_lapack()
_check_scipy_linalg_matrix(LU, "getrs")
_check_scipy_linalg_matrix(B, "getrs")
dtype = LU.dtype
w_type = _get_underlying_float(dtype)
numba_getrs = _LAPACK().numba_xgetrs(dtype)
def impl(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
) -> tuple[np.ndarray, int]:
_N = np.int32(LU.shape[-1])
_solve_check_input_shapes(LU, B)
B_is_1d = B.ndim == 1
if not overwrite_b:
B_copy = _copy_to_fortran_order(B)
else:
B_copy = B
if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)
NRHS = 1 if B_is_1d else int(B_copy.shape[-1])
TRANS = val_to_int_ptr(_trans_char_to_int(trans))
N = val_to_int_ptr(_N)
NRHS = val_to_int_ptr(NRHS)
LDA = val_to_int_ptr(_N)
LDB = val_to_int_ptr(_N)
IPIV = _copy_to_fortran_order(IPIV)
INFO = val_to_int_ptr(0)
numba_getrs(
TRANS,
N,
NRHS,
LU.view(w_type).ctypes,
LDA,
IPIV.ctypes,
B_copy.view(w_type).ctypes,
LDB,
INFO,
)
if B_is_1d:
return B_copy[..., 0], int_ptr_to_val(INFO)
return B_copy, int_ptr_to_val(INFO)
return impl
def _solve_gen(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
):
"""Thin wrapper around scipy.linalg.solve. Used as an overload target for numba to avoid unexpected side-effects
for users who import pytensor."""
return linalg.solve(
A,
B,
lower=lower,
overwrite_a=overwrite_a,
overwrite_b=overwrite_b,
check_finite=check_finite,
assume_a="gen",
transposed=transposed,
)
@overload(_solve_gen)
def solve_gen_impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "solve")
_check_scipy_linalg_matrix(B, "solve")
def impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> np.ndarray:
_N = np.int32(A.shape[-1])
_solve_check_input_shapes(A, B)
order = "I" if transposed else "1"
norm = _xlange(A, order=order)
N = A.shape[1]
LU, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a)
_solve_check(N, INFO)
X, INFO = _getrs(
LU=LU, B=B, IPIV=IPIV, trans=transposed, overwrite_b=overwrite_b
)
_solve_check(N, INFO)
RCOND, INFO = _xgecon(LU, norm, "1")
_solve_check(N, INFO, True, RCOND)
return X
return impl
def _sysv(
A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool
) -> tuple[np.ndarray, np.ndarray, int]:
"""
Placeholder for solving a linear system with a symmetric matrix; used by linalg.solve.
"""
return # type: ignore
@overload(_sysv)
def sysv_impl(
A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool
) -> Callable[
[np.ndarray, np.ndarray, bool, bool, bool], tuple[np.ndarray, np.ndarray, int]
]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "sysv")
_check_scipy_linalg_matrix(B, "sysv")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_sysv = _LAPACK().numba_xsysv(dtype)
def impl(
A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool
):
_LDA, _N = np.int32(A.shape[-2:]) # type: ignore
_solve_check_input_shapes(A, B)
if not overwrite_a:
A_copy = _copy_to_fortran_order(A)
else:
A_copy = A
B_is_1d = B.ndim == 1
if not overwrite_b:
B_copy = _copy_to_fortran_order(B)
else:
B_copy = B
if B_is_1d:
B_copy = np.asfortranarray(np.expand_dims(B_copy, -1))
NRHS = 1 if B_is_1d else int(B.shape[-1])
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
N = val_to_int_ptr(_N) # type: ignore
NRHS = val_to_int_ptr(NRHS)
LDA = val_to_int_ptr(_LDA) # type: ignore
IPIV = np.empty(_N, dtype=np.int32) # type: ignore
LDB = val_to_int_ptr(_N) # type: ignore
WORK = np.empty(1, dtype=dtype)
LWORK = val_to_int_ptr(-1)
INFO = val_to_int_ptr(0)
# Workspace query
numba_sysv(
UPLO,
N,
NRHS,
A_copy.view(w_type).ctypes,
LDA,
IPIV.ctypes,
B_copy.view(w_type).ctypes,
LDB,
WORK.view(w_type).ctypes,
LWORK,
INFO,
)
WS_SIZE = np.int32(WORK[0].real)
LWORK = val_to_int_ptr(WS_SIZE)
WORK = np.empty(WS_SIZE, dtype=dtype)
# Actual solve
numba_sysv(
UPLO,
N,
NRHS,
A_copy.view(w_type).ctypes,
LDA,
IPIV.ctypes,
B_copy.view(w_type).ctypes,
LDB,
WORK.view(w_type).ctypes,
LWORK,
INFO,
)
if B_is_1d:
return B_copy[..., 0], IPIV, int_ptr_to_val(INFO)
return B_copy, IPIV, int_ptr_to_val(INFO)
return impl
def _sycon(A: np.ndarray, ipiv: np.ndarray, anorm: float) -> tuple[np.ndarray, int]:
"""
Placeholder for computing the condition number of a symmetric matrix; used by linalg.solve. Never called in
python mode.
"""
return # type: ignore
@overload(_sycon)
def sycon_impl(
A: np.ndarray, ipiv: np.ndarray, anorm: float
) -> Callable[[np.ndarray, np.ndarray, float], tuple[np.ndarray, int]]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "sycon")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_sycon = _LAPACK().numba_xsycon(dtype)
def impl(A: np.ndarray, ipiv: np.ndarray, anorm: float) -> tuple[np.ndarray, int]:
_N = np.int32(A.shape[-1])
A_copy = _copy_to_fortran_order(A)
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
UPLO = val_to_int_ptr(ord("L"))
ANORM = np.array(anorm, dtype=dtype)
RCOND = np.empty(1, dtype=dtype)
WORK = np.empty(2 * _N, dtype=dtype)
IWORK = np.empty(_N, dtype=np.int32)
INFO = val_to_int_ptr(0)
numba_sycon(
UPLO,
N,
A_copy.view(w_type).ctypes,
LDA,
ipiv.ctypes,
ANORM.view(w_type).ctypes,
RCOND.view(w_type).ctypes,
WORK.view(w_type).ctypes,
IWORK.ctypes,
INFO,
)
return RCOND, int_ptr_to_val(INFO)
return impl
def _solve_symmetric(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
):
"""Thin wrapper around scipy.linalg.solve for symmetric matrices. Used as an overload target for numba to avoid
unexpected side-effects when users import pytensor."""
return linalg.solve(
A,
B,
lower=lower,
overwrite_a=overwrite_a,
overwrite_b=overwrite_b,
check_finite=check_finite,
assume_a="sym",
transposed=transposed,
)
@overload(_solve_symmetric)
def solve_symmetric_impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "solve")
_check_scipy_linalg_matrix(B, "solve")
def impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> np.ndarray:
_solve_check_input_shapes(A, B)
x, ipiv, info = _sysv(A, B, lower, overwrite_a, overwrite_b)
_solve_check(A.shape[-1], info)
rcond, info = _sycon(A, ipiv, _xlange(A, order="I"))
_solve_check(A.shape[-1], info, True, rcond)
return x
return impl
def _posv(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> tuple[np.ndarray, int]:
"""
Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve.
"""
return # type: ignore
@overload(_posv)
def posv_impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> Callable[
[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], tuple[np.ndarray, int]
]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "solve")
_check_scipy_linalg_matrix(B, "solve")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_posv = _LAPACK().numba_xposv(dtype)
def impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> tuple[np.ndarray, int]:
_solve_check_input_shapes(A, B)
_N = np.int32(A.shape[-1])
if not overwrite_a:
A_copy = _copy_to_fortran_order(A)
else:
A_copy = A
B_is_1d = B.ndim == 1
if not overwrite_b:
B_copy = _copy_to_fortran_order(B)
else:
B_copy = B
if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
NRHS = 1 if B_is_1d else int(B.shape[-1])
N = val_to_int_ptr(_N)
NRHS = val_to_int_ptr(NRHS)
LDA = val_to_int_ptr(_N)
LDB = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)
numba_posv(
UPLO,
N,
NRHS,
A_copy.view(w_type).ctypes,
LDA,
B_copy.view(w_type).ctypes,
LDB,
INFO,
)
if B_is_1d:
return B_copy[..., 0], int_ptr_to_val(INFO)
return B_copy, int_ptr_to_val(INFO)
return impl
def _pocon(A: np.ndarray, anorm: float) -> tuple[np.ndarray, int]:
"""
Placeholder for computing the condition number of a cholesky-factorized positive-definite matrix. Used by
linalg.solve when assume_a = "pos".
"""
return # type: ignore
@overload(_pocon)
def pocon_impl(
A: np.ndarray, anorm: float
) -> Callable[[np.ndarray, float], tuple[np.ndarray, int]]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "pocon")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_pocon = _LAPACK().numba_xpocon(dtype)
def impl(A: np.ndarray, anorm: float):
_N = np.int32(A.shape[-1])
A_copy = _copy_to_fortran_order(A)
UPLO = val_to_int_ptr(ord("L"))
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
ANORM = np.array(anorm, dtype=dtype)
RCOND = np.empty(1, dtype=dtype)
WORK = np.empty(3 * _N, dtype=dtype)
IWORK = np.empty(_N, dtype=np.int32)
INFO = val_to_int_ptr(0)
numba_pocon(
UPLO,
N,
A_copy.view(w_type).ctypes,
LDA,
ANORM.view(w_type).ctypes,
RCOND.view(w_type).ctypes,
WORK.view(w_type).ctypes,
IWORK.ctypes,
INFO,
)
return RCOND, int_ptr_to_val(INFO)
return impl
def _solve_psd(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
):
"""Thin wrapper around scipy.linalg.solve for positive-definite matrices. Used as an overload target for numba to
avoid unexpected side-effects when users import pytensor."""
return linalg.solve(
A,
B,
lower=lower,
overwrite_a=overwrite_a,
overwrite_b=overwrite_b,
check_finite=check_finite,
transposed=transposed,
assume_a="pos",
)
@overload(_solve_psd)
def solve_psd_impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "solve")
_check_scipy_linalg_matrix(B, "solve")
def impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> np.ndarray:
_solve_check_input_shapes(A, B)
x, info = _posv(A, B, lower, overwrite_a, overwrite_b, check_finite, transposed)
_solve_check(A.shape[-1], info)
rcond, info = _pocon(x, _xlange(A))
_solve_check(A.shape[-1], info=info, lamch=True, rcond=rcond)
return x
return impl
@numba_funcify.register(Solve)
def numba_funcify_Solve(op, node, **kwargs):
assume_a = op.assume_a
lower = op.lower
check_finite = op.check_finite
overwrite_a = op.overwrite_a
overwrite_b = op.overwrite_b
transposed = False # TODO: Solve doesnt currently allow the transposed argument
dtype = node.inputs[0].dtype
if str(dtype).startswith("complex"):
raise NotImplementedError(
"Complex inputs not currently supported by solve in Numba mode"
)
if assume_a == "gen":
solve_fn = _solve_gen
elif assume_a == "sym":
solve_fn = _solve_symmetric
elif assume_a == "her":
raise NotImplementedError(
'Use assume_a = "sym" for symmetric real matrices. If you need compelx support, '
"please open an issue on github."
)
elif assume_a == "pos":
solve_fn = _solve_psd
else:
raise NotImplementedError(f"Assumption {assume_a} not supported in Numba mode")
@numba_basic.numba_njit(inline="always")
def solve(a, b):
if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input A to solve"
)
if np.any(np.bitwise_or(np.isinf(b), np.isnan(b))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input b to solve"
)
res = solve_fn(a, b, lower, overwrite_a, overwrite_b, check_finite, transposed)
return res
return solve
def _cho_solve(A_and_lower, B, overwrite_a=False, overwrite_b=False, check_finite=True):
"""
Solve a positive-definite linear system using the Cholesky decomposition.
"""
A, lower = A_and_lower
return linalg.cho_solve((A, lower), B)
@overload(_cho_solve)
def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
ensure_lapack()
_check_scipy_linalg_matrix(C, "cho_solve")
_check_scipy_linalg_matrix(B, "cho_solve")
dtype = C.dtype
w_type = _get_underlying_float(dtype)
numba_potrs = _LAPACK().numba_xpotrs(dtype)
def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
_solve_check_input_shapes(C, B)
_N = np.int32(C.shape[-1])
C_copy = _copy_to_fortran_order(C)
B_is_1d = B.ndim == 1
if B_is_1d:
B_copy = np.asfortranarray(np.expand_dims(B, -1))
else:
B_copy = _copy_to_fortran_order(B)
NRHS = 1 if B_is_1d else int(B.shape[-1])
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
N = val_to_int_ptr(_N)
NRHS = val_to_int_ptr(NRHS)
LDA = val_to_int_ptr(_N)
LDB = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)
numba_potrs(
UPLO,
N,
NRHS,
C_copy.view(w_type).ctypes,
LDA,
B_copy.view(w_type).ctypes,
LDB,
INFO,
)
if B_is_1d:
return B_copy[..., 0], int_ptr_to_val(INFO)
return B_copy, int_ptr_to_val(INFO)
return impl
@numba_funcify.register(CholeskySolve)
def numba_funcify_CholeskySolve(op, node, **kwargs):
lower = op.lower
overwrite_b = op.overwrite_b
check_finite = op.check_finite
dtype = node.inputs[0].dtype
if str(dtype).startswith("complex"):
raise NotImplementedError(
"Complex inputs not currently supported by cho_solve in Numba mode"
)
@numba_basic.numba_njit(inline="always")
def cho_solve(c, b):
if check_finite:
if np.any(np.bitwise_or(np.isinf(c), np.isnan(c))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input A to cho_solve"
)
if np.any(np.bitwise_or(np.isinf(b), np.isnan(b))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input b to cho_solve"
)
res, info = _cho_solve(
c, b, lower=lower, overwrite_b=overwrite_b, check_finite=check_finite
)
if info < 0:
raise np.linalg.LinAlgError("Illegal values found in input to cho_solve")
elif info > 0:
raise np.linalg.LinAlgError(
"Matrix is not positive definite in input to cho_solve"
)
return res
return cho_solve
import logging import logging
import typing
import warnings import warnings
from collections.abc import Sequence
from functools import reduce from functools import reduce
from typing import Literal, cast from typing import Literal, cast
import numpy as np import numpy as np
import scipy.linalg import scipy.linalg as scipy_linalg
import pytensor import pytensor
import pytensor.tensor as pt import pytensor.tensor as pt
...@@ -58,7 +58,7 @@ class Cholesky(Op): ...@@ -58,7 +58,7 @@ class Cholesky(Op):
f"Cholesky only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input" f"Cholesky only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input"
) )
# Call scipy to find output dtype # Call scipy to find output dtype
dtype = scipy.linalg.cholesky(np.eye(1, dtype=x.type.dtype)).dtype dtype = scipy_linalg.cholesky(np.eye(1, dtype=x.type.dtype)).dtype
return Apply(self, [x], [tensor(shape=x.type.shape, dtype=dtype)]) return Apply(self, [x], [tensor(shape=x.type.shape, dtype=dtype)])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
...@@ -68,21 +68,21 @@ class Cholesky(Op): ...@@ -68,21 +68,21 @@ class Cholesky(Op):
# Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS # Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS
# If we have a `C_CONTIGUOUS` array we transpose to benefit from it # If we have a `C_CONTIGUOUS` array we transpose to benefit from it
if self.overwrite_a and x.flags["C_CONTIGUOUS"]: if self.overwrite_a and x.flags["C_CONTIGUOUS"]:
out[0] = scipy.linalg.cholesky( out[0] = scipy_linalg.cholesky(
x.T, x.T,
lower=not self.lower, lower=not self.lower,
check_finite=self.check_finite, check_finite=self.check_finite,
overwrite_a=True, overwrite_a=True,
).T ).T
else: else:
out[0] = scipy.linalg.cholesky( out[0] = scipy_linalg.cholesky(
x, x,
lower=self.lower, lower=self.lower,
check_finite=self.check_finite, check_finite=self.check_finite,
overwrite_a=self.overwrite_a, overwrite_a=self.overwrite_a,
) )
except scipy.linalg.LinAlgError: except scipy_linalg.LinAlgError:
if self.on_error == "raise": if self.on_error == "raise":
raise raise
else: else:
...@@ -334,7 +334,7 @@ class CholeskySolve(SolveBase): ...@@ -334,7 +334,7 @@ class CholeskySolve(SolveBase):
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
C, b = inputs C, b = inputs
rval = scipy.linalg.cho_solve( rval = scipy_linalg.cho_solve(
(C, self.lower), (C, self.lower),
b, b,
check_finite=self.check_finite, check_finite=self.check_finite,
...@@ -369,7 +369,7 @@ def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None): ...@@ -369,7 +369,7 @@ def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None):
Whether to check that the input matrices contain only finite numbers. Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs. (crashes, non-termination) if the inputs do contain infinities or NaNs.
b_ndim : int b_ndim : int
Whether the core case of b is a vector (1) or matrix (2). Whether the core case of b is a vector (1) or matrix (2).
This will influence how batched dimensions are interpreted. This will influence how batched dimensions are interpreted.
""" """
...@@ -401,7 +401,7 @@ class SolveTriangular(SolveBase): ...@@ -401,7 +401,7 @@ class SolveTriangular(SolveBase):
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
A, b = inputs A, b = inputs
outputs[0][0] = scipy.linalg.solve_triangular( outputs[0][0] = scipy_linalg.solve_triangular(
A, A,
b, b,
lower=self.lower, lower=self.lower,
...@@ -502,7 +502,7 @@ class Solve(SolveBase): ...@@ -502,7 +502,7 @@ class Solve(SolveBase):
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
a, b = inputs a, b = inputs
outputs[0][0] = scipy.linalg.solve( outputs[0][0] = scipy_linalg.solve(
a=a, a=a,
b=b, b=b,
lower=self.lower, lower=self.lower,
...@@ -619,9 +619,9 @@ class Eigvalsh(Op): ...@@ -619,9 +619,9 @@ class Eigvalsh(Op):
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(w,) = outputs (w,) = outputs
if len(inputs) == 2: if len(inputs) == 2:
w[0] = scipy.linalg.eigvalsh(a=inputs[0], b=inputs[1], lower=self.lower) w[0] = scipy_linalg.eigvalsh(a=inputs[0], b=inputs[1], lower=self.lower)
else: else:
w[0] = scipy.linalg.eigvalsh(a=inputs[0], b=None, lower=self.lower) w[0] = scipy_linalg.eigvalsh(a=inputs[0], b=None, lower=self.lower)
def grad(self, inputs, g_outputs): def grad(self, inputs, g_outputs):
a, b = inputs a, b = inputs
...@@ -675,7 +675,7 @@ class EigvalshGrad(Op): ...@@ -675,7 +675,7 @@ class EigvalshGrad(Op):
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(a, b, gw) = inputs (a, b, gw) = inputs
w, v = scipy.linalg.eigh(a, b, lower=self.lower) w, v = scipy_linalg.eigh(a, b, lower=self.lower)
gA = v.dot(np.diag(gw).dot(v.T)) gA = v.dot(np.diag(gw).dot(v.T))
gB = -v.dot(np.diag(gw * w).dot(v.T)) gB = -v.dot(np.diag(gw * w).dot(v.T))
...@@ -718,7 +718,7 @@ class Expm(Op): ...@@ -718,7 +718,7 @@ class Expm(Op):
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(A,) = inputs (A,) = inputs
(expm,) = outputs (expm,) = outputs
expm[0] = scipy.linalg.expm(A) expm[0] = scipy_linalg.expm(A)
def grad(self, inputs, outputs): def grad(self, inputs, outputs):
(A,) = inputs (A,) = inputs
...@@ -758,8 +758,8 @@ class ExpmGrad(Op): ...@@ -758,8 +758,8 @@ class ExpmGrad(Op):
# this expression. # this expression.
(A, gA) = inputs (A, gA) = inputs
(out,) = outputs (out,) = outputs
w, V = scipy.linalg.eig(A, right=True) w, V = scipy_linalg.eig(A, right=True)
U = scipy.linalg.inv(V).T U = scipy_linalg.inv(V).T
exp_w = np.exp(w) exp_w = np.exp(w)
X = np.subtract.outer(exp_w, exp_w) / np.subtract.outer(w, w) X = np.subtract.outer(exp_w, exp_w) / np.subtract.outer(w, w)
...@@ -800,7 +800,7 @@ class SolveContinuousLyapunov(Op): ...@@ -800,7 +800,7 @@ class SolveContinuousLyapunov(Op):
X = output_storage[0] X = output_storage[0]
out_dtype = node.outputs[0].type.dtype out_dtype = node.outputs[0].type.dtype
X[0] = scipy.linalg.solve_continuous_lyapunov(A, B).astype(out_dtype) X[0] = scipy_linalg.solve_continuous_lyapunov(A, B).astype(out_dtype)
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
return [shapes[0]] return [shapes[0]]
...@@ -870,7 +870,7 @@ class BilinearSolveDiscreteLyapunov(Op): ...@@ -870,7 +870,7 @@ class BilinearSolveDiscreteLyapunov(Op):
X = output_storage[0] X = output_storage[0]
out_dtype = node.outputs[0].type.dtype out_dtype = node.outputs[0].type.dtype
X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear").astype( X[0] = scipy_linalg.solve_discrete_lyapunov(A, B, method="bilinear").astype(
out_dtype out_dtype
) )
...@@ -992,7 +992,7 @@ class SolveDiscreteARE(Op): ...@@ -992,7 +992,7 @@ class SolveDiscreteARE(Op):
Q = 0.5 * (Q + Q.T) Q = 0.5 * (Q + Q.T)
out_dtype = node.outputs[0].type.dtype out_dtype = node.outputs[0].type.dtype
X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype) X[0] = scipy_linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype)
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
return [shapes[0]] return [shapes[0]]
...@@ -1064,7 +1064,7 @@ def solve_discrete_are( ...@@ -1064,7 +1064,7 @@ def solve_discrete_are(
) )
def _largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype: def _largest_common_dtype(tensors: Sequence[TensorVariable]) -> np.dtype:
return reduce(lambda l, r: np.promote_types(l, r), [x.dtype for x in tensors]) return reduce(lambda l, r: np.promote_types(l, r), [x.dtype for x in tensors])
...@@ -1118,7 +1118,7 @@ class BlockDiagonal(BaseBlockDiagonal): ...@@ -1118,7 +1118,7 @@ class BlockDiagonal(BaseBlockDiagonal):
def perform(self, node, inputs, output_storage, params=None): def perform(self, node, inputs, output_storage, params=None):
dtype = node.outputs[0].type.dtype dtype = node.outputs[0].type.dtype
output_storage[0][0] = scipy.linalg.block_diag(*inputs).astype(dtype) output_storage[0][0] = scipy_linalg.block_diag(*inputs).astype(dtype)
def block_diag(*matrices: TensorVariable): def block_diag(*matrices: TensorVariable):
...@@ -1175,4 +1175,5 @@ __all__ = [ ...@@ -1175,4 +1175,5 @@ __all__ = [
"solve_discrete_are", "solve_discrete_are",
"solve_triangular", "solve_triangular",
"block_diag", "block_diag",
"cho_solve",
] ]
...@@ -7,58 +7,13 @@ import pytensor.tensor as pt ...@@ -7,58 +7,13 @@ import pytensor.tensor as pt
from pytensor.compile.sharedvalue import SharedVariable from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import nlinalg, slinalg from pytensor.tensor import nlinalg
from tests.link.numba.test_basic import compare_numba_and_py, set_test_value from tests.link.numba.test_basic import compare_numba_and_py, set_test_value
rng = np.random.default_rng(42849) rng = np.random.default_rng(42849)
@pytest.mark.parametrize(
"A, x, lower, exc",
[
(
set_test_value(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
set_test_value(pt.dvector(), rng.random(size=(3,)).astype("float64")),
"gen",
None,
),
(
set_test_value(
pt.lmatrix(),
(lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64")
),
),
set_test_value(pt.dvector(), rng.random(size=(3,)).astype("float64")),
"gen",
None,
),
],
)
def test_Solve(A, x, lower, exc):
g = slinalg.Solve(lower=lower, b_ndim=1)(A, x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, exc", "x, exc",
[ [
......
import re import re
from functools import partial
from typing import Literal
import numpy as np import numpy as np
import pytest import pytest
from numpy.testing import assert_allclose
from scipy import linalg as scipy_linalg
import pytensor import pytensor
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config
from pytensor.graph import FunctionGraph from pytensor.graph import FunctionGraph
from tests import unittest_tools as utt
from tests.link.numba.test_basic import compare_numba_and_py from tests.link.numba.test_basic import compare_numba_and_py
numba = pytest.importorskip("numba") numba = pytest.importorskip("numba")
ATOL = 0 if config.floatX.endswith("64") else 1e-6 floatX = pytensor.config.floatX
RTOL = 1e-7 if config.floatX.endswith("64") else 1e-6
rng = np.random.default_rng(42849) rng = np.random.default_rng(42849)
...@@ -27,8 +31,8 @@ def transpose_func(x, trans): ...@@ -27,8 +31,8 @@ def transpose_func(x, trans):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"b_func, b_size", "b_shape",
[(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))], [(5, 1), (5, 5), (5,)],
ids=["b_col_vec", "b_matrix", "b_vec"], ids=["b_col_vec", "b_matrix", "b_vec"],
) )
@pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"]) @pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"])
...@@ -36,50 +40,88 @@ def transpose_func(x, trans): ...@@ -36,50 +40,88 @@ def transpose_func(x, trans):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"unit_diag", [True, False], ids=["unit_diag=True", "unit_diag=False"] "unit_diag", [True, False], ids=["unit_diag=True", "unit_diag=False"]
) )
@pytest.mark.parametrize("complex", [True, False], ids=["complex", "real"]) @pytest.mark.parametrize("is_complex", [True, False], ids=["complex", "real"])
@pytest.mark.filterwarnings( @pytest.mark.filterwarnings(
'ignore:Cannot cache compiled function "numba_funcified_fgraph"' 'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
) )
def test_solve_triangular(b_func, b_size, lower, trans, unit_diag, complex): def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_complex):
if complex: if is_complex:
# TODO: Complex raises ValueError: To change to a dtype of a different size, the last axis must be contiguous, # TODO: Complex raises ValueError: To change to a dtype of a different size, the last axis must be contiguous,
# why? # why?
pytest.skip("Complex inputs currently not supported to solve_triangular") pytest.skip("Complex inputs currently not supported to solve_triangular")
complex_dtype = "complex64" if config.floatX.endswith("32") else "complex128" complex_dtype = "complex64" if floatX.endswith("32") else "complex128"
dtype = complex_dtype if complex else config.floatX dtype = complex_dtype if is_complex else floatX
A = pt.matrix("A", dtype=dtype) A = pt.matrix("A", dtype=dtype)
b = b_func("b", dtype=dtype) b = pt.tensor("b", shape=b_shape, dtype=dtype)
def A_func(x):
x = x @ x.conj().T
x_tri = scipy_linalg.cholesky(x, lower=lower).astype(dtype)
X = pt.linalg.solve_triangular( if unit_diag:
A, b, lower=lower, trans=trans, unit_diagonal=unit_diag x_tri[np.diag_indices_from(x_tri)] = 1.0
return x_tri.astype(dtype)
solve_op = partial(
pt.linalg.solve_triangular, lower=lower, trans=trans, unit_diagonal=unit_diag
) )
X = solve_op(A, b)
f = pytensor.function([A, b], X, mode="NUMBA") f = pytensor.function([A, b], X, mode="NUMBA")
A_val = np.random.normal(size=(5, 5)) A_val = np.random.normal(size=(5, 5))
b = np.random.normal(size=b_size) b_val = np.random.normal(size=b_shape)
if complex: if is_complex:
A_val = A_val + np.random.normal(size=(5, 5)) * 1j A_val = A_val + np.random.normal(size=(5, 5)) * 1j
b = b + np.random.normal(size=b_size) * 1j b_val = b_val + np.random.normal(size=b_shape) * 1j
A_sym = A_val @ A_val.conj().T
A_tri = np.linalg.cholesky(A_sym).astype(dtype) X_np = f(A_func(A_val.copy()), b_val.copy())
if unit_diag:
adj_mat = np.ones((5, 5))
adj_mat[np.diag_indices(5)] = 1 / np.diagonal(A_tri)
A_tri = A_tri * adj_mat
A_tri = A_tri.astype(dtype) test_input = transpose_func(A_func(A_val.copy()), trans)
b = b.astype(dtype)
if not lower: ATOL = 1e-8 if floatX.endswith("64") else 1e-4
A_tri = A_tri.T RTOL = 1e-8 if floatX.endswith("64") else 1e-4
X_np = f(A_tri, b) np.testing.assert_allclose(test_input @ X_np, b_val, atol=ATOL, rtol=RTOL)
np.testing.assert_allclose(
transpose_func(A_tri, trans) @ X_np, b, atol=ATOL, rtol=RTOL compare_numba_and_py(f.maker.fgraph, [A_func(A_val.copy()), b_val.copy()])
@pytest.mark.parametrize(
"lower, unit_diag, trans",
[(True, True, True), (False, False, False)],
ids=["lower_unit_trans", "defaults"],
)
def test_solve_triangular_grad(lower, unit_diag, trans):
A_val = np.random.normal(size=(5, 5)).astype(floatX)
b_val = np.random.normal(size=(5, 5)).astype(floatX)
# utt.verify_grad uses small perturbations to the input matrix to calculate the finite difference gradient. When
# a non-triangular matrix is passed to scipy.linalg.solve_triangular, no error is raise, but the result will be
# wrong, resulting in wrong gradients. As a result, it is necessary to add a mapping from the space of all matrices
# to the space of triangular matrices, and test the gradient of that entire graph.
def A_func_pt(x):
x = x @ x.conj().T
x_tri = pt.linalg.cholesky(x, lower=lower).astype(floatX)
if unit_diag:
n = A_val.shape[0]
x_tri = x_tri[np.diag_indices(n)].set(1.0)
return transpose_func(x_tri.astype(floatX), trans)
solve_op = partial(
pt.linalg.solve_triangular, lower=lower, trans=trans, unit_diagonal=unit_diag
)
utt.verify_grad(
lambda A, b: solve_op(A_func_pt(A), b),
[A_val.copy(), b_val.copy()],
mode="NUMBA",
) )
...@@ -93,11 +135,11 @@ def test_solve_triangular_raises_on_nan_inf(value): ...@@ -93,11 +135,11 @@ def test_solve_triangular_raises_on_nan_inf(value):
X = pt.linalg.solve_triangular(A, b, check_finite=True) X = pt.linalg.solve_triangular(A, b, check_finite=True)
f = pytensor.function([A, b], X, mode="NUMBA") f = pytensor.function([A, b], X, mode="NUMBA")
A_val = np.random.normal(size=(5, 5)) A_val = np.random.normal(size=(5, 5)).astype(floatX)
A_sym = A_val @ A_val.conj().T A_sym = A_val @ A_val.conj().T
A_tri = np.linalg.cholesky(A_sym).astype(config.floatX) A_tri = np.linalg.cholesky(A_sym).astype(floatX)
b = np.full((5, 1), value) b = np.full((5, 1), value).astype(floatX)
with pytest.raises( with pytest.raises(
np.linalg.LinAlgError, np.linalg.LinAlgError,
...@@ -119,19 +161,19 @@ def test_numba_Cholesky(lower, trans): ...@@ -119,19 +161,19 @@ def test_numba_Cholesky(lower, trans):
fg = FunctionGraph(outputs=[chol]) fg = FunctionGraph(outputs=[chol])
x = np.array([0.1, 0.2, 0.3]) x = np.array([0.1, 0.2, 0.3]).astype(floatX)
val = np.eye(3) + x[None, :] * x[:, None] val = np.eye(3).astype(floatX) + x[None, :] * x[:, None]
compare_numba_and_py(fg, [val]) compare_numba_and_py(fg, [val])
def test_numba_Cholesky_raises_on_nan_input(): def test_numba_Cholesky_raises_on_nan_input():
test_value = rng.random(size=(3, 3)).astype(config.floatX) test_value = rng.random(size=(3, 3)).astype(floatX)
test_value[0, 0] = np.nan test_value[0, 0] = np.nan
x = pt.tensor(dtype=config.floatX, shape=(3, 3)) x = pt.tensor(dtype=floatX, shape=(3, 3))
x = x.T.dot(x) x = x.T.dot(x)
g = pt.linalg.cholesky(x, check_finite=True) g = pt.linalg.cholesky(x)
f = pytensor.function([x], g, mode="NUMBA") f = pytensor.function([x], g, mode="NUMBA")
with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"): with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"):
...@@ -140,9 +182,9 @@ def test_numba_Cholesky_raises_on_nan_input(): ...@@ -140,9 +182,9 @@ def test_numba_Cholesky_raises_on_nan_input():
@pytest.mark.parametrize("on_error", ["nan", "raise"]) @pytest.mark.parametrize("on_error", ["nan", "raise"])
def test_numba_Cholesky_raise_on(on_error): def test_numba_Cholesky_raise_on(on_error):
test_value = rng.random(size=(3, 3)).astype(config.floatX) test_value = rng.random(size=(3, 3)).astype(floatX)
x = pt.tensor(dtype=config.floatX, shape=(3, 3)) x = pt.tensor(dtype=floatX, shape=(3, 3))
g = pt.linalg.cholesky(x, on_error=on_error) g = pt.linalg.cholesky(x, on_error=on_error)
f = pytensor.function([x], g, mode="NUMBA") f = pytensor.function([x], g, mode="NUMBA")
...@@ -155,6 +197,16 @@ def test_numba_Cholesky_raise_on(on_error): ...@@ -155,6 +197,16 @@ def test_numba_Cholesky_raise_on(on_error):
assert np.all(np.isnan(f(test_value))) assert np.all(np.isnan(f(test_value)))
@pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"])
def test_numba_Cholesky_grad(lower):
rng = np.random.default_rng(utt.fetch_seed())
L = rng.normal(size=(5, 5)).astype(floatX)
X = L @ L.T
chol_op = partial(pt.linalg.cholesky, lower=lower)
utt.verify_grad(chol_op, [X], mode="NUMBA")
def test_block_diag(): def test_block_diag():
A = pt.matrix("A") A = pt.matrix("A")
B = pt.matrix("B") B = pt.matrix("B")
...@@ -162,9 +214,242 @@ def test_block_diag(): ...@@ -162,9 +214,242 @@ def test_block_diag():
D = pt.matrix("D") D = pt.matrix("D")
X = pt.linalg.block_diag(A, B, C, D) X = pt.linalg.block_diag(A, B, C, D)
A_val = np.random.normal(size=(5, 5)) A_val = np.random.normal(size=(5, 5)).astype(floatX)
B_val = np.random.normal(size=(3, 3)) B_val = np.random.normal(size=(3, 3)).astype(floatX)
C_val = np.random.normal(size=(2, 2)) C_val = np.random.normal(size=(2, 2)).astype(floatX)
D_val = np.random.normal(size=(4, 4)) D_val = np.random.normal(size=(4, 4)).astype(floatX)
out_fg = pytensor.graph.FunctionGraph([A, B, C, D], [X]) out_fg = pytensor.graph.FunctionGraph([A, B, C, D], [X])
compare_numba_and_py(out_fg, [A_val, B_val, C_val, D_val]) compare_numba_and_py(out_fg, [A_val, B_val, C_val, D_val])
def test_lamch():
from scipy.linalg import get_lapack_funcs
from pytensor.link.numba.dispatch.slinalg import _xlamch
@numba.njit()
def xlamch(kind):
return _xlamch(kind)
lamch = get_lapack_funcs("lamch", (np.array([0.0], dtype=floatX),))
np.testing.assert_allclose(xlamch("E"), lamch("E"))
np.testing.assert_allclose(xlamch("S"), lamch("S"))
np.testing.assert_allclose(xlamch("P"), lamch("P"))
np.testing.assert_allclose(xlamch("B"), lamch("B"))
np.testing.assert_allclose(xlamch("R"), lamch("R"))
np.testing.assert_allclose(xlamch("M"), lamch("M"))
@pytest.mark.parametrize(
"ord_numba, ord_scipy", [("F", "fro"), ("1", 1), ("I", np.inf)]
)
def test_xlange(ord_numba, ord_scipy):
# xlange is called internally only, we don't dispatch pt.linalg.norm to it
from scipy import linalg
from pytensor.link.numba.dispatch.slinalg import _xlange
@numba.njit()
def xlange(x, ord):
return _xlange(x, ord)
x = np.random.normal(size=(5, 5)).astype(floatX)
np.testing.assert_allclose(xlange(x, ord_numba), linalg.norm(x, ord_scipy))
@pytest.mark.parametrize("ord_numba, ord_scipy", [("1", 1), ("I", np.inf)])
def test_xgecon(ord_numba, ord_scipy):
# gecon is called internally only, we don't dispatch pt.linalg.norm to it
from scipy.linalg import get_lapack_funcs
from pytensor.link.numba.dispatch.slinalg import _xgecon, _xlange
@numba.njit()
def gecon(x, norm):
anorm = _xlange(x, norm)
cond, info = _xgecon(x, anorm, norm)
return cond, info
x = np.random.normal(size=(5, 5)).astype(floatX)
rcond, info = gecon(x, norm=ord_numba)
# Test against direct call to the underlying LAPACK functions
# Solution does **not** agree with 1 / np.linalg.cond(x) !
lange, gecon = get_lapack_funcs(("lange", "gecon"), (x,))
norm = lange(ord_numba, x)
rcond2, _ = gecon(x, norm, norm=ord_numba)
assert info == 0
np.testing.assert_allclose(rcond, rcond2)
@pytest.mark.parametrize("overwrite_a", [True, False])
def test_getrf(overwrite_a):
from scipy.linalg import lu_factor
from pytensor.link.numba.dispatch.slinalg import _getrf
# TODO: Refactor this test to use compare_numba_and_py after we implement lu_factor in pytensor
@numba.njit()
def getrf(x, overwrite_a):
return _getrf(x, overwrite_a=overwrite_a)
x = np.random.normal(size=(5, 5)).astype(floatX)
x = np.asfortranarray(
x
) # x needs to be fortran-contiguous going into getrf for the overwrite option to work
lu, ipiv = lu_factor(x, overwrite_a=False)
LU, IPIV, info = getrf(x, overwrite_a=overwrite_a)
assert info == 0
assert_allclose(LU, lu)
if overwrite_a:
assert_allclose(x, LU)
# TODO: It seems IPIV is 1-indexed in FORTRAN, so we need to subtract 1. I can't find evidence that scipy is doing
# this, though.
assert_allclose(IPIV - 1, ipiv)
@pytest.mark.parametrize("trans", [0, 1])
@pytest.mark.parametrize("overwrite_a", [True, False])
@pytest.mark.parametrize("overwrite_b", [True, False])
@pytest.mark.parametrize("b_shape", [(5,), (5, 3)], ids=["b_1d", "b_2d"])
def test_getrs(trans, overwrite_a, overwrite_b, b_shape):
from scipy.linalg import lu_factor
from scipy.linalg import lu_solve as sp_lu_solve
from pytensor.link.numba.dispatch.slinalg import _getrf, _getrs
# TODO: Refactor this test to use compare_numba_and_py after we implement lu_solve in pytensor
@numba.njit()
def lu_solve(a, b, trans, overwrite_a, overwrite_b):
lu, ipiv, info = _getrf(a, overwrite_a=overwrite_a)
x, info = _getrs(lu, b, ipiv, trans=trans, overwrite_b=overwrite_b)
return x, lu, info
a = np.random.normal(size=(5, 5)).astype(floatX)
b = np.random.normal(size=b_shape).astype(floatX)
# inputs need to be fortran-contiguous going into getrf and getrs for the overwrite option to work
a = np.asfortranarray(a)
b = np.asfortranarray(b)
lu_and_piv = lu_factor(a, overwrite_a=False)
x_sp = sp_lu_solve(lu_and_piv, b, trans, overwrite_b=False)
x, lu, info = lu_solve(
a, b, trans, overwrite_a=overwrite_a, overwrite_b=overwrite_b
)
assert info == 0
if overwrite_a:
assert_allclose(a, lu)
if overwrite_b:
assert_allclose(b, x)
assert_allclose(x, x_sp)
@pytest.mark.parametrize(
"b_shape",
[(5, 1), (5, 5), (5,)],
ids=["b_col_vec", "b_matrix", "b_vec"],
)
@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str)
@pytest.mark.filterwarnings(
'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
)
def test_solve(b_shape: tuple[int], assume_a: Literal["gen", "sym", "pos"]):
A = pt.matrix("A", dtype=floatX)
b = pt.tensor("b", shape=b_shape, dtype=floatX)
A_val = np.asfortranarray(np.random.normal(size=(5, 5)).astype(floatX))
b_val = np.asfortranarray(np.random.normal(size=b_shape).astype(floatX))
def A_func(x):
if assume_a == "pos":
x = x @ x.T
elif assume_a == "sym":
x = (x + x.T) / 2
return x
X = pt.linalg.solve(
A_func(A),
b,
assume_a=assume_a,
b_ndim=len(b_shape),
)
f = pytensor.function(
[pytensor.In(A, mutable=True), pytensor.In(b, mutable=True)], X, mode="NUMBA"
)
op = f.maker.fgraph.outputs[0].owner.op
compare_numba_and_py(([A, b], [X]), inputs=[A_val, b_val], inplace=True)
# Calling this is destructive and will rewrite b_val to be the answer. Store copies of the inputs first.
A_val_copy = A_val.copy()
b_val_copy = b_val.copy()
X_np = f(A_val, b_val)
# overwrite_b is preferred when both inputs can be destroyed
assert op.destroy_map == {0: [1]}
# Confirm inputs were destroyed by checking against the copies
assert (A_val == A_val_copy).all() == (op.destroy_map.get(0, None) != [0])
assert (b_val == b_val_copy).all() == (op.destroy_map.get(0, None) != [1])
ATOL = 1e-8 if floatX.endswith("64") else 1e-4
RTOL = 1e-8 if floatX.endswith("64") else 1e-4
# Confirm b_val is used to store to solution
np.testing.assert_allclose(X_np, b_val, atol=ATOL, rtol=RTOL)
assert not np.allclose(b_val, b_val_copy)
# Test that the result is numerically correct. Need to use the unmodified copy
np.testing.assert_allclose(
A_func(A_val_copy) @ X_np, b_val_copy, atol=ATOL, rtol=RTOL
)
# See the note in tensor/test_slinalg.py::test_solve_correctness for details about the setup here
utt.verify_grad(
lambda A, b: pt.linalg.solve(
A_func(A), b, lower=False, assume_a=assume_a, b_ndim=len(b_shape)
),
[A_val_copy, b_val_copy],
mode="NUMBA",
)
@pytest.mark.parametrize(
"b_func, b_size",
[(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))],
ids=["b_col_vec", "b_matrix", "b_vec"],
)
@pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower = {x}")
def test_cho_solve(b_func, b_size, lower):
A = pt.matrix("A", dtype=floatX)
b = b_func("b", dtype=floatX)
C = pt.linalg.cholesky(A, lower=lower)
X = pt.linalg.cho_solve((C, lower), b)
f = pytensor.function([A, b], X, mode="NUMBA")
A = np.random.normal(size=(5, 5)).astype(floatX)
A = A @ A.conj().T
b = np.random.normal(size=b_size)
b = b.astype(floatX)
X_np = f(A, b)
ATOL = 1e-8 if floatX.endswith("64") else 1e-4
RTOL = 1e-8 if floatX.endswith("64") else 1e-4
np.testing.assert_allclose(A @ X_np, b, atol=ATOL, rtol=RTOL)
...@@ -209,12 +209,12 @@ class TestSolveBase: ...@@ -209,12 +209,12 @@ class TestSolveBase:
) )
class TestSolve(utt.InferShapeTester): def test_solve_raises_on_invalid_A():
def test__init__(self): with pytest.raises(ValueError, match="is not a recognized matrix structure"):
with pytest.raises(ValueError) as excinfo: Solve(assume_a="test", b_ndim=2)
Solve(assume_a="test", b_ndim=2)
assert "is not a recognized matrix structure" in str(excinfo.value)
class TestSolve(utt.InferShapeTester):
@pytest.mark.parametrize("b_shape", [(5, 1), (5,)]) @pytest.mark.parametrize("b_shape", [(5, 1), (5,)])
def test_infer_shape(self, b_shape): def test_infer_shape(self, b_shape):
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
...@@ -232,64 +232,78 @@ class TestSolve(utt.InferShapeTester): ...@@ -232,64 +232,78 @@ class TestSolve(utt.InferShapeTester):
warn=False, warn=False,
) )
def test_correctness(self): @pytest.mark.parametrize(
"b_size", [(5, 1), (5, 5), (5,)], ids=["b_col_vec", "b_matrix", "b_vec"]
)
@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str)
def test_solve_correctness(self, b_size: tuple[int], assume_a: str):
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
A = matrix() A = pt.tensor("A", shape=(5, 5))
b = matrix() b = pt.tensor("b", shape=b_size)
y = solve(A, b)
gen_solve_func = pytensor.function([A, b], y)
b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX) A_val = rng.normal(size=(5, 5)).astype(config.floatX)
b_val = rng.normal(size=b_size).astype(config.floatX)
A_val = np.asarray(rng.random((5, 5)), dtype=config.floatX) solve_op = functools.partial(solve, assume_a=assume_a, b_ndim=len(b_size))
A_val = np.dot(A_val.transpose(), A_val)
np.testing.assert_allclose( def A_func(x):
scipy.linalg.solve(A_val, b_val, assume_a="gen"), if assume_a == "pos":
gen_solve_func(A_val, b_val), return x @ x.T
) elif assume_a == "sym":
return (x + x.T) / 2
else:
return x
solve_input_val = A_func(A_val)
y = solve_op(A_func(A), b)
solve_func = pytensor.function([A, b], y)
X_np = solve_func(A_val.copy(), b_val.copy())
ATOL = 1e-8 if config.floatX.endswith("64") else 1e-4
RTOL = 1e-8 if config.floatX.endswith("64") else 1e-4
A_undef = np.array(
[
[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 1],
[0, 0, 0, 1, 0],
],
dtype=config.floatX,
)
np.testing.assert_allclose( np.testing.assert_allclose(
scipy.linalg.solve(A_undef, b_val), gen_solve_func(A_undef, b_val) scipy.linalg.solve(solve_input_val, b_val, assume_a=assume_a),
X_np,
atol=ATOL,
rtol=RTOL,
) )
np.testing.assert_allclose(A_func(A_val) @ X_np, b_val, atol=ATOL, rtol=RTOL)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"m, n, assume_a, lower", "b_size", [(5, 1), (5, 5), (5,)], ids=["b_col_vec", "b_matrix", "b_vec"]
[
(5, None, "gen", False),
(5, None, "gen", True),
(4, 2, "gen", False),
(4, 2, "gen", True),
],
) )
def test_solve_grad(self, m, n, assume_a, lower): @pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str)
@pytest.mark.skipif(
config.floatX == "float32", reason="Gradients not numerically stable in float32"
)
def test_solve_gradient(self, b_size: tuple[int], assume_a: str):
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
# Ensure diagonal elements of `A` are relatively large to avoid eps = 2e-8 if config.floatX == "float64" else None
# numerical precision issues
A_val = (rng.normal(size=(m, m)) * 0.5 + np.eye(m)).astype(config.floatX)
if n is None: A_val = rng.normal(size=(5, 5)).astype(config.floatX)
b_val = rng.normal(size=m).astype(config.floatX) b_val = rng.normal(size=b_size).astype(config.floatX)
else:
b_val = rng.normal(size=(m, n)).astype(config.floatX)
eps = None def A_func(x):
if config.floatX == "float64": if assume_a == "pos":
eps = 2e-8 return x @ x.T
elif assume_a == "sym":
return (x + x.T) / 2
else:
return x
solve_op = Solve(assume_a=assume_a, lower=lower, b_ndim=1 if n is None else 2) solve_op = functools.partial(solve, assume_a=assume_a, b_ndim=len(b_size))
utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)
# To correctly check the gradients, we need to include a transformation from the space of unconstrained matrices
# (A) to a valid input matrix for the given solver. This is done by the A_func function. If this isn't included,
# the random perturbations used by verify_grad will result in invalid input matrices, and
# LAPACK will silently do the wrong thing, making the gradients wrong
utt.verify_grad(
lambda A, b: solve_op(A_func(A), b), [A_val, b_val], 3, rng, eps=eps
)
class TestSolveTriangular(utt.InferShapeTester): class TestSolveTriangular(utt.InferShapeTester):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论