提交 62ba6c9f authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Jesse Grabowski

Add Numba dispatch for Schur

上级 5f04b911
......@@ -750,3 +750,147 @@ class _LAPACK:
fn(M, N, K, A, LDA, TAU, WORK, LWORK, INFO)
return ungqr
@classmethod
def numba_xgees(cls, dtype):
"""
Compute the eigenvalues and, optionally, the right Schur vectors of a real nonsymmetric matrix A.
Called by scipy.linalg.schur
"""
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}gees"
@numba_basic.numba_njit
def get_gees_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "gees")
return ptr
if isinstance(dtype, Complex):
real_pointer = nb_f64p if dtype is nb_c128 else nb_f32p
gees_function_type = types.FunctionType(
types.void(
nb_i32p, # JOBVS
nb_i32p, # SORT
nb_i32p, # SELECT
nb_i32p, # N
float_pointer, # A
nb_i32p, # LDA
nb_i32p, # SDIM
float_pointer, # W
float_pointer, # VS
nb_i32p, # LDVS
float_pointer, # WORK
nb_i32p, # LWORK
real_pointer, # RWORK
nb_i32p, # BWORK
nb_i32p, # INFO
)
)
@numba_basic.numba_njit
def gees(
JOBVS,
SORT,
SELECT,
N,
A,
LDA,
SDIM,
W,
VS,
LDVS,
WORK,
LWORK,
RWORK,
BWORK,
INFO,
):
fn = _call_cached_ptr(
get_ptr_func=get_gees_pointer,
func_type_ref=gees_function_type,
unique_func_name_lit=unique_func_name,
)
fn(
JOBVS,
SORT,
SELECT,
N,
A,
LDA,
SDIM,
W,
VS,
LDVS,
WORK,
LWORK,
RWORK,
BWORK,
INFO,
)
else: # Real case
gees_function_type = types.FunctionType(
types.void(
nb_i32p, # JOBVS
nb_i32p, # SORT
nb_i32p, # SELECT
nb_i32p, # N
float_pointer, # A
nb_i32p, # LDA
nb_i32p, # SDIM
float_pointer, # WR
float_pointer, # WI
float_pointer, # VS
nb_i32p, # LDVS
float_pointer, # WORK
nb_i32p, # LWORK
nb_i32p, # BWORK
nb_i32p, # INFO
)
)
@numba_basic.numba_njit
def gees(
JOBVS,
SORT,
SELECT,
N,
A,
LDA,
SDIM,
WR,
WI,
VS,
LDVS,
WORK,
LWORK,
BWORK,
INFO,
):
fn = _call_cached_ptr(
get_ptr_func=get_gees_pointer,
func_type_ref=gees_function_type,
unique_func_name_lit=unique_func_name,
)
fn(
JOBVS,
SORT,
SELECT,
N,
A,
LDA,
SDIM,
WR,
WI,
VS,
LDVS,
WORK,
LWORK,
BWORK,
INFO,
)
return gees
from typing import Any
import numpy as np
from numba.core.extending import overload
from numba.core.types import Complex, Float
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy.linalg import schur
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix
def schur_real(
A: np.ndarray,
lwork: Any | None = None,
overwrite_a: Any = False,
):
return schur(
a=A,
output="real",
lwork=lwork,
overwrite_a=overwrite_a,
sort=None,
check_finite=False,
)
def schur_complex(
A: np.ndarray,
lwork: Any | None = None,
overwrite_a: Any = False,
):
return schur(
a=A,
output="complex",
lwork=lwork,
overwrite_a=overwrite_a,
sort=None,
check_finite=False,
)
@overload(schur_real)
def schur_real_impl(A, lwork, overwrite_a):
"""Overload for real Schur decomposition."""
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=(Float,), func_name="schur")
dtype = A.dtype
numba_xgees = _LAPACK().numba_xgees(dtype)
def real_schur_impl(A, lwork, overwrite_a):
_N = np.int32(A.shape[-1])
if lwork is None:
lwork = -1
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
if lwork == -1:
WORK = np.empty(1, dtype=dtype)
LWORK = val_to_int_ptr(-1)
else:
WORK = np.empty(lwork if lwork > 0 else 1, dtype=dtype)
LWORK = val_to_int_ptr(WORK.size)
JOBVS = val_to_int_ptr(ord("V"))
SORT = val_to_int_ptr(ord("N"))
SELECT = val_to_int_ptr(0.0)
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
SDIM = val_to_int_ptr(_N)
WR = np.empty(_N, dtype=dtype)
WI = np.empty(_N, dtype=dtype)
_LDVS = _N
LDVS = val_to_int_ptr(_N)
VS = np.empty((_LDVS, _N), dtype=dtype)
BWORK = val_to_int_ptr(1)
INFO = val_to_int_ptr(1)
if lwork == -1:
numba_xgees(
JOBVS,
SORT,
SELECT,
N,
A_copy.ctypes,
LDA,
SDIM,
WR.ctypes,
WI.ctypes,
VS.ctypes,
LDVS,
WORK.ctypes,
LWORK,
BWORK,
INFO,
)
WS_SIZE = np.int32(WORK[0].real)
LWORK = val_to_int_ptr(WS_SIZE)
WORK = np.empty(WS_SIZE, dtype=dtype)
numba_xgees(
JOBVS,
SORT,
SELECT,
N,
A_copy.ctypes,
LDA,
SDIM,
WR.ctypes,
WI.ctypes,
VS.ctypes,
LDVS,
WORK.ctypes,
LWORK,
BWORK,
INFO,
)
if int_ptr_to_val(INFO) != 0:
A_copy[:] = np.nan
return A_copy, VS.T
return real_schur_impl
@overload(schur_complex)
def schur_complex_impl(A, lwork, overwrite_a):
"""Overload for complex Schur decomposition."""
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=(Complex,), func_name="schur")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_xgees = _LAPACK().numba_xgees(dtype)
def complex_schur_impl(A, lwork, overwrite_a):
_N = np.int32(A.shape[-1])
if lwork is None:
lwork = -1
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
if lwork == -1:
WORK = np.empty(1, dtype=dtype)
LWORK = val_to_int_ptr(-1)
else:
WORK = np.empty(lwork if lwork > 0 else 1, dtype=dtype)
LWORK = val_to_int_ptr(WORK.size)
JOBVS = val_to_int_ptr(ord("V"))
SORT = val_to_int_ptr(ord("N"))
SELECT = val_to_int_ptr(0.0)
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
SDIM = val_to_int_ptr(_N)
W = np.empty(_N, dtype=dtype)
_LDVS = _N
LDVS = val_to_int_ptr(_N)
VS = np.empty((_LDVS, _N), dtype=dtype)
RWORK = np.empty(_N, dtype=w_type)
BWORK = val_to_int_ptr(1)
INFO = val_to_int_ptr(1)
if lwork == -1:
numba_xgees(
JOBVS,
SORT,
SELECT,
N,
A_copy.ctypes,
LDA,
SDIM,
W.ctypes,
VS.ctypes,
LDVS,
WORK.ctypes,
LWORK,
RWORK.ctypes,
BWORK,
INFO,
)
WS_SIZE = np.int32(WORK[0].real)
LWORK = val_to_int_ptr(WS_SIZE)
WORK = np.empty(WS_SIZE, dtype=dtype)
numba_xgees(
JOBVS,
SORT,
SELECT,
N,
A_copy.ctypes,
LDA,
SDIM,
W.ctypes,
VS.ctypes,
LDVS,
WORK.ctypes,
LWORK,
RWORK.ctypes,
BWORK,
INFO,
)
if int_ptr_to_val(INFO) != 0:
A_copy[:] = np.nan
return A_copy, VS.T
return complex_schur_impl
......@@ -25,6 +25,10 @@ from pytensor.link.numba.dispatch.linalg.decomposition.qr import (
_qr_raw_no_pivot,
_qr_raw_pivot,
)
from pytensor.link.numba.dispatch.linalg.decomposition.schur import (
schur_complex,
schur_real,
)
from pytensor.link.numba.dispatch.linalg.solve.cholesky import _cho_solve
from pytensor.link.numba.dispatch.linalg.solve.general import _solve_gen
from pytensor.link.numba.dispatch.linalg.solve.posdef import _solve_psd
......@@ -43,6 +47,7 @@ from pytensor.tensor.slinalg import (
CholeskySolve,
LUFactor,
PivotToPermutations,
Schur,
Solve,
SolveTriangular,
)
......@@ -469,3 +474,58 @@ def numba_funcify_QR(op, node, **kwargs):
cache_version = 2
return qr, cache_version
@register_funcify_default_op_cache_key(Schur)
def numba_funcify_Schur(op, node, **kwargs):
output = op.output
overwrite_a = op.overwrite_a
sort = op.sort
if sort is not None:
if config.compiler_verbose:
print( # noqa: T201
"Schur is not implemented in numba mode when `sort` is not None, "
"falling back to object mode"
)
return generate_fallback_impl(op, node=node, **kwargs)
in_dtype = node.inputs[0].type.numpy_dtype
out_dtype = node.outputs[0].type.numpy_dtype
integer_input = in_dtype.kind in "ibu"
complex_input = in_dtype.kind in "cz"
needs_complex_cast = in_dtype.kind in "fd" and output == "complex"
# Disable overwrite_a for dtype conversion (real->complex upcast)
if needs_complex_cast:
overwrite_a = False
if config.compiler_verbose:
print( # noqa: T201
"Schur: disabling overwrite_a due to dtype conversion (casting prevents in-place operation)"
)
if integer_input and config.compiler_verbose:
print("Schur requires casting discrete input to float") # noqa: T201
# Complex input always produces complex output, and output == "complex" forces complex output
if complex_input or output == "complex":
@numba_basic.numba_njit
def schur(a):
if integer_input:
a = a.astype(out_dtype)
elif needs_complex_cast:
a = a.astype(out_dtype)
T, Z = schur_complex(a, lwork=None, overwrite_a=overwrite_a)
return T, Z
else:
# Real input with real output
@numba_basic.numba_njit
def schur(a):
if integer_input:
a = a.astype(out_dtype)
T, Z = schur_real(a, lwork=None, overwrite_a=overwrite_a)
return T, Z
cache_version = 1
return schur, cache_version
......@@ -20,6 +20,7 @@ from pytensor.tensor.slinalg import (
lu,
lu_factor,
lu_solve,
schur,
solve,
solve_triangular,
)
......@@ -735,6 +736,63 @@ class TestDecompositions:
[np.zeros((0, 0))],
)
@pytest.mark.parametrize("output", ["real", "complex"], ids=lambda x: f"output_{x}")
@pytest.mark.parametrize(
"input_type", ["real", "complex"], ids=lambda x: f"input_{x}"
)
@pytest.mark.parametrize(
"overwrite_a", [False, True], ids=["no_overwrite", "overwrite_a"]
)
def test_schur(self, output, input_type, overwrite_a):
shape = (5, 5)
# Scipy only respects output parameter for real inputs
# Complex inputs always produce complex output
requires_casting = input_type == "real" and output == "complex"
dtype = (
config.floatX
if input_type == "real"
else ("complex64" if config.floatX.endswith("32") else "complex128")
)
A = pt.tensor("A", shape=shape, dtype=dtype)
T, Z = schur(A, output=output)
rng = np.random.default_rng()
A_val = rng.normal(size=shape).astype(dtype)
fn, (T_res, Z_res) = compare_numba_and_py(
[In(A, mutable=overwrite_a)],
[T, Z],
[A_val],
numba_mode=numba_inplace_mode,
inplace=True,
)
expected_complex_output = input_type == "complex" or output == "complex"
assert (
np.iscomplexobj(T_res) and np.iscomplexobj(Z_res)
) == expected_complex_output
# Verify reconstruction
A_rebuilt = Z_res @ T_res @ Z_res.conj().T
np.testing.assert_allclose(A_val, A_rebuilt, atol=1e-6, rtol=1e-6)
# Test F-contiguous input
val_f_contig = np.copy(A_val, order="F")
T_f, Z_f = fn(val_f_contig)
np.testing.assert_allclose(T_f, T_res, atol=1e-6)
np.testing.assert_allclose(Z_f, Z_res, atol=1e-6)
expect_destroy = overwrite_a and not requires_casting
assert (A_val == val_f_contig).all() == (not expect_destroy)
# Test C-contiguous input (cannot destroy)
val_c_contig = np.copy(A_val, order="C")
T_c, Z_c = fn(val_c_contig)
np.testing.assert_allclose(T_c, T_res, atol=1e-6)
np.testing.assert_allclose(Z_c, Z_res, atol=1e-6)
np.testing.assert_allclose(val_c_contig, A_val)
def test_block_diag():
A = pt.matrix("A")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论