提交 5f04b911 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Jesse Grabowski

always include inplace rewrite during inplace test

Add Schur decomposition Op
上级 975ca888
......@@ -2068,6 +2068,193 @@ def qr(
return Blockwise(QR(mode=mode, pivoting=pivoting, overwrite_a=False))(A)
class Schur(Op):
"""
Schur Decomposition
"""
__props__ = ("output", "overwrite_a", "sort")
def __init__(
self,
output: Literal["real", "complex"] = "real",
overwrite_a: bool = False,
sort: Literal["lhp", "rhp", "iuc", "ouc"] | None = None,
):
self.output = output
self.gufunc_signature = "(m,m)->(m,m),(m,m)"
self.overwrite_a = overwrite_a
self.sort = sort
self.destroy_map = {0: [0]} if overwrite_a else {}
if output not in ["real", "complex"]:
raise ValueError("output must be 'real' or 'complex'")
if sort is not None and sort not in ("lhp", "rhp", "iuc", "ouc"):
raise ValueError("sort must be None or one of ('lhp', 'rhp', 'iuc', 'ouc')")
def make_sort_function(self):
sort = self.sort
sort_t = 1
match sort:
case None:
sort_t = 0
def sort_function(x, y=None):
return None
case "lhp":
def sort_function(x, y=None):
return x.real < 0.0
case "rhp":
def sort_function(x, y=None):
return x.real >= 0.0
case "iuc":
def sort_function(x, y=None):
z = x if y is None else x + y * 1j
return abs(z) <= 1.0
case "ouc":
def sort_function(x, y=None):
z = x if y is None else x + y * 1j
return abs(z) > 1.0
case _:
raise ValueError(
"sort must be None or one of ('lhp', 'rhp', 'iuc', 'ouc')"
)
return sort_function, sort_t
def make_node(self, A):
A = as_tensor_variable(A)
assert A.ndim == 2
out_dtype = A.dtype
complex_input = out_dtype in ("complex64", "complex128")
# Scipy behavior: output parameter only affects real inputs
# Complex inputs always return complex output
if self.output == "complex" and not complex_input:
out_dtype = "complex64" if A.dtype == "float32" else "complex128"
T = matrix(dtype=out_dtype, shape=A.type.shape)
Z = matrix(dtype=out_dtype, shape=A.type.shape)
return Apply(self, [A], [T, Z])
def perform(self, node, inputs, outputs):
(A,) = inputs
(T_out, Z_out) = outputs
overwrite_a = self.overwrite_a
A_work = A
if self.output == "complex" and not np.iscomplexobj(A):
overwrite_a = False
if A.dtype == np.float32:
A_work = A.astype(np.complex64)
else:
A_work = A.astype(np.complex128)
if self.output == "real" and np.iscomplexobj(A):
overwrite_a = False
(gees,) = scipy_linalg.get_lapack_funcs(("gees",), dtype=A_work.dtype)
if A_work.size == 0:
T_out[0] = np.empty_like(A_work, dtype=gees.dtype)
Z_out[0] = np.empty_like(A_work, dtype=gees.dtype)
return
if not np.isfinite(A_work).all():
T_out[0] = np.full(A_work.shape, np.nan, dtype=gees.dtype)
Z_out[0] = np.full(A_work.shape, np.nan, dtype=gees.dtype)
return
sort_function, sort_t = self.make_sort_function()
*_, work, _info = gees(
sort_function, A_work, lwork=-1, overwrite_a=False, sort_t=sort_t
)
lwork = int(work[0].real)
result = gees(
sort_function,
A_work,
lwork=lwork,
overwrite_a=overwrite_a,
sort_t=sort_t,
)
if np.iscomplexobj(A_work):
T, _sdim, _w, Z, _work, info = result
else:
T, _sdim, _wr, _wi, Z, _work, info = result
if info != 0:
T_out[0] = np.full(A_work.shape, np.nan, dtype=node.outputs[0].type.dtype)
Z_out[0] = np.full(A_work.shape, np.nan, dtype=node.outputs[1].type.dtype)
else:
T_out[0] = T
Z_out[0] = Z
def infer_shape(self, fgraph, node, shapes):
return [shapes[0], shapes[0]]
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
if not allowed_inplace_inputs:
return self
new_props = self._props_dict() # type: ignore
new_props["overwrite_a"] = True
return type(self)(**new_props)
def schur(
A: TensorLike,
output: Literal["real", "complex"] = "real",
sort: Literal["lhp", "rhp", "iuc", "ouc"] | None = None,
) -> tuple[TensorVariable, TensorVariable]:
"""
Schur Decomposition of input matrix `A`.
The Schur decomposition of a matrix `A` is a factorization of the form :math:`A = Z T Z^H`,
where `Z` is a unitary matrix and `T` is either upper-triangular (for complex Schur form)
or quasi-upper-triangular (for real Schur form with output='real').
Parameters
----------
A: TensorLike
Input square matrix of shape (M, M) to be decomposed.
output: str, one of "real" or "complex"
For real-valued `A`, if output='real', then the Schur form is quasi-upper-triangular.
If output='complex', the Schur form is upper-triangular. For complex-valued `A`,
the Schur form is always upper-triangular regardless of the output parameter.
sort: str or None, optional
Specifies whether the upper eigenvalues should be sorted. Available options:
- None (default): eigenvalues are not sorted
- 'lhp': left half-plane (real(λ) < 0)
- 'rhp': right half-plane (real(λ) >= 0)
- 'iuc': inside unit circle (abs(λ) <= 1)
- 'ouc': outside unit circle (abs(λ) > 1)
Returns
-------
T : TensorVariable
Schur form of A. An upper-triangular matrix (or quasi-upper-triangular if output='real').
Z : TensorVariable
Unitary Schur transformation matrix such that A = Z @ T @ Z.conj().T
"""
return Blockwise(Schur(output=output, sort=sort))(A) # type: ignore[return-value]
__all__ = [
"block_diag",
"cho_solve",
......@@ -2078,6 +2265,7 @@ __all__ = [
"lu_factor",
"lu_solve",
"qr",
"schur",
"solve",
"solve_continuous_lyapunov",
"solve_discrete_are",
......
......@@ -8,7 +8,7 @@ import pytest
import scipy
from scipy import linalg as scipy_linalg
from pytensor import function, grad
from pytensor import In, function, grad
from pytensor import tensor as pt
from pytensor.compile import get_default_mode
from pytensor.configdefaults import config
......@@ -31,6 +31,7 @@ from pytensor.tensor.slinalg import (
lu_solve,
pivot_to_permutation,
qr,
schur,
solve,
solve_continuous_lyapunov,
solve_discrete_are,
......@@ -1248,3 +1249,92 @@ def test_qr_grad(shape, gradient_test_case, mode, is_complex):
utt.verify_grad(
partial(_test_fn, case=gradient_test_case, mode=mode), [a], rng=np.random
)
class TestSchur:
@pytest.mark.parametrize(
"shape, output",
[((5, 5), "real"), ((5, 5), "complex"), ((2, 4, 4), "real")],
ids=["not_batched_real", "not_batched_complex", "batched_real"],
)
@pytest.mark.parametrize("complex", [False, True], ids=["real", "complex"])
def test_schur_decomposition(self, shape, output, complex):
dtype = (
config.floatX if not complex else f"complex{int(config.floatX[-2:]) * 2}"
)
A = tensor("A", shape=shape, dtype=dtype)
T, Z = schur(A, output=output)
f = function([A], [T, Z])
rng = np.random.default_rng(utt.fetch_seed())
x = rng.normal(size=shape).astype(config.floatX)
if complex:
x = x + 1j * rng.normal(size=shape).astype(config.floatX)
T_out, Z_out = f(x)
# Verify reconstruction
x_rebuilt = np.einsum("...ij,...jk,...lk->...il", Z_out, T_out, Z_out.conj())
np.testing.assert_allclose(
x,
x_rebuilt,
atol=1e-6 if config.floatX == "float64" else 1e-3,
rtol=1e-6 if config.floatX == "float64" else 1e-3,
)
vec_schur = np.vectorize(
lambda a: scipy_linalg.schur(a, output=output),
signature="(m,m)->(m,m),(m,m)",
)
scipy_T, scipy_Z = vec_schur(x)
np.testing.assert_allclose(T_out, scipy_T, atol=1e-6, rtol=1e-6)
np.testing.assert_allclose(Z_out, scipy_Z, atol=1e-6, rtol=1e-6)
if len(shape) == 2 and (output == "complex") == complex:
x_f = np.asfortranarray(x.copy())
f_mut = function(
[In(A, mutable=True)],
[T, Z],
mode=get_default_mode().including("inplace"),
)
f_mut(x_f)
np.testing.assert_allclose(x_f, scipy_T, atol=1e-6, rtol=1e-6)
@pytest.mark.parametrize("sort", ["lhp", "rhp", "iuc", "ouc"])
def test_schur_sort(self, sort):
rng = np.random.default_rng(utt.fetch_seed())
x = rng.normal(size=(3, 3)).astype(config.floatX)
A = matrix("A", dtype=config.floatX)
T, Z = schur(A, sort=sort)
f = function([A], [T, Z])
T_out, Z_out = f(x)
x_rebuilt = Z_out @ T_out @ Z_out.T
np.testing.assert_allclose(
x,
x_rebuilt,
atol=1e-6 if config.floatX == "float64" else 1e-3,
rtol=1e-6 if config.floatX == "float64" else 1e-3,
)
scipy_T, scipy_Z, _ = scipy_linalg.schur(x, output="real", sort=sort)
np.testing.assert_allclose(T_out, scipy_T, atol=1e-6, rtol=1e-6)
np.testing.assert_allclose(Z_out, scipy_Z, atol=1e-6, rtol=1e-6)
def test_schur_empty(self):
empty = np.empty([0, 0], dtype=config.floatX)
A = matrix()
T, Z = schur(A)
f = function([A], [T, Z])
T_out, Z_out = f(empty)
assert T_out.size == 0
assert Z_out.size == 0
assert T_out.dtype == config.floatX
assert Z_out.dtype == config.floatX
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论