提交 c3d877fe authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

Add numba dispatch for hermetian solve

上级 bf0fe7af
...@@ -390,6 +390,51 @@ class _LAPACK: ...@@ -390,6 +390,51 @@ class _LAPACK:
return sysv return sysv
@classmethod
def numba_xhesv(cls, dtype) -> CPUDispatcher:
"""
Solve a system of linear equations A @ X = B with a Hermitian matrix A using the diagonal pivoting method,
factorizing A into LDL^H or UDU^H form, depending on the value of UPLO.
Called by scipy.linalg.solve when assume_a == "her" with complex inputs.
"""
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}hesv"
@numba_basic.numba_njit
def get_hesv_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "hesv")
return ptr
hesv_function_type = types.FunctionType(
types.void(
nb_i32p, # UPLO
nb_i32p, # N
nb_i32p, # NRHS
float_pointer, # A
nb_i32p, # LDA
nb_i32p, # IPIV
float_pointer, # B
nb_i32p, # LDB
float_pointer, # WORK
nb_i32p, # LWORK
nb_i32p, # INFO
)
)
@numba_basic.numba_njit
def hesv(UPLO, N, NRHS, A, LDA, IPIV, B, LDB, WORK, LWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_hesv_pointer,
func_type_ref=hesv_function_type,
unique_func_name_lit=unique_func_name,
)
fn(UPLO, N, NRHS, A, LDA, IPIV, B, LDB, WORK, LWORK, INFO)
return hesv
@classmethod @classmethod
def numba_xposv(cls, dtype) -> CPUDispatcher: def numba_xposv(cls, dtype) -> CPUDispatcher:
""" """
......
from collections.abc import Callable
import numpy as np
from numba.core.extending import overload
from numba.core.types import Complex
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match,
_check_linalg_matrix,
_copy_to_fortran_order_even_if_1d,
)
def _solve_hermitian(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
transposed: bool,
):
"""Thin wrapper around scipy.linalg.solve for Hermitian matrices. Used as an overload target for numba to avoid
unexpected side-effects when users import pytensor."""
return linalg.solve(
A,
B,
lower=lower,
overwrite_a=overwrite_a,
overwrite_b=overwrite_b,
check_finite=False,
assume_a="her",
transposed=transposed,
)
@overload(_solve_hermitian)
def solve_hermitian_impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool], np.ndarray]:
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Complex, func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Complex, func_name="solve")
_check_dtypes_match((A, B), func_name="solve")
dtype = A.dtype
numba_hesv = _LAPACK().numba_xhesv(A.dtype)
def impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
transposed: bool,
) -> np.ndarray:
_LDA, _N = np.int32(A.shape[-2:]) # type: ignore
_solve_check_input_shapes(A, B)
if overwrite_a and (A.flags.f_contiguous or A.flags.c_contiguous):
A_copy = A
if A.flags.c_contiguous:
# For Hermitian matrices, A^T = conj(A), so transposing
# swaps upper/lower AND conjugates. We can't just flip lower
# like we do for symmetric. We must copy instead.
A_copy = _copy_to_fortran_order(A)
else:
A_copy = _copy_to_fortran_order(A)
B_is_1d = B.ndim == 1
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)
if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)
NRHS = 1 if B_is_1d else int(B.shape[-1])
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
N = val_to_int_ptr(_N) # type: ignore
NRHS = val_to_int_ptr(NRHS)
LDA = val_to_int_ptr(_LDA) # type: ignore
IPIV = np.empty(_N, dtype=np.int32) # type: ignore
LDB = val_to_int_ptr(_N) # type: ignore
WORK = np.empty(1, dtype=dtype)
LWORK = val_to_int_ptr(-1)
INFO = val_to_int_ptr(0)
# Workspace query
numba_hesv(
UPLO,
N,
NRHS,
A_copy.ctypes,
LDA,
IPIV.ctypes,
B_copy.ctypes,
LDB,
WORK.ctypes,
LWORK,
INFO,
)
WS_SIZE = np.int32(WORK[0].real)
LWORK = val_to_int_ptr(WS_SIZE)
WORK = np.empty(WS_SIZE, dtype=dtype)
# Actual solve
numba_hesv(
UPLO,
N,
NRHS,
A_copy.ctypes,
LDA,
IPIV.ctypes,
B_copy.ctypes,
LDB,
WORK.ctypes,
LWORK,
INFO,
)
if int_ptr_to_val(INFO) != 0:
B_copy = np.full_like(B_copy, np.nan)
if B_is_1d:
B_copy = B_copy[..., 0]
return B_copy
return impl
...@@ -41,6 +41,7 @@ from pytensor.link.numba.dispatch.linalg.decomposition.schur import ( ...@@ -41,6 +41,7 @@ from pytensor.link.numba.dispatch.linalg.decomposition.schur import (
) )
from pytensor.link.numba.dispatch.linalg.solve.cholesky import _cho_solve from pytensor.link.numba.dispatch.linalg.solve.cholesky import _cho_solve
from pytensor.link.numba.dispatch.linalg.solve.general import _solve_gen from pytensor.link.numba.dispatch.linalg.solve.general import _solve_gen
from pytensor.link.numba.dispatch.linalg.solve.hermitian import _solve_hermitian
from pytensor.link.numba.dispatch.linalg.solve.linear_control import ( from pytensor.link.numba.dispatch.linalg.solve.linear_control import (
_trsyl, _trsyl,
) )
...@@ -289,10 +290,10 @@ def numba_funcify_Solve(op, node, **kwargs): ...@@ -289,10 +290,10 @@ def numba_funcify_Solve(op, node, **kwargs):
print("Solve requires casting second input `b`") # noqa: T201 print("Solve requires casting second input `b`") # noqa: T201
overwrite_a = op.overwrite_a overwrite_a = op.overwrite_a
assume_a = op.assume_a
lower = op.lower lower = op.lower
overwrite_a = op.overwrite_a overwrite_a = op.overwrite_a
overwrite_b = op.overwrite_b overwrite_b = op.overwrite_b
is_complex = out_dtype.kind == "c"
transposed = False # TODO: Solve doesnt currently allow the transposed argument transposed = False # TODO: Solve doesnt currently allow the transposed argument
if assume_a == "gen": if assume_a == "gen":
...@@ -300,8 +301,8 @@ def numba_funcify_Solve(op, node, **kwargs): ...@@ -300,8 +301,8 @@ def numba_funcify_Solve(op, node, **kwargs):
elif assume_a == "sym": elif assume_a == "sym":
solve_fn = _solve_symmetric solve_fn = _solve_symmetric
elif assume_a == "her": elif assume_a == "her":
# We already ruled out complex inputs # For real inputs, Hermitian == symmetric
solve_fn = _solve_symmetric solve_fn = _solve_hermitian if is_complex else _solve_symmetric
elif assume_a == "pos": elif assume_a == "pos":
solve_fn = _solve_psd solve_fn = _solve_psd
elif assume_a == "tridiagonal": elif assume_a == "tridiagonal":
......
...@@ -49,7 +49,9 @@ class TestSolves: ...@@ -49,7 +49,9 @@ class TestSolves:
[(5, 1), (5, 5), (5,)], [(5, 1), (5, 5), (5,)],
ids=["b_col_vec", "b_matrix", "b_vec"], ids=["b_col_vec", "b_matrix", "b_vec"],
) )
@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos", "tridiagonal"], ids=str) @pytest.mark.parametrize(
"assume_a", ["gen", "sym", "her", "pos", "tridiagonal"], ids=str
)
@pytest.mark.parametrize("is_complex", [True, False], ids=["complex", "real"]) @pytest.mark.parametrize("is_complex", [True, False], ids=["complex", "real"])
def test_solve( def test_solve(
self, self,
...@@ -77,6 +79,10 @@ class TestSolves: ...@@ -77,6 +79,10 @@ class TestSolves:
# We have to set the unused triangle to something other than zero # We have to set the unused triangle to something other than zero
# to see lapack destroying it. # to see lapack destroying it.
x[np.triu_indices(n, 1) if lower else np.tril_indices(n, 1)] = np.pi x[np.triu_indices(n, 1) if lower else np.tril_indices(n, 1)] = np.pi
elif assume_a == "her":
x = (x + x.conj().T) / 2
n = x.shape[0]
x[np.triu_indices(n, 1) if lower else np.tril_indices(n, 1)] = np.pi
elif assume_a == "tridiagonal": elif assume_a == "tridiagonal":
_x = x _x = x
x = np.zeros_like(x) x = np.zeros_like(x)
...@@ -152,7 +158,7 @@ class TestSolves: ...@@ -152,7 +158,7 @@ class TestSolves:
# We can destroy C-contiguous A arrays by inverting `transpose/lower` at runtime # We can destroy C-contiguous A arrays by inverting `transpose/lower` at runtime
# Complex posdef/hermitian can't use this trick (A^T = conj(A) != A for Hermitian) # Complex posdef/hermitian can't use this trick (A^T = conj(A) != A for Hermitian)
can_destroy_c_contig_A = overwrite_a and not ( can_destroy_c_contig_A = overwrite_a and not (
is_complex and assume_a in ("pos",) is_complex and assume_a in ("pos", "her")
) )
assert np.allclose(A_val_c_contig, A_val) == (not can_destroy_c_contig_A) assert np.allclose(A_val_c_contig, A_val) == (not can_destroy_c_contig_A)
# b vectors are always f_contiguous if also c_contiguous # b vectors are always f_contiguous if also c_contiguous
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论