提交 f7d1c644 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Jesse Grabowski

Numba dispatch for linear control Ops

上级 a3bf6bb6
...@@ -894,3 +894,55 @@ class _LAPACK: ...@@ -894,3 +894,55 @@ class _LAPACK:
) )
return gees return gees
@classmethod
def numba_xtrsyl(cls, dtype):
"""
Solve the Sylvester equation A*X + ISGN*X*B = C or A**T*X + ISGN*X*B**T = C.
Called by scipy.linalg.solve_sylvester and scipy.linalg.solve_continuous_lyapunov.
"""
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
if kind in "ld":
real_pointer = float_pointer
else:
real_pointer = nb_f64p if dtype is nb_c128 else nb_f32p
unique_func_name = f"scipy.lapack.{kind}trsyl"
@numba_basic.numba_njit
def get_trsyl_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "trsyl")
return ptr
trsyl_function_type = types.FunctionType(
types.void(
nb_i32p, # TRANA
nb_i32p, # TRANB
nb_i32p, # ISGN
nb_i32p, # M
nb_i32p, # N
float_pointer, # A
nb_i32p, # LDA
float_pointer, # B
nb_i32p, # LDB
float_pointer, # C
nb_i32p, # LDC
real_pointer, # SCALE
nb_i32p, # INFO
)
)
@numba_basic.numba_njit
def trsyl(TRANA, TRANB, ISGN, M, N, A, LDA, B, LDB, C, LDC, SCALE, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_trsyl_pointer,
func_type_ref=trsyl_function_type,
unique_func_name_lit=unique_func_name,
)
fn(TRANA, TRANB, ISGN, M, N, A, LDA, B, LDB, C, LDC, SCALE, INFO)
return trsyl
from collections.abc import Callable
from typing import cast
import numpy as np
from numba.core.extending import overload
from numba.core.types import Complex, Float
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy.linalg import get_lapack_funcs
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match,
_check_linalg_matrix,
)
def _trsyl(a: np.ndarray, b: np.ndarray, c: np.ndarray, overwrite_c):
"""Placeholder for real TRSYL (Sylvester equation solver)."""
fn = cast(Callable, get_lapack_funcs("trsyl", (a, b, c)))
x, scale, info = fn(a, b, c, overwrite_c=overwrite_c)
if info < 0:
return np.full_like(c, np.nan)
x *= scale
return x
@overload(_trsyl)
def trsyl_impl(A, B, C, overwrite_c):
"""
Overload for real TRSYL to solve Sylvester equation for inputs A and B in standard
Schur form.
"""
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=(Float, Complex), func_name="trsyl")
_check_linalg_matrix(B, ndim=2, dtype=(Float, Complex), func_name="trsyl")
_check_linalg_matrix(C, ndim=2, dtype=(Float, Complex), func_name="trsyl")
_check_dtypes_match((A, B, C), func_name="trsyl")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_xtrsyl = _LAPACK().numba_xtrsyl(dtype)
def impl(A, B, C, overwrite_c):
_M = np.int32(A.shape[-1])
_N = np.int32(B.shape[-1])
A_copy = _copy_to_fortran_order(A)
B_copy = _copy_to_fortran_order(B)
if overwrite_c and C.flags.f_contiguous:
C_copy = C
else:
C_copy = _copy_to_fortran_order(C)
TRANA = val_to_int_ptr(ord("N"))
TRANB = val_to_int_ptr(ord("N"))
ISGN = val_to_int_ptr(1)
M = val_to_int_ptr(_M)
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_M)
LDB = val_to_int_ptr(_N)
LDC = val_to_int_ptr(_M)
SCALE = np.array(1.0, dtype=w_type)
INFO = val_to_int_ptr(0)
# Call LAPACK trsyl
numba_xtrsyl(
TRANA,
TRANB,
ISGN,
M,
N,
A_copy.ctypes,
LDA,
B_copy.ctypes,
LDB,
C_copy.ctypes,
LDC,
SCALE.ctypes,
INFO,
)
if int_ptr_to_val(INFO) < 0:
return np.full_like(C_copy, np.nan)
# CC now contains the solution, scale it
C_copy *= SCALE
return C_copy
return impl
...@@ -31,6 +31,9 @@ from pytensor.link.numba.dispatch.linalg.decomposition.schur import ( ...@@ -31,6 +31,9 @@ from pytensor.link.numba.dispatch.linalg.decomposition.schur import (
) )
from pytensor.link.numba.dispatch.linalg.solve.cholesky import _cho_solve from pytensor.link.numba.dispatch.linalg.solve.cholesky import _cho_solve
from pytensor.link.numba.dispatch.linalg.solve.general import _solve_gen from pytensor.link.numba.dispatch.linalg.solve.general import _solve_gen
from pytensor.link.numba.dispatch.linalg.solve.linear_control import (
_trsyl,
)
from pytensor.link.numba.dispatch.linalg.solve.posdef import _solve_psd from pytensor.link.numba.dispatch.linalg.solve.posdef import _solve_psd
from pytensor.link.numba.dispatch.linalg.solve.symmetric import _solve_symmetric from pytensor.link.numba.dispatch.linalg.solve.symmetric import _solve_symmetric
from pytensor.link.numba.dispatch.linalg.solve.triangular import _solve_triangular from pytensor.link.numba.dispatch.linalg.solve.triangular import _solve_triangular
...@@ -42,6 +45,7 @@ from pytensor.link.numba.dispatch.string_codegen import ( ...@@ -42,6 +45,7 @@ from pytensor.link.numba.dispatch.string_codegen import (
from pytensor.tensor.slinalg import ( from pytensor.tensor.slinalg import (
LU, LU,
QR, QR,
TRSYL,
BlockDiagonal, BlockDiagonal,
Cholesky, Cholesky,
CholeskySolve, CholeskySolve,
...@@ -529,3 +533,38 @@ def numba_funcify_Schur(op, node, **kwargs): ...@@ -529,3 +533,38 @@ def numba_funcify_Schur(op, node, **kwargs):
cache_version = 1 cache_version = 1
return schur, cache_version return schur, cache_version
@register_funcify_default_op_cache_key(TRSYL)
def numba_funcify_TRSYL(op, node, **kwargs):
in_dtype_a = node.inputs[0].type.numpy_dtype
in_dtype_b = node.inputs[1].type.numpy_dtype
in_dtype_c = node.inputs[2].type.numpy_dtype
out_dtype = node.outputs[0].type.numpy_dtype
overwrite_c = op.overwrite_c
must_cast_a = in_dtype_a != out_dtype
if must_cast_a and config.compiler_verbose:
print("TRSYL requires casting first input `A`") # noqa: T201
must_cast_b = in_dtype_b != out_dtype
if must_cast_b and config.compiler_verbose:
print("TRSYL requires casting second input `B`") # noqa: T201
must_cast_c = in_dtype_c != out_dtype
if must_cast_c and config.compiler_verbose:
print("TRSYL requires casting third input `C`") # noqa: T201
@numba_basic.numba_njit
def trsyl(a, b, c):
if must_cast_a:
a = a.astype(out_dtype)
if must_cast_b:
b = b.astype(out_dtype)
if must_cast_c:
c = c.astype(out_dtype)
x = _trsyl(a, b, c, overwrite_c=overwrite_c)
return x
cache_version = 1
return trsyl, cache_version
import numpy as np
import pytest
from pytensor import config
from pytensor import tensor as pt
from tests.link.numba.test_basic import compare_numba_and_py
floatX = config.floatX
pytestmark = pytest.mark.filterwarnings(
"ignore:numba.core.errors.NumbaPerformanceWarning"
)
def test_solve_sylvester():
A = pt.matrix("A")
B = pt.matrix("B")
C = pt.matrix("C")
X = pt.linalg.solve_sylvester(A, B, C)
rng = np.random.default_rng()
A_val = rng.normal(size=(5, 5)).astype(floatX)
B_val = rng.normal(size=(5, 5)).astype(floatX)
C_val = rng.normal(size=(5, 5)).astype(floatX)
compare_numba_and_py([A, B, C], [X], [A_val, B_val, C_val])
def test_solve_continuous_lyapunov():
A = pt.matrix("A")
Q = pt.matrix("Q")
X = pt.linalg.solve_continuous_lyapunov(A, Q)
rng = np.random.default_rng()
A_val = rng.normal(size=(5, 5)).astype(floatX)
Q_val = rng.normal(size=(5, 5)).astype(floatX)
Q_val = Q_val @ Q_val.T # Make Q symmetric positive definite
compare_numba_and_py([A, Q], [X], [A_val, Q_val])
@pytest.mark.parametrize("method", ["bilinear", "direct"], ids=str)
def test_solve_discrete_lyapunov(method):
A = pt.matrix("A")
Q = pt.matrix("Q")
X = pt.linalg.solve_discrete_lyapunov(A, Q, method=method)
rng = np.random.default_rng()
A_val = rng.normal(size=(5, 5)).astype(floatX)
Q_val = rng.normal(size=(5, 5)).astype(floatX)
Q_val = Q_val @ Q_val.T # Make Q symmetric positive definite
compare_numba_and_py(
[A, Q],
[X],
[A_val, Q_val],
# object mode fails with 'numpy.dtypes.Int32DType' object has no attribute 'is_precise'
# when mode is "bilinear"
eval_obj_mode=method == "direct",
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论