提交 9f911e35 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Jesse Grabowski

Add Numba dispatch for QZ

上级 ad8dca48
...@@ -946,3 +946,388 @@ class _LAPACK: ...@@ -946,3 +946,388 @@ class _LAPACK:
fn(TRANA, TRANB, ISGN, M, N, A, LDA, B, LDB, C, LDC, SCALE, INFO) fn(TRANA, TRANB, ISGN, M, N, A, LDA, B, LDB, C, LDC, SCALE, INFO)
return trsyl return trsyl
@classmethod
def numba_xgges(cls, dtype):
"""
Compute generalized eigenvalues and, optionally, the left and/or right generalized Schur vectors of a pair
of real nonsymmetric matrices (A,B).
Called by scipy.linalg.qz and scipy.linalg.ordqz.
"""
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}gges"
@numba_basic.numba_njit
def get_gges_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "gges")
return ptr
if isinstance(dtype, Complex):
real_pointer = nb_f64p if dtype is nb_c128 else nb_f32p
gges_function_type = types.FunctionType(
types.void(
nb_i32p, # JOBVSL
nb_i32p, # JOBVSR
nb_i32p, # SORT
nb_i32p, # SELECT
nb_i32p, # N
float_pointer, # A
nb_i32p, # LDA
float_pointer, # B
nb_i32p, # LDB
nb_i32p, # SDIM
float_pointer, # ALPHA
float_pointer, # BETA
float_pointer, # VSL
nb_i32p, # LDVSL
float_pointer, # VSR
nb_i32p, # LDVSR
float_pointer, # WORK
nb_i32p, # LWORK
real_pointer, # RWORK
nb_i32p, # BWORK
nb_i32p, # INFO
)
)
@numba_basic.numba_njit
def gges(
JOBVSL,
JOBVSR,
SORT,
SELECT,
N,
A,
LDA,
B,
LDB,
SDIM,
ALPHA,
BETA,
VSL,
LDVSL,
VSR,
LDVSR,
WORK,
LWORK,
RWORK,
BWORK,
INFO,
):
fn = _call_cached_ptr(
get_ptr_func=get_gges_pointer,
func_type_ref=gges_function_type,
unique_func_name_lit=unique_func_name,
)
fn(
JOBVSL,
JOBVSR,
SORT,
SELECT,
N,
A,
LDA,
B,
LDB,
SDIM,
ALPHA,
BETA,
VSL,
LDVSL,
VSR,
LDVSR,
WORK,
LWORK,
RWORK,
BWORK,
INFO,
)
else: # Real case
gges_function_type = types.FunctionType(
types.void(
nb_i32p, # JOBVSL
nb_i32p, # JOBVSR
nb_i32p, # SORT
nb_i32p, # SELECT
nb_i32p, # N
float_pointer, # A
nb_i32p, # LDA
float_pointer, # B
nb_i32p, # LDB
nb_i32p, # SDIM
float_pointer, # ALPHAR
float_pointer, # ALPHAI
float_pointer, # BETA
float_pointer, # VSL
nb_i32p, # LDVSL
float_pointer, # VSR
nb_i32p, # LDVSR
float_pointer, # WORK
nb_i32p, # LWORK
nb_i32p, # BWORK
nb_i32p, # INFO
)
)
@numba_basic.numba_njit
def gges(
JOBVSL,
JOBVSR,
SORT,
SELECT,
N,
A,
LDA,
B,
LDB,
SDIM,
ALPHAR,
ALPHAI,
BETA,
VSL,
LDVSL,
VSR,
LDVSR,
WORK,
LWORK,
BWORK,
INFO,
):
fn = _call_cached_ptr(
get_ptr_func=get_gges_pointer,
func_type_ref=gges_function_type,
unique_func_name_lit=unique_func_name,
)
fn(
JOBVSL,
JOBVSR,
SORT,
SELECT,
N,
A,
LDA,
B,
LDB,
SDIM,
ALPHAR,
ALPHAI,
BETA,
VSL,
LDVSL,
VSR,
LDVSR,
WORK,
LWORK,
BWORK,
INFO,
)
return gges
@classmethod
def numba_tgsen(cls, dtype):
"""
Reorders the generalized Schur decomposition of a matrix pair (A, B) by their eigenvalues.
Output is sorted so that a selected cluster of eigenvalues appears in the leading diagonal blocks of the pair
(A,B). The leading columns of Q and Z form unitary bases of the corresponding left and right eigenspaces
(deflating subspaces). (A, B) must be in generalized Schur canonical form, that is, A and B are both upper
triangular.
Used by scipy.linalg.ordqz.
"""
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}tgsen"
@numba_basic.numba_njit
def get_tgsen_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "tgsen")
return ptr
if isinstance(dtype, Complex):
real_pointer = nb_f64p if dtype is nb_c128 else nb_f32p
tgsen_function_type = types.FunctionType(
types.void(
nb_i32p, # IJOB
nb_i32p, # WANTQ
nb_i32p, # WANTZ
nb_i32p, # SELECT
nb_i32p, # N
float_pointer, # A
nb_i32p, # LDA
float_pointer, # B
nb_i32p, # LDB
float_pointer, # alpha
float_pointer, # beta
float_pointer, # Q
nb_i32p, # LDQ
float_pointer, # Z
nb_i32p, # LDZ
nb_i32p, # M
real_pointer, # PL
real_pointer, # PR
real_pointer, # DIF
float_pointer, # WORK
nb_i32p, # LWORK
nb_i32p, # IWORK
nb_i32p, # LIWORK
nb_i32p, # INFO
)
)
@numba_basic.numba_njit
def tgsen(
IJOB,
WANTQ,
WANTZ,
SELECT,
N,
A,
LDA,
B,
LDB,
alpha,
beta,
Q,
LDQ,
Z,
LDZ,
M,
PL,
PR,
DIF,
WORK,
LWORK,
IWORK,
LIWORK,
INFO,
):
fn = _call_cached_ptr(
get_ptr_func=get_tgsen_pointer,
func_type_ref=tgsen_function_type,
unique_func_name_lit=unique_func_name,
)
fn(
IJOB,
WANTQ,
WANTZ,
SELECT,
N,
A,
LDA,
B,
LDB,
alpha,
beta,
Q,
LDQ,
Z,
LDZ,
M,
PL,
PR,
DIF,
WORK,
LWORK,
IWORK,
LIWORK,
INFO,
)
else: # Real case
tgsen_function_type = types.FunctionType(
types.void(
nb_i32p, # IJOB
nb_i32p, # WANTQ
nb_i32p, # WANTZ
nb_i32p, # SELECT
nb_i32p, # N
float_pointer, # A
nb_i32p, # LDA
float_pointer, # B
nb_i32p, # LDB
float_pointer, # ALPHAR
float_pointer, # ALPHAI
float_pointer, # BETA
float_pointer, # Q
nb_i32p, # LDQ
float_pointer, # Z
nb_i32p, # LDZ
nb_i32p, # M
float_pointer, # PL
float_pointer, # PR
float_pointer, # DIF
float_pointer, # WORK
nb_i32p, # LWORK
nb_i32p, # IWORK
nb_i32p, # LIWORK
nb_i32p, # INFO
)
)
@numba_basic.numba_njit
def tgsen(
IJOB,
WANTQ,
WANTZ,
SELECT,
N,
A,
LDA,
B,
LDB,
ALPHAR,
ALPHAI,
BETA,
Q,
LDQ,
Z,
LDZ,
M,
PL,
PR,
DIF,
WORK,
LWORK,
IWORK,
LIWORK,
INFO,
):
fn = _call_cached_ptr(
get_ptr_func=get_tgsen_pointer,
func_type_ref=tgsen_function_type,
unique_func_name_lit=unique_func_name,
)
fn(
IJOB,
WANTQ,
WANTZ,
SELECT,
N,
A,
LDA,
B,
LDB,
ALPHAR,
ALPHAI,
BETA,
Q,
LDQ,
Z,
LDZ,
M,
PL,
PR,
DIF,
WORK,
LWORK,
IWORK,
LIWORK,
INFO,
)
return tgsen
import numpy as np
import scipy.linalg as scipy_linalg
from numba.core.extending import overload
from numba.core.types import Complex, Float
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix
@numba_basic.numba_njit
def _lhp(alpha, beta):
out = np.empty(alpha.shape, dtype=np.int32)
nonzero = beta != 0
out[~nonzero] = False
out[nonzero] = np.real(alpha[nonzero] / beta[nonzero]) < 0.0
return out
@numba_basic.numba_njit
def _rhp(alpha, beta):
out = np.empty(alpha.shape, dtype=np.int32)
nonzero = beta != 0
out[~nonzero] = False
out[nonzero] = np.real(alpha[nonzero] / beta[nonzero]) > 0.0
return out
@numba_basic.numba_njit
def _iuc(alpha, beta):
out = np.empty(alpha.shape, dtype=np.int32)
nonzero = beta != 0
out[~nonzero] = False
out[nonzero] = np.abs(alpha[nonzero] / beta[nonzero]) < 1.0
return out
@numba_basic.numba_njit
def _ouc(alpha, beta):
out = np.empty(alpha.shape, dtype=np.int32)
alpha_zero = alpha == 0
beta_zero = beta == 0
out[alpha_zero & beta_zero] = False
out[~alpha_zero & beta_zero] = True
out[~beta_zero] = np.abs(alpha[~beta_zero] / beta[~beta_zero]) > 1.0
return out
def _qz_real_nosort_noeig(A, B, overwrite_a=False, overwrite_b=False):
S, T, Q, Z = scipy_linalg.qz(
A,
B,
output="real",
overwrite_a=overwrite_a,
overwrite_b=overwrite_b,
check_finite=False,
)
return S, T, Q, Z
def _qz_real_nosort_eig(A, B, overwrite_a=False, overwrite_b=False):
S, T, Q, Z = scipy_linalg.qz(
A,
B,
output="real",
overwrite_a=overwrite_a,
overwrite_b=overwrite_b,
check_finite=False,
)
# There is no option to return eigenvalues directly from scipy.linalg.qz, so we have to compute them manually.
# Unlike the complex Schur form, the real Schur form can have 2x2 blocks on the main diagonal for complex conjugate
# pairs, so we can't just read off the eigenvalues and the diagonal elements of S and T.
n = S.shape[0]
alpha = np.empty(
n,
dtype=np.complex128
if _get_underlying_float(S.dtype) == np.float64
else np.complex64,
)
beta = np.empty(n, dtype=S.dtype)
i = 0
while i < n:
if i == n - 1 or S[i + 1, i] == 0:
# 1x1 block - real eigenvalue
alpha[i] = S[i, i]
beta[i] = T[i, i]
i += 1
else:
# 2x2 block - complex conjugate pair
a11, a12, a21, a22 = S[i, i], S[i, i + 1], S[i + 1, i], S[i + 1, i + 1]
b11, b22 = T[i, i], T[i + 1, i + 1]
# For standardized 2x2 blocks, eigenvalues are roots of det(A - lambda*B)
tr = (a11 * b22 + a22 * b11) / (b11 * b22)
det = (a11 * a22 - a12 * a21) / (b11 * b22)
disc = tr * tr / 4 - det
if disc < 0:
sqrt_disc = np.sqrt(-disc)
alpha[i] = tr / 2 + 1j * sqrt_disc
alpha[i + 1] = tr / 2 - 1j * sqrt_disc
else:
sqrt_disc = np.sqrt(disc)
alpha[i] = tr / 2 + sqrt_disc
alpha[i + 1] = tr / 2 - sqrt_disc
beta[i] = 1.0
beta[i + 1] = 1.0
i += 2
return S, T, alpha, beta, Q, Z
def _qz_real_sort_noeig(A, B, sort, overwrite_a=False, overwrite_b=False):
S, T, _, _, Q, Z = scipy_linalg.ordqz(
A,
B,
sort=sort,
output="real",
overwrite_a=overwrite_a,
overwrite_b=overwrite_b,
check_finite=False,
)
return S, T, Q, Z
def _qz_real_sort_eig(A, B, sort, overwrite_a=False, overwrite_b=False):
S, T, alpha, beta, Q, Z = scipy_linalg.ordqz(
A,
B,
sort=sort,
output="real",
overwrite_a=overwrite_a,
overwrite_b=overwrite_b,
check_finite=False,
)
return S, T, alpha, beta, Q, Z
def _qz_complex_nosort_noeig(A, B, overwrite_a=False, overwrite_b=False):
S, T, Q, Z = scipy_linalg.qz(
A,
B,
output="complex",
overwrite_a=overwrite_a,
overwrite_b=overwrite_b,
check_finite=False,
)
return S, T, Q, Z
def _qz_complex_nosort_eig(A, B, overwrite_a=False, overwrite_b=False):
S, T, Q, Z = scipy_linalg.qz(
A,
B,
output="complex",
overwrite_a=overwrite_a,
overwrite_b=overwrite_b,
check_finite=False,
)
# For complex Schur form, eigenvalues are simply the diagonal elements
alpha = np.diag(S)
beta = np.diag(T)
return S, T, alpha, beta, Q, Z
def _qz_complex_sort_noeig(A, B, sort, overwrite_a=False, overwrite_b=False):
S, T, _, _, Q, Z = scipy_linalg.ordqz(
A,
B,
sort=sort,
output="complex",
overwrite_a=overwrite_a,
overwrite_b=overwrite_b,
check_finite=False,
)
return S, T, Q, Z
def _qz_complex_sort_eig(A, B, sort, overwrite_a=False, overwrite_b=False):
S, T, alpha, beta, Q, Z = scipy_linalg.ordqz(
A,
B,
sort=sort,
output="complex",
overwrite_a=overwrite_a,
overwrite_b=overwrite_b,
check_finite=False,
)
return S, T, alpha, beta, Q, Z
@overload(_qz_real_nosort_noeig)
def qz_real_nosort_noeig_impl(A, B, overwrite_a, overwrite_b):
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=(Float,), func_name="qz")
_check_linalg_matrix(B, ndim=2, dtype=(Float,), func_name="qz")
dtype = A.dtype
numba_gges = _LAPACK().numba_xgges(dtype)
def impl(A, B, overwrite_a, overwrite_b):
_N = np.int32(A.shape[-1])
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order(B)
WORK = np.empty(1, dtype=dtype)
LWORK = val_to_int_ptr(-1)
JOBVSL = val_to_int_ptr(ord("V"))
JOBVSR = val_to_int_ptr(ord("V"))
SORT = val_to_int_ptr(ord("N"))
SELECT = val_to_int_ptr(0)
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
LDB = val_to_int_ptr(_N)
SDIM = val_to_int_ptr(0)
ALPHAR = np.empty(_N, dtype=dtype)
ALPHAI = np.empty(_N, dtype=dtype)
BETA = np.empty(_N, dtype=dtype)
LDVSL = val_to_int_ptr(_N)
VSL = np.empty((_N, _N), dtype=dtype)
LDVSR = val_to_int_ptr(_N)
VSR = np.empty((_N, _N), dtype=dtype)
BWORK = val_to_int_ptr(1)
INFO = val_to_int_ptr(0)
# Workspace query
numba_gges(
JOBVSL,
JOBVSR,
SORT,
SELECT,
N,
A_copy.ctypes,
LDA,
B_copy.ctypes,
LDB,
SDIM,
ALPHAR.ctypes,
ALPHAI.ctypes,
BETA.ctypes,
VSL.ctypes,
LDVSL,
VSR.ctypes,
LDVSR,
WORK.ctypes,
LWORK,
BWORK,
INFO,
)
WS_SIZE = np.int32(WORK[0].real)
LWORK = val_to_int_ptr(WS_SIZE)
WORK = np.empty(WS_SIZE, dtype=dtype)
# Actual call
numba_gges(
JOBVSL,
JOBVSR,
SORT,
SELECT,
N,
A_copy.ctypes,
LDA,
B_copy.ctypes,
LDB,
SDIM,
ALPHAR.ctypes,
ALPHAI.ctypes,
BETA.ctypes,
VSL.ctypes,
LDVSL,
VSR.ctypes,
LDVSR,
WORK.ctypes,
LWORK,
BWORK,
INFO,
)
if int_ptr_to_val(INFO) != 0:
A_copy[:] = np.nan
B_copy[:] = np.nan
VSL[:] = np.nan
VSR[:] = np.nan
return A_copy, B_copy, VSL.T, VSR.T
return impl
@overload(_qz_real_nosort_eig)
def qz_real_nosort_eig_impl(A, B, overwrite_a, overwrite_b):
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=(Float,), func_name="qz")
_check_linalg_matrix(B, ndim=2, dtype=(Float,), func_name="qz")
dtype = A.dtype
numba_gges = _LAPACK().numba_xgges(dtype)
def impl(A, B, overwrite_a, overwrite_b):
_N = np.int32(A.shape[-1])
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order(B)
WORK = np.empty(1, dtype=dtype)
LWORK = val_to_int_ptr(-1)
JOBVSL = val_to_int_ptr(ord("V"))
JOBVSR = val_to_int_ptr(ord("V"))
SORT = val_to_int_ptr(ord("N"))
SELECT = val_to_int_ptr(0)
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
LDB = val_to_int_ptr(_N)
SDIM = val_to_int_ptr(0)
ALPHAR = np.empty(_N, dtype=dtype)
ALPHAI = np.empty(_N, dtype=dtype)
BETA = np.empty(_N, dtype=dtype)
LDVSL = val_to_int_ptr(_N)
VSL = np.empty((_N, _N), dtype=dtype)
LDVSR = val_to_int_ptr(_N)
VSR = np.empty((_N, _N), dtype=dtype)
BWORK = val_to_int_ptr(1)
INFO = val_to_int_ptr(0)
# Workspace query
numba_gges(
JOBVSL,
JOBVSR,
SORT,
SELECT,
N,
A_copy.ctypes,
LDA,
B_copy.ctypes,
LDB,
SDIM,
ALPHAR.ctypes,
ALPHAI.ctypes,
BETA.ctypes,
VSL.ctypes,
LDVSL,
VSR.ctypes,
LDVSR,
WORK.ctypes,
LWORK,
BWORK,
INFO,
)
WS_SIZE = np.int32(WORK[0].real)
LWORK = val_to_int_ptr(WS_SIZE)
WORK = np.empty(WS_SIZE, dtype=dtype)
# Actual call
numba_gges(
JOBVSL,
JOBVSR,
SORT,
SELECT,
N,
A_copy.ctypes,
LDA,
B_copy.ctypes,
LDB,
SDIM,
ALPHAR.ctypes,
ALPHAI.ctypes,
BETA.ctypes,
VSL.ctypes,
LDVSL,
VSR.ctypes,
LDVSR,
WORK.ctypes,
LWORK,
BWORK,
INFO,
)
if int_ptr_to_val(INFO) != 0:
A_copy[:] = np.nan
B_copy[:] = np.nan
VSL[:] = np.nan
VSR[:] = np.nan
alpha = ALPHAR + 1j * ALPHAI
return A_copy, B_copy, alpha, BETA, VSL.T, VSR.T
return impl
@overload(_qz_real_sort_noeig)
def qz_real_sort_noeig_impl(A, B, sort, overwrite_a, overwrite_b):
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=(Float,), func_name="qz")
_check_linalg_matrix(B, ndim=2, dtype=(Float,), func_name="qz")
dtype = A.dtype
numba_gges = _LAPACK().numba_xgges(dtype)
numba_tgsen = _LAPACK().numba_tgsen(dtype)
def impl(A, B, sort, overwrite_a, overwrite_b):
_N = np.int32(A.shape[-1])
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order(B)
WORK = np.empty(1, dtype=dtype)
LWORK = val_to_int_ptr(-1)
JOBVSL = val_to_int_ptr(ord("V"))
JOBVSR = val_to_int_ptr(ord("V"))
SORT = val_to_int_ptr(ord("N"))
SELECT = val_to_int_ptr(0)
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
LDB = val_to_int_ptr(_N)
SDIM = val_to_int_ptr(0)
ALPHAR = np.empty(_N, dtype=dtype)
ALPHAI = np.empty(_N, dtype=dtype)
BETA = np.empty(_N, dtype=dtype)
LDVSL = val_to_int_ptr(_N)
VSL = np.empty((_N, _N), dtype=dtype)
LDVSR = val_to_int_ptr(_N)
VSR = np.empty((_N, _N), dtype=dtype)
BWORK = val_to_int_ptr(1)
INFO = val_to_int_ptr(0)
# Workspace query for gges
numba_gges(
JOBVSL,
JOBVSR,
SORT,
SELECT,
N,
A_copy.ctypes,
LDA,
B_copy.ctypes,
LDB,
SDIM,
ALPHAR.ctypes,
ALPHAI.ctypes,
BETA.ctypes,
VSL.ctypes,
LDVSL,
VSR.ctypes,
LDVSR,
WORK.ctypes,
LWORK,
BWORK,
INFO,
)
WS_SIZE = np.int32(WORK[0].real)
LWORK = val_to_int_ptr(WS_SIZE)
WORK = np.empty(WS_SIZE, dtype=dtype)
# Actual gges call
numba_gges(
JOBVSL,
JOBVSR,
SORT,
SELECT,
N,
A_copy.ctypes,
LDA,
B_copy.ctypes,
LDB,
SDIM,
ALPHAR.ctypes,
ALPHAI.ctypes,
BETA.ctypes,
VSL.ctypes,
LDVSL,
VSR.ctypes,
LDVSR,
WORK.ctypes,
LWORK,
BWORK,
INFO,
)
if int_ptr_to_val(INFO) != 0:
A_copy[:] = np.nan
B_copy[:] = np.nan
VSL[:] = np.nan
VSR[:] = np.nan
return A_copy, B_copy, VSL.T, VSR.T
# Apply sorting via tgsen
alpha = ALPHAR + 1j * ALPHAI
if sort == "lhp":
select = _lhp(alpha, BETA)
elif sort == "rhp":
select = _rhp(alpha, BETA)
elif sort == "iuc":
select = _iuc(alpha, BETA)
else: # ouc
select = _ouc(alpha, BETA)
IJOB = val_to_int_ptr(0)
WANTQ = val_to_int_ptr(1)
WANTZ = val_to_int_ptr(1)
LDQ = val_to_int_ptr(_N)
LDZ = val_to_int_ptr(_N)
M = val_to_int_ptr(0)
PL = np.empty(1, dtype=dtype)
PR = np.empty(1, dtype=dtype)
DIF = np.empty(2, dtype=dtype)
TGSEN_LWORK = val_to_int_ptr(4 * _N + 16)
TGSEN_WORK = np.empty(4 * _N + 16, dtype=dtype)
LIWORK = val_to_int_ptr(1)
IWORK = np.empty(1, dtype=np.int32)
INFO = val_to_int_ptr(0)
numba_tgsen(
IJOB,
WANTQ,
WANTZ,
select.ctypes,
N,
A_copy.ctypes,
LDA,
B_copy.ctypes,
LDB,
ALPHAR.ctypes,
ALPHAI.ctypes,
BETA.ctypes,
VSL.ctypes,
LDQ,
VSR.ctypes,
LDZ,
M,
PL.ctypes,
PR.ctypes,
DIF.ctypes,
TGSEN_WORK.ctypes,
TGSEN_LWORK,
IWORK.ctypes,
LIWORK,
INFO,
)
if int_ptr_to_val(INFO) != 0:
A_copy[:] = np.nan
B_copy[:] = np.nan
VSL[:] = np.nan
VSR[:] = np.nan
return A_copy, B_copy, VSL.T, VSR.T
return impl
@overload(_qz_real_sort_eig)
def qz_real_sort_eig_impl(A, B, sort, overwrite_a, overwrite_b):
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=(Float,), func_name="qz")
_check_linalg_matrix(B, ndim=2, dtype=(Float,), func_name="qz")
dtype = A.dtype
numba_gges = _LAPACK().numba_xgges(dtype)
numba_tgsen = _LAPACK().numba_tgsen(dtype)
def impl(A, B, sort, overwrite_a, overwrite_b):
_N = np.int32(A.shape[-1])
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order(B)
WORK = np.empty(1, dtype=dtype)
LWORK = val_to_int_ptr(-1)
JOBVSL = val_to_int_ptr(ord("V"))
JOBVSR = val_to_int_ptr(ord("V"))
SORT = val_to_int_ptr(ord("N"))
SELECT = val_to_int_ptr(0)
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
LDB = val_to_int_ptr(_N)
SDIM = val_to_int_ptr(0)
ALPHAR = np.empty(_N, dtype=dtype)
ALPHAI = np.empty(_N, dtype=dtype)
BETA = np.empty(_N, dtype=dtype)
LDVSL = val_to_int_ptr(_N)
VSL = np.empty((_N, _N), dtype=dtype)
LDVSR = val_to_int_ptr(_N)
VSR = np.empty((_N, _N), dtype=dtype)
BWORK = val_to_int_ptr(1)
INFO = val_to_int_ptr(0)
# Workspace query for gges
numba_gges(
JOBVSL,
JOBVSR,
SORT,
SELECT,
N,
A_copy.ctypes,
LDA,
B_copy.ctypes,
LDB,
SDIM,
ALPHAR.ctypes,
ALPHAI.ctypes,
BETA.ctypes,
VSL.ctypes,
LDVSL,
VSR.ctypes,
LDVSR,
WORK.ctypes,
LWORK,
BWORK,
INFO,
)
WS_SIZE = np.int32(WORK[0].real)
LWORK = val_to_int_ptr(WS_SIZE)
WORK = np.empty(WS_SIZE, dtype=dtype)
# Actual gges call
numba_gges(
JOBVSL,
JOBVSR,
SORT,
SELECT,
N,
A_copy.ctypes,
LDA,
B_copy.ctypes,
LDB,
SDIM,
ALPHAR.ctypes,
ALPHAI.ctypes,
BETA.ctypes,
VSL.ctypes,
LDVSL,
VSR.ctypes,
LDVSR,
WORK.ctypes,
LWORK,
BWORK,
INFO,
)
if int_ptr_to_val(INFO) != 0:
A_copy[:] = np.nan
B_copy[:] = np.nan
VSL[:] = np.nan
VSR[:] = np.nan
alpha = ALPHAR + 1j * ALPHAI
return A_copy, B_copy, alpha, BETA, VSL.T, VSR.T
# Apply sorting via tgsen
alpha = ALPHAR + 1j * ALPHAI
if sort == "lhp":
select = _lhp(alpha, BETA)
elif sort == "rhp":
select = _rhp(alpha, BETA)
elif sort == "iuc":
select = _iuc(alpha, BETA)
else: # ouc
select = _ouc(alpha, BETA)
IJOB = val_to_int_ptr(0)
WANTQ = val_to_int_ptr(1)
WANTZ = val_to_int_ptr(1)
LDQ = val_to_int_ptr(_N)
LDZ = val_to_int_ptr(_N)
M = val_to_int_ptr(0)
PL = np.empty(1, dtype=dtype)
PR = np.empty(1, dtype=dtype)
DIF = np.empty(2, dtype=dtype)
TGSEN_LWORK = val_to_int_ptr(4 * _N + 16)
TGSEN_WORK = np.empty(4 * _N + 16, dtype=dtype)
LIWORK = val_to_int_ptr(1)
IWORK = np.empty(1, dtype=np.int32)
INFO = val_to_int_ptr(0)
numba_tgsen(
IJOB,
WANTQ,
WANTZ,
select.ctypes,
N,
A_copy.ctypes,
LDA,
B_copy.ctypes,
LDB,
ALPHAR.ctypes,
ALPHAI.ctypes,
BETA.ctypes,
VSL.ctypes,
LDQ,
VSR.ctypes,
LDZ,
M,
PL.ctypes,
PR.ctypes,
DIF.ctypes,
TGSEN_WORK.ctypes,
TGSEN_LWORK,
IWORK.ctypes,
LIWORK,
INFO,
)
if int_ptr_to_val(INFO) != 0:
A_copy[:] = np.nan
B_copy[:] = np.nan
VSL[:] = np.nan
VSR[:] = np.nan
# Recompute alpha after tgsen
alpha = ALPHAR + 1j * ALPHAI
return A_copy, B_copy, alpha, BETA, VSL.T, VSR.T
return impl
@overload(_qz_complex_nosort_noeig)
def qz_complex_nosort_noeig_impl(A, B, overwrite_a, overwrite_b):
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=(Complex,), func_name="qz")
_check_linalg_matrix(B, ndim=2, dtype=(Complex,), func_name="qz")
dtype = A.dtype
real_dtype = _get_underlying_float(dtype)
numba_gges = _LAPACK().numba_xgges(dtype)
def impl(A, B, overwrite_a, overwrite_b):
_N = np.int32(A.shape[-1])
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order(B)
WORK = np.empty(1, dtype=dtype)
LWORK = val_to_int_ptr(-1)
JOBVSL = val_to_int_ptr(ord("V"))
JOBVSR = val_to_int_ptr(ord("V"))
SORT = val_to_int_ptr(ord("N"))
SELECT = val_to_int_ptr(0)
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
LDB = val_to_int_ptr(_N)
SDIM = val_to_int_ptr(0)
ALPHA = np.empty(_N, dtype=dtype)
BETA = np.empty(_N, dtype=dtype)
LDVSL = val_to_int_ptr(_N)
VSL = np.empty((_N, _N), dtype=dtype)
LDVSR = val_to_int_ptr(_N)
VSR = np.empty((_N, _N), dtype=dtype)
RWORK = np.empty(8 * _N, dtype=real_dtype)
BWORK = val_to_int_ptr(1)
INFO = val_to_int_ptr(0)
# Workspace query
numba_gges(
JOBVSL,
JOBVSR,
SORT,
SELECT,
N,
A_copy.ctypes,
LDA,
B_copy.ctypes,
LDB,
SDIM,
ALPHA.ctypes,
BETA.ctypes,
VSL.ctypes,
LDVSL,
VSR.ctypes,
LDVSR,
WORK.ctypes,
LWORK,
RWORK.ctypes,
BWORK,
INFO,
)
WS_SIZE = np.int32(WORK[0].real)
LWORK = val_to_int_ptr(WS_SIZE)
WORK = np.empty(WS_SIZE, dtype=dtype)
# Actual call
numba_gges(
JOBVSL,
JOBVSR,
SORT,
SELECT,
N,
A_copy.ctypes,
LDA,
B_copy.ctypes,
LDB,
SDIM,
ALPHA.ctypes,
BETA.ctypes,
VSL.ctypes,
LDVSL,
VSR.ctypes,
LDVSR,
WORK.ctypes,
LWORK,
RWORK.ctypes,
BWORK,
INFO,
)
if int_ptr_to_val(INFO) != 0:
A_copy[:] = np.nan
B_copy[:] = np.nan
VSL[:] = np.nan
VSR[:] = np.nan
return A_copy, B_copy, VSL.T, VSR.T
return impl
@overload(_qz_complex_nosort_eig)
def qz_complex_nosort_eig_impl(A, B, overwrite_a, overwrite_b):
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=(Complex,), func_name="qz")
_check_linalg_matrix(B, ndim=2, dtype=(Complex,), func_name="qz")
dtype = A.dtype
real_dtype = _get_underlying_float(dtype)
numba_gges = _LAPACK().numba_xgges(dtype)
def impl(A, B, overwrite_a, overwrite_b):
_N = np.int32(A.shape[-1])
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order(B)
WORK = np.empty(1, dtype=dtype)
LWORK = val_to_int_ptr(-1)
JOBVSL = val_to_int_ptr(ord("V"))
JOBVSR = val_to_int_ptr(ord("V"))
SORT = val_to_int_ptr(ord("N"))
SELECT = val_to_int_ptr(0)
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
LDB = val_to_int_ptr(_N)
SDIM = val_to_int_ptr(0)
ALPHA = np.empty(_N, dtype=dtype)
BETA = np.empty(_N, dtype=dtype)
LDVSL = val_to_int_ptr(_N)
VSL = np.empty((_N, _N), dtype=dtype)
LDVSR = val_to_int_ptr(_N)
VSR = np.empty((_N, _N), dtype=dtype)
RWORK = np.empty(8 * _N, dtype=real_dtype)
BWORK = val_to_int_ptr(1)
INFO = val_to_int_ptr(0)
# Workspace query
numba_gges(
JOBVSL,
JOBVSR,
SORT,
SELECT,
N,
A_copy.ctypes,
LDA,
B_copy.ctypes,
LDB,
SDIM,
ALPHA.ctypes,
BETA.ctypes,
VSL.ctypes,
LDVSL,
VSR.ctypes,
LDVSR,
WORK.ctypes,
LWORK,
RWORK.ctypes,
BWORK,
INFO,
)
WS_SIZE = np.int32(WORK[0].real)
LWORK = val_to_int_ptr(WS_SIZE)
WORK = np.empty(WS_SIZE, dtype=dtype)
# Actual call
numba_gges(
JOBVSL,
JOBVSR,
SORT,
SELECT,
N,
A_copy.ctypes,
LDA,
B_copy.ctypes,
LDB,
SDIM,
ALPHA.ctypes,
BETA.ctypes,
VSL.ctypes,
LDVSL,
VSR.ctypes,
LDVSR,
WORK.ctypes,
LWORK,
RWORK.ctypes,
BWORK,
INFO,
)
if int_ptr_to_val(INFO) != 0:
A_copy[:] = np.nan
B_copy[:] = np.nan
VSL[:] = np.nan
VSR[:] = np.nan
return A_copy, B_copy, ALPHA, BETA, VSL.T, VSR.T
return impl
@overload(_qz_complex_sort_noeig)
def qz_complex_sort_noeig_impl(A, B, sort, overwrite_a, overwrite_b):
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=(Complex,), func_name="qz")
_check_linalg_matrix(B, ndim=2, dtype=(Complex,), func_name="qz")
dtype = A.dtype
real_dtype = _get_underlying_float(dtype)
numba_gges = _LAPACK().numba_xgges(dtype)
numba_tgsen = _LAPACK().numba_tgsen(dtype)
def impl(A, B, sort, overwrite_a, overwrite_b):
_N = np.int32(A.shape[-1])
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order(B)
WORK = np.empty(1, dtype=dtype)
LWORK = val_to_int_ptr(-1)
JOBVSL = val_to_int_ptr(ord("V"))
JOBVSR = val_to_int_ptr(ord("V"))
SORT = val_to_int_ptr(ord("N"))
SELECT = val_to_int_ptr(0)
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
LDB = val_to_int_ptr(_N)
SDIM = val_to_int_ptr(0)
ALPHA = np.empty(_N, dtype=dtype)
BETA = np.empty(_N, dtype=dtype)
LDVSL = val_to_int_ptr(_N)
VSL = np.empty((_N, _N), dtype=dtype)
LDVSR = val_to_int_ptr(_N)
VSR = np.empty((_N, _N), dtype=dtype)
RWORK = np.empty(8 * _N, dtype=real_dtype)
BWORK = val_to_int_ptr(1)
INFO = val_to_int_ptr(0)
# Workspace query for gges
numba_gges(
JOBVSL,
JOBVSR,
SORT,
SELECT,
N,
A_copy.ctypes,
LDA,
B_copy.ctypes,
LDB,
SDIM,
ALPHA.ctypes,
BETA.ctypes,
VSL.ctypes,
LDVSL,
VSR.ctypes,
LDVSR,
WORK.ctypes,
LWORK,
RWORK.ctypes,
BWORK,
INFO,
)
WS_SIZE = np.int32(WORK[0].real)
LWORK = val_to_int_ptr(WS_SIZE)
WORK = np.empty(WS_SIZE, dtype=dtype)
# Actual gges call
numba_gges(
JOBVSL,
JOBVSR,
SORT,
SELECT,
N,
A_copy.ctypes,
LDA,
B_copy.ctypes,
LDB,
SDIM,
ALPHA.ctypes,
BETA.ctypes,
VSL.ctypes,
LDVSL,
VSR.ctypes,
LDVSR,
WORK.ctypes,
LWORK,
RWORK.ctypes,
BWORK,
INFO,
)
if int_ptr_to_val(INFO) != 0:
A_copy[:] = np.nan
B_copy[:] = np.nan
VSL[:] = np.nan
VSR[:] = np.nan
return A_copy, B_copy, VSL.T, VSR.T
# Apply sorting via tgsen
if sort == "lhp":
select = _lhp(ALPHA, BETA)
elif sort == "rhp":
select = _rhp(ALPHA, BETA)
elif sort == "iuc":
select = _iuc(ALPHA, BETA)
else: # ouc
select = _ouc(ALPHA, BETA)
IJOB = val_to_int_ptr(0)
WANTQ = val_to_int_ptr(1)
WANTZ = val_to_int_ptr(1)
LDQ = val_to_int_ptr(_N)
LDZ = val_to_int_ptr(_N)
M = val_to_int_ptr(0)
PL = np.empty(1, dtype=real_dtype)
PR = np.empty(1, dtype=real_dtype)
DIF = np.empty(2, dtype=real_dtype)
TGSEN_LWORK = val_to_int_ptr(1)
TGSEN_WORK = np.empty(1, dtype=dtype)
LIWORK = val_to_int_ptr(1)
IWORK = np.empty(1, dtype=np.int32)
INFO = val_to_int_ptr(0)
numba_tgsen(
IJOB,
WANTQ,
WANTZ,
select.ctypes,
N,
A_copy.ctypes,
LDA,
B_copy.ctypes,
LDB,
ALPHA.ctypes,
BETA.ctypes,
VSL.ctypes,
LDQ,
VSR.ctypes,
LDZ,
M,
PL.ctypes,
PR.ctypes,
DIF.ctypes,
TGSEN_WORK.ctypes,
TGSEN_LWORK,
IWORK.ctypes,
LIWORK,
INFO,
)
if int_ptr_to_val(INFO) != 0:
A_copy[:] = np.nan
B_copy[:] = np.nan
VSL[:] = np.nan
VSR[:] = np.nan
return A_copy, B_copy, VSL.T, VSR.T
return impl
@overload(_qz_complex_sort_eig)
def qz_complex_sort_eig_impl(A, B, sort, overwrite_a, overwrite_b):
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=(Complex,), func_name="qz")
_check_linalg_matrix(B, ndim=2, dtype=(Complex,), func_name="qz")
dtype = A.dtype
real_dtype = _get_underlying_float(dtype)
numba_gges = _LAPACK().numba_xgges(dtype)
numba_tgsen = _LAPACK().numba_tgsen(dtype)
def impl(A, B, sort, overwrite_a, overwrite_b):
_N = np.int32(A.shape[-1])
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order(B)
WORK = np.empty(1, dtype=dtype)
LWORK = val_to_int_ptr(-1)
JOBVSL = val_to_int_ptr(ord("V"))
JOBVSR = val_to_int_ptr(ord("V"))
SORT = val_to_int_ptr(ord("N"))
SELECT = val_to_int_ptr(0)
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
LDB = val_to_int_ptr(_N)
SDIM = val_to_int_ptr(0)
ALPHA = np.empty(_N, dtype=dtype)
BETA = np.empty(_N, dtype=dtype)
LDVSL = val_to_int_ptr(_N)
VSL = np.empty((_N, _N), dtype=dtype)
LDVSR = val_to_int_ptr(_N)
VSR = np.empty((_N, _N), dtype=dtype)
RWORK = np.empty(8 * _N, dtype=real_dtype)
BWORK = val_to_int_ptr(1)
INFO = val_to_int_ptr(0)
# Workspace query for gges
numba_gges(
JOBVSL,
JOBVSR,
SORT,
SELECT,
N,
A_copy.ctypes,
LDA,
B_copy.ctypes,
LDB,
SDIM,
ALPHA.ctypes,
BETA.ctypes,
VSL.ctypes,
LDVSL,
VSR.ctypes,
LDVSR,
WORK.ctypes,
LWORK,
RWORK.ctypes,
BWORK,
INFO,
)
WS_SIZE = np.int32(WORK[0].real)
LWORK = val_to_int_ptr(WS_SIZE)
WORK = np.empty(WS_SIZE, dtype=dtype)
# Actual call
numba_gges(
JOBVSL,
JOBVSR,
SORT,
SELECT,
N,
A_copy.ctypes,
LDA,
B_copy.ctypes,
LDB,
SDIM,
ALPHA.ctypes,
BETA.ctypes,
VSL.ctypes,
LDVSL,
VSR.ctypes,
LDVSR,
WORK.ctypes,
LWORK,
RWORK.ctypes,
BWORK,
INFO,
)
if int_ptr_to_val(INFO) != 0:
A_copy[:] = np.nan
B_copy[:] = np.nan
VSL[:] = np.nan
VSR[:] = np.nan
return A_copy, B_copy, ALPHA, BETA, VSL.T, VSR.T
# Apply sorting via tgsen
if sort == "lhp":
select = _lhp(ALPHA, BETA)
elif sort == "rhp":
select = _rhp(ALPHA, BETA)
elif sort == "iuc":
select = _iuc(ALPHA, BETA)
else: # ouc
select = _ouc(ALPHA, BETA)
IJOB = val_to_int_ptr(0)
WANTQ = val_to_int_ptr(1)
WANTZ = val_to_int_ptr(1)
LDQ = val_to_int_ptr(_N)
LDZ = val_to_int_ptr(_N)
M = val_to_int_ptr(0)
PL = np.empty(1, dtype=real_dtype)
PR = np.empty(1, dtype=real_dtype)
DIF = np.empty(2, dtype=real_dtype)
TGSEN_LWORK = val_to_int_ptr(1)
TGSEN_WORK = np.empty(1, dtype=dtype)
LIWORK = val_to_int_ptr(1)
IWORK = np.empty(1, dtype=np.int32)
INFO = val_to_int_ptr(0)
numba_tgsen(
IJOB,
WANTQ,
WANTZ,
select.ctypes,
N,
A_copy.ctypes,
LDA,
B_copy.ctypes,
LDB,
ALPHA.ctypes,
BETA.ctypes,
VSL.ctypes,
LDQ,
VSR.ctypes,
LDZ,
M,
PL.ctypes,
PR.ctypes,
DIF.ctypes,
TGSEN_WORK.ctypes,
TGSEN_LWORK,
IWORK.ctypes,
LIWORK,
INFO,
)
if int_ptr_to_val(INFO) != 0:
A_copy[:] = np.nan
B_copy[:] = np.nan
VSL[:] = np.nan
VSR[:] = np.nan
return A_copy, B_copy, ALPHA, BETA, VSL.T, VSR.T
return impl
...@@ -25,6 +25,16 @@ from pytensor.link.numba.dispatch.linalg.decomposition.qr import ( ...@@ -25,6 +25,16 @@ from pytensor.link.numba.dispatch.linalg.decomposition.qr import (
_qr_raw_no_pivot, _qr_raw_no_pivot,
_qr_raw_pivot, _qr_raw_pivot,
) )
from pytensor.link.numba.dispatch.linalg.decomposition.qz import (
_qz_complex_nosort_eig,
_qz_complex_nosort_noeig,
_qz_complex_sort_eig,
_qz_complex_sort_noeig,
_qz_real_nosort_eig,
_qz_real_nosort_noeig,
_qz_real_sort_eig,
_qz_real_sort_noeig,
)
from pytensor.link.numba.dispatch.linalg.decomposition.schur import ( from pytensor.link.numba.dispatch.linalg.decomposition.schur import (
schur_complex, schur_complex,
schur_real, schur_real,
...@@ -46,6 +56,7 @@ from pytensor.tensor._linalg.solve.linear_control import TRSYL ...@@ -46,6 +56,7 @@ from pytensor.tensor._linalg.solve.linear_control import TRSYL
from pytensor.tensor.slinalg import ( from pytensor.tensor.slinalg import (
LU, LU,
QR, QR,
QZ,
BlockDiagonal, BlockDiagonal,
Cholesky, Cholesky,
CholeskySolve, CholeskySolve,
...@@ -535,6 +546,94 @@ def numba_funcify_Schur(op, node, **kwargs): ...@@ -535,6 +546,94 @@ def numba_funcify_Schur(op, node, **kwargs):
return schur, cache_version return schur, cache_version
@register_funcify_default_op_cache_key(QZ)
def numba_funcify_QZ(op, node, **kwargs):
complex_output = op.complex_output
sort = op.sort
return_eigenvalues = op.return_eigenvalues
overwrite_a = op.overwrite_a
overwrite_b = op.overwrite_b
in_dtype_a = node.inputs[0].type.numpy_dtype
in_dtype_b = node.inputs[1].type.numpy_dtype
out_dtype = node.outputs[0].type.numpy_dtype
integer_input_a = in_dtype_a.kind in "ibu"
integer_input_b = in_dtype_b.kind in "ibu"
complex_input = in_dtype_a.kind == "c" or in_dtype_b.kind == "c"
needs_complex_cast = (
in_dtype_a.kind in "fd" or in_dtype_b.kind in "fd"
) and complex_output
# Disable overwrite for dtype conversion (real->complex upcast)
if needs_complex_cast:
overwrite_a = False
overwrite_b = False
if config.compiler_verbose:
print( # noqa: T201
"QZ: disabling overwrite_a/b due to dtype conversion (casting prevents in-place operation)"
)
if (integer_input_a or integer_input_b) and config.compiler_verbose:
print("QZ requires casting discrete input to float") # noqa: T201
use_complex = complex_input or complex_output
use_sort = sort is not None
if use_complex:
if use_sort:
if return_eigenvalues:
qz_fn = _qz_complex_sort_eig
else:
qz_fn = _qz_complex_sort_noeig
else:
if return_eigenvalues:
qz_fn = _qz_complex_nosort_eig
else:
qz_fn = _qz_complex_nosort_noeig
else:
if use_sort:
if return_eigenvalues:
qz_fn = _qz_real_sort_eig
else:
qz_fn = _qz_real_sort_noeig
else:
if return_eigenvalues:
qz_fn = _qz_real_nosort_eig
else:
qz_fn = _qz_real_nosort_noeig
if use_sort:
@numba_basic.numba_njit
def qz(a, b):
if integer_input_a:
a = a.astype(out_dtype)
elif needs_complex_cast:
a = a.astype(out_dtype)
if integer_input_b:
b = b.astype(out_dtype)
elif needs_complex_cast:
b = b.astype(out_dtype)
return qz_fn(a, b, sort, overwrite_a, overwrite_b)
else:
@numba_basic.numba_njit
def qz(a, b):
if integer_input_a:
a = a.astype(out_dtype)
elif needs_complex_cast:
a = a.astype(out_dtype)
if integer_input_b:
b = b.astype(out_dtype)
elif needs_complex_cast:
b = b.astype(out_dtype)
return qz_fn(a, b, overwrite_a, overwrite_b)
cache_version = 1
return qz, cache_version
@register_funcify_default_op_cache_key(TRSYL) @register_funcify_default_op_cache_key(TRSYL)
def numba_funcify_TRSYL(op, node, **kwargs): def numba_funcify_TRSYL(op, node, **kwargs):
in_dtype_a = node.inputs[0].type.numpy_dtype in_dtype_a = node.inputs[0].type.numpy_dtype
......
...@@ -20,6 +20,7 @@ from pytensor.tensor.slinalg import ( ...@@ -20,6 +20,7 @@ from pytensor.tensor.slinalg import (
lu, lu,
lu_factor, lu_factor,
lu_solve, lu_solve,
qz,
schur, schur,
solve, solve,
solve_triangular, solve_triangular,
...@@ -793,6 +794,103 @@ class TestDecompositions: ...@@ -793,6 +794,103 @@ class TestDecompositions:
np.testing.assert_allclose(Z_c, Z_res, atol=1e-6) np.testing.assert_allclose(Z_c, Z_res, atol=1e-6)
np.testing.assert_allclose(val_c_contig, A_val) np.testing.assert_allclose(val_c_contig, A_val)
@pytest.mark.parametrize(
"output, input_type, sort, return_eigenvalues",
[
("real", "real", None, False),
("complex", "real", "lhp", True),
("real", "complex", "ouc", False),
("complex", "complex", None, True),
("real", "real", "iuc", True),
],
ids=[
"real_nosort",
"real_to_complex_sort",
"complex_sort",
"complex_nosort_eig",
"real_sort_eig",
],
)
def test_qz(self, output, input_type, sort, return_eigenvalues):
shape = (5, 5)
dtype = (
config.floatX
if input_type == "real"
else ("complex64" if config.floatX.endswith("32") else "complex128")
)
A = pt.tensor("A", shape=shape, dtype=dtype)
B = pt.tensor("B", shape=shape, dtype=dtype)
outputs = qz(
A, B, output=output, sort=sort, return_eigenvalues=return_eigenvalues
)
if return_eigenvalues:
AA, BB, alpha, beta, Q, Z = outputs
output_list = [AA, BB, alpha, beta, Q, Z]
else:
AA, BB, Q, Z = outputs
output_list = [AA, BB, Q, Z]
rng = np.random.default_rng()
A_val = rng.normal(size=shape).astype(dtype)
B_val = rng.normal(size=shape).astype(dtype)
fn, res = compare_numba_and_py(
[A, B],
output_list,
[A_val, B_val],
numba_mode=numba_inplace_mode,
inplace=True,
)
if return_eigenvalues:
AA_res, BB_res, alpha_res, beta_res, Q_res, Z_res = res
else:
AA_res, BB_res, Q_res, Z_res = res
expected_complex_output = input_type == "complex" or output == "complex"
assert np.iscomplexobj(AA_res) == expected_complex_output
assert np.iscomplexobj(BB_res) == expected_complex_output
assert np.iscomplexobj(Q_res) == expected_complex_output
assert np.iscomplexobj(Z_res) == expected_complex_output
# Verify reconstruction: Q @ AA @ Z.conj().T = A, Q @ BB @ Z.conj().T = B
A_rebuilt = Q_res @ AA_res @ Z_res.conj().T
B_rebuilt = Q_res @ BB_res @ Z_res.conj().T
np.testing.assert_allclose(A_val, A_rebuilt, atol=1e-5, rtol=1e-5)
np.testing.assert_allclose(B_val, B_rebuilt, atol=1e-5, rtol=1e-5)
# Test F-contiguous input
A_val_f_contig = np.copy(A_val, order="F")
B_val_f_contig = np.copy(B_val, order="F")
res_f = fn(A_val_f_contig, B_val_f_contig)
if return_eigenvalues:
AA_f, BB_f, alpha_f, beta_f, Q_f, Z_f = res_f
np.testing.assert_allclose(alpha_f, alpha_res, atol=1e-6)
np.testing.assert_allclose(beta_f, beta_res, atol=1e-6)
else:
AA_f, BB_f, Q_f, Z_f = res_f
np.testing.assert_allclose(AA_f, AA_res, atol=1e-6)
np.testing.assert_allclose(BB_f, BB_res, atol=1e-6)
np.testing.assert_allclose(Q_f, Q_res, atol=1e-6)
np.testing.assert_allclose(Z_f, Z_res, atol=1e-6)
# Test C-contiguous input
A_val_c_contig = np.copy(A_val, order="C")
B_val_c_contig = np.copy(B_val, order="C")
res_c = fn(A_val_c_contig, B_val_c_contig)
if return_eigenvalues:
AA_c, BB_c, alpha_c, beta_c, Q_c, Z_c = res_c
np.testing.assert_allclose(alpha_c, alpha_res, atol=1e-6)
np.testing.assert_allclose(beta_c, beta_res, atol=1e-6)
else:
AA_c, BB_c, Q_c, Z_c = res_c
np.testing.assert_allclose(AA_c, AA_res, atol=1e-6)
np.testing.assert_allclose(BB_c, BB_res, atol=1e-6)
np.testing.assert_allclose(Q_c, Q_res, atol=1e-6)
np.testing.assert_allclose(Z_c, Z_res, atol=1e-6)
def test_block_diag(): def test_block_diag():
A = pt.matrix("A") A = pt.matrix("A")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论