提交 ad8dca48 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Jesse Grabowski

Implement QZ Op

上级 9834e96c
......@@ -1963,6 +1963,415 @@ def schur(
return Blockwise(Schur(output=output, sort=sort))(A) # type: ignore[return-value]
class QZ(Op):
"""
QZ Decomposition
"""
__props__ = (
"complex_output",
"overwrite_a",
"overwrite_b",
"sort",
"return_eigenvalues",
)
def __init__(
self,
complex_output: bool = False,
overwrite_a: bool = False,
overwrite_b: bool = False,
sort: Literal["lhp", "rhp", "iuc", "ouc"] | None = None,
return_eigenvalues: bool = False,
):
self.complex_output = complex_output
self.overwrite_a = overwrite_a
self.overwrite_b = overwrite_b
self.sort = sort
self.return_eigenvalues = return_eigenvalues
if return_eigenvalues:
self.gufunc_signature = "(m,m),(m,m)->(m,m),(m,m),(m),(m),(m,m),(m,m)"
else:
self.gufunc_signature = "(m,m),(m,m)->(m,m),(m,m),(m,m),(m,m)"
self.destroy_map = {}
if overwrite_a:
self.destroy_map[0] = [0]
if overwrite_b:
self.destroy_map[1] = [1]
if sort is not None and sort not in ("lhp", "rhp", "iuc", "ouc"):
raise ValueError("sort must be None or one of ('lhp', 'rhp', 'iuc', 'ouc')")
def make_sort_function(
self, sort: Literal["lhp", "rhp", "iuc", "ouc", "none"] | None = None
):
if sort is None:
sort = self.sort
sort_t = 1
match sort:
case None | "none":
sort_t = 0
def sort_function(alpha, beta):
"""No sorting."""
return None
case "lhp":
def sort_function(alpha, beta):
"""Sort eigenvalues with negative real part (left half-plane) to upper-left."""
out = np.empty(alpha.shape, dtype=bool)
nonzero = beta != 0
out[~nonzero] = False
out[nonzero] = (alpha[nonzero] / beta[nonzero]).real < 0.0
return out
case "rhp":
def sort_function(alpha, beta):
"""Sort eigenvalues with positive real part (right half-plane) to upper-left."""
out = np.empty(alpha.shape, dtype=bool)
nonzero = beta != 0
out[~nonzero] = False
out[nonzero] = (alpha[nonzero] / beta[nonzero]).real > 0.0
return out
case "iuc":
def sort_function(alpha, beta):
"""Sort eigenvalues inside the unit circle (abs(lambda) < 1) to upper-left."""
out = np.empty(alpha.shape, dtype=bool)
nonzero = beta != 0
out[~nonzero] = False
out[nonzero] = np.abs(alpha[nonzero] / beta[nonzero]) < 1.0
return out
case "ouc":
def sort_function(alpha, beta):
"""Sort eigenvalues outside the unit circle (abs(lambda) > 1) to upper-left.
Infinite eigenvalues (beta=0, alpha != 0) are included."""
out = np.empty(alpha.shape, dtype=bool)
alpha_zero = alpha == 0
beta_zero = beta == 0
beta_nonzero = ~beta_zero
out[alpha_zero & beta_zero] = False
out[~alpha_zero & beta_zero] = True
out[beta_nonzero] = (
np.abs(alpha[beta_nonzero] / beta[beta_nonzero]) > 1.0
)
return out
case _:
raise ValueError(
"sort must be None or one of ('lhp', 'rhp', 'iuc', 'ouc', 'none')"
)
return sort_function, sort_t
def make_node(self, A, B):
A = as_tensor_variable(A)
B = as_tensor_variable(B)
assert A.ndim == 2
assert B.ndim == 2
out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype)
if np.dtype(out_dtype).kind in "ibu":
out_dtype = "float64" if np.dtype(out_dtype).itemsize > 2 else "float32"
complex_input = out_dtype in ("complex64", "complex128")
# Scipy behavior: output parameter only affects real inputs
# Complex inputs always return complex output
if self.complex_output and not complex_input:
out_dtype = pytensor.scalar.upcast(out_dtype, "complex64")
AA = matrix(dtype=out_dtype, shape=A.type.shape)
BB = matrix(dtype=out_dtype, shape=B.type.shape)
Q = matrix(dtype=out_dtype, shape=A.type.shape)
Z = matrix(dtype=out_dtype, shape=A.type.shape)
if self.return_eigenvalues:
# Eigenvalues can be complex even for real matrices, so alpha is always complex
# beta has the same dtype as the matrix outputs
if complex_input or self.complex_output:
alpha_dtype = out_dtype
else:
alpha_dtype = pytensor.scalar.upcast(out_dtype, "complex64")
alpha = vector(dtype=alpha_dtype, shape=(A.type.shape[0],))
beta = vector(dtype=out_dtype, shape=(A.type.shape[0],))
return Apply(self, [A, B], [AA, BB, alpha, beta, Q, Z])
else:
return Apply(self, [A, B], [AA, BB, Q, Z])
def perform(self, node, inputs, outputs):
(A, B) = inputs
if self.return_eigenvalues:
(AA_out, BB_out, alpha_out, beta_out, Q_out, Z_out) = outputs
else:
(AA_out, BB_out, Q_out, Z_out) = outputs
overwrite_a = self.overwrite_a
overwrite_b = self.overwrite_b
A_work = A
B_work = B
if self.complex_output and not np.iscomplexobj(A):
overwrite_a = False
if A.dtype == np.float32:
A_work = A.astype(np.complex64)
else:
A_work = A.astype(np.complex128)
if self.complex_output and not np.iscomplexobj(B):
overwrite_b = False
if B.dtype == np.float32:
B_work = B.astype(np.complex64)
else:
B_work = B.astype(np.complex128)
if not self.complex_output and np.iscomplexobj(A):
overwrite_a = False
if not self.complex_output and np.iscomplexobj(B):
overwrite_b = False
(gges,) = scipy_linalg.get_lapack_funcs(("gges",), dtype=A_work.dtype)
gges_type = gges.typecode
no_sort_fn, no_sort_t = self.make_sort_function(sort="none")
# Workspace query
*_, work, _info = gges(
no_sort_fn,
A_work,
B_work,
lwork=-1,
overwrite_a=False,
overwrite_b=False,
sort_t=no_sort_t,
)
lwork = int(work[0].real)
# This Op is a combination of scipy.linalg.qz and scipy.linalg.ordqz. They first call gges with no sorting,
# then do the sorting in a second step if required
AA, BB, _sdim, *ab, Q, Z, _work, info = gges(
no_sort_fn,
A_work,
B_work,
lwork=lwork,
overwrite_a=overwrite_a,
overwrite_b=overwrite_b,
sort_t=no_sort_t,
)
# If this first pass failed, we skip the sorting step no matter what and return NaNs
# TODO: When info > 0 and info < A.shape[0], gges fails to put A and B in Shur form but the eigenvalues
# are still valid. We could potentially still return something in this case.
if info != 0:
AA_out[0] = np.full(A_work.shape, np.nan, dtype=node.outputs[0].type.dtype)
BB_out[0] = np.full(B_work.shape, np.nan, dtype=node.outputs[1].type.dtype)
Q_out[0] = np.full(A_work.shape, np.nan, dtype=node.outputs[-2].type.dtype)
Z_out[0] = np.full(A_work.shape, np.nan, dtype=node.outputs[-1].type.dtype)
if self.return_eigenvalues:
alpha_out[0] = np.full(
(A_work.shape[0],), np.nan, dtype=node.outputs[2].type.dtype
)
beta_out[0] = np.full(
(A_work.shape[0],), np.nan, dtype=node.outputs[3].type.dtype
)
return
if self.sort is not None or self.return_eigenvalues:
if gges_type == "s":
_alphar, _alphai, beta = ab
alpha = _alphar + np.complex64(1j) * _alphai
elif gges_type == "d":
_alphar, _alphai, beta = ab
alpha = _alphar + 1j * _alphai
else:
alpha, beta = ab
if self.sort is not None:
sort_function, _ = self.make_sort_function()
select = sort_function(alpha, beta)
tgsen = get_lapack_funcs("tgsen", (AA, BB))
lwork = 4 * AA.shape[0] + 16 if gges_type in "sd" else 1
AA, BB, *ab, Q, Z, _, _, _, _, info = tgsen(
select,
AA,
BB,
Q,
Z,
ijob=0,
lwork=lwork,
liwork=1,
overwrite_a=overwrite_a,
overwrite_b=overwrite_b,
)
if gges_type == "s":
alphar, alphai, beta = ab
alpha = alphar + np.complex64(1j) * alphai
elif gges_type == "d":
alphar, alphai, beta = ab
alpha = alphar + 1j * alphai
else:
alpha, beta = ab
if info != 0:
AA_out[0] = np.full(A_work.shape, np.nan, dtype=node.outputs[0].type.dtype)
BB_out[0] = np.full(B_work.shape, np.nan, dtype=node.outputs[1].type.dtype)
Q_out[0] = np.full(A_work.shape, np.nan, dtype=node.outputs[-2].type.dtype)
Z_out[0] = np.full(A_work.shape, np.nan, dtype=node.outputs[-1].type.dtype)
if self.return_eigenvalues:
alpha_out[0] = np.full(
(A_work.shape[0],), np.nan, dtype=node.outputs[2].type.dtype
)
beta_out[0] = np.full(
(A_work.shape[0],), np.nan, dtype=node.outputs[3].type.dtype
)
else:
AA_out[0] = AA
BB_out[0] = BB
Q_out[0] = Q
Z_out[0] = Z
if self.return_eigenvalues:
alpha_out[0] = alpha
beta_out[0] = beta
def infer_shape(self, fgraph, node, shapes):
A_shape, B_shape = shapes
if self.return_eigenvalues:
return [A_shape, B_shape, (A_shape[0],), (A_shape[0],), A_shape, B_shape]
else:
return [A_shape, B_shape, A_shape, B_shape]
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
if not allowed_inplace_inputs:
return self
new_props = self._props_dict() # type: ignore
if 0 in allowed_inplace_inputs:
new_props["overwrite_a"] = True
if 1 in allowed_inplace_inputs:
new_props["overwrite_b"] = True
return type(self)(**new_props)
def qz(
A: TensorLike,
B: TensorLike,
output: Literal["real", "complex"] = "real",
sort: Literal["lhp", "rhp", "iuc", "ouc"] | None = None,
return_eigenvalues: bool = False,
) -> (
tuple[TensorVariable, TensorVariable, TensorVariable, TensorVariable]
| tuple[
TensorVariable,
TensorVariable,
TensorVariable,
TensorVariable,
TensorVariable,
TensorVariable,
]
):
"""
QZ Decomposition of input matrix pair `(A, B)`.
The QZ decomposition (also known as the generalized Schur decomposition) of a matrix pair
`(A, B)` is a factorization of the form :math:`A = Q H Z^H` and :math:`B = Q K Z^H`,
where `Q` and `Z` are unitary matrices, and `H` and `K` are upper-triangular matrices.
Parameters
----------
A: TensorLike
First input square matrix of shape (M, M) to be decomposed.
B: TensorLike
Second input square matrix of shape (M, M) to be decomposed.
output: str, one of "real" or "complex"
For real-valued `A` and `B`, if output='real', then the Schur forms are quasi-upper-triangular.
If output='complex', the Schur forms are upper-triangular. For complex-valued `A` and `B`,
the Schur forms are always upper-triangular regardless of the output parameter.
sort: str or None, optional
Specifies whether the generalized eigenvalues should be sorted. Available options are:
- None (default): eigenvalues are not sorted
- 'lhp': left half-plane (real(λ) < 0)
- 'rhp': right half-plane (real(λ) >= 0)
- 'iuc': inside unit circle (abs(λ) <= 1)
- 'ouc': outside unit circle (abs(λ) > 1)
return_eigenvalues: bool, default False
If True, the function also returns the generalized eigenvalues as two arrays `alpha` and `beta`,
where the generalized eigenvalues are given by the ratio `alpha / beta`.
Returns
-------
H : TensorVariable
Schur form of A. An upper-triangular matrix (or quasi-upper-triangular if output='real').
K : TensorVariable
Schur form of B. An upper-triangular matrix (or quasi-upper-triangular if output='real').
Q : TensorVariable
Unitary matrix such that A = Q @ H @ Z.conj().T and B = Q @ K @ Z.conj().T.
Z : TensorVariable
Unitary matrix such that A = Q @ H @ Z.conj().T and B = Q @ K @ Z.conj().T.
alpha : TensorVariable, optional
Numerators of the generalized eigenvalues (returned if `return_eigenvalues` is True).
beta : TensorVariable, optional
Denominators of the generalized eigenvalues (returned if `return_eigenvalues` is True).
Notes
-----
Unlike scipy.linalg.qz, the sort function is allowed. Behavior in this case follows that of scipy.linalg.ordqz.
"""
if output not in ["real", "complex"]:
raise ValueError("output must be 'real' or 'complex'")
complex_output = output == "complex"
qz_op = QZ(
complex_output=complex_output, sort=sort, return_eigenvalues=return_eigenvalues
)
return Blockwise(qz_op)(A, B) # type: ignore[return-value]
def ordqz(
A: TensorLike,
B: TensorLike,
sort: Literal["lhp", "rhp", "iuc", "ouc"] | None = None,
output: Literal["real", "complex"] = "real",
) -> (
tuple[TensorVariable, TensorVariable, TensorVariable, TensorVariable]
| tuple[
TensorVariable,
TensorVariable,
TensorVariable,
TensorVariable,
TensorVariable,
TensorVariable,
]
):
"""
Ordered QZ Decomposition of input matrix pair `(A, B)`.
Alias for `qz`. Included for API consistency with `scipy.linalg`. For details, see the docstring of
`pytensor.linalg.qz`.
"""
return qz(A, B, output=output, sort=sort, return_eigenvalues=True)
_deprecated_names = {
"solve_continuous_lyapunov",
"solve_discrete_are",
......@@ -1994,7 +2403,9 @@ __all__ = [
"lu",
"lu_factor",
"lu_solve",
"ordqz",
"qr",
"qz",
"schur",
"solve",
"solve_triangular",
......
......@@ -37,6 +37,7 @@ from pytensor.tensor.slinalg import (
lu_solve,
pivot_to_permutation,
qr,
qz,
schur,
solve,
solve_triangular,
......@@ -1366,3 +1367,103 @@ class TestSchur:
assert Z_out.size == 0
assert T_out.dtype == config.floatX
assert Z_out.dtype == config.floatX
class TestQZ:
@pytest.mark.parametrize(
"shape, output",
[((5, 5), "real"), ((5, 5), "complex"), ((2, 4, 4), "real")],
ids=["not_batched_real", "not_batched_complex", "batched_real"],
)
@pytest.mark.parametrize("complex", [False, True], ids=["real", "complex"])
@pytest.mark.parametrize("sort", [None, "lhp", "rhp", "iuc", "ouc"])
def test_qz_decomposition(self, shape, output, complex, sort):
dtype = (
config.floatX if not complex else f"complex{int(config.floatX[-2:]) * 2}"
)
A = tensor("A", shape=shape, dtype=dtype)
B = tensor("B", shape=shape, dtype=dtype)
outputs = qz(
A, B, output=output, sort=sort, return_eigenvalues=sort is not None
)
f = function([A, B], outputs)
rng = np.random.default_rng(utt.fetch_seed())
A_val, B_val = rng.normal(size=(2, *shape))
A_val = A_val.astype(config.floatX)
B_val = B_val.astype(config.floatX)
if complex:
A_val = A_val + 1j * rng.normal(size=shape).astype(config.floatX)
B_val = B_val + 1j * rng.normal(size=shape).astype(config.floatX)
output_values = f(A_val, B_val)
if sort is None:
AA_val, BB_val, Q_val, Z_val = output_values
else:
AA_val, BB_val, alpha_val, beta_val, Q_val, Z_val = output_values
# Verify reconstruction
A_rebuilt = np.einsum("...ij,...jk,...lk->...il", Q_val, AA_val, Z_val.conj())
B_rebuilt = np.einsum("...ij,...jk,...lk->...il", Q_val, BB_val, Z_val.conj())
np.testing.assert_allclose(
A_val,
A_rebuilt,
atol=1e-6 if config.floatX == "float64" else 1e-3,
rtol=1e-6 if config.floatX == "float64" else 1e-3,
)
np.testing.assert_allclose(
B_val,
B_rebuilt,
atol=1e-6 if config.floatX == "float64" else 1e-3,
rtol=1e-6 if config.floatX == "float64" else 1e-3,
)
scipy_fn = (
scipy_linalg.qz
if sort is None
else functools.partial(scipy_linalg.ordqz, sort=sort)
)
scipy_signature = (
"(m,m),(m,m)->(m,m),(m,m),(m,m),(m,m)"
if sort is None
else ("(m,m),(m,m)->(m,m),(m,m),(m),(m),(m,m),(m,m)")
)
vec_qz = np.vectorize(
lambda a, b: scipy_fn(a, b, output=output),
signature=scipy_signature,
)
scipy_result = vec_qz(A_val, B_val)
if sort is None:
scipy_AA, scipy_BB, scipy_Q, scipy_Z = scipy_result
else:
scipy_AA, scipy_BB, scipy_alpha, scipy_beta, scipy_Q, scipy_Z = scipy_result
np.testing.assert_allclose(AA_val, scipy_AA, atol=1e-6, rtol=1e-6)
np.testing.assert_allclose(BB_val, scipy_BB, atol=1e-6, rtol=1e-6)
np.testing.assert_allclose(Q_val, scipy_Q, atol=1e-6, rtol=1e-6)
np.testing.assert_allclose(Z_val, scipy_Z, atol=1e-6, rtol=1e-6)
if sort is not None:
np.testing.assert_allclose(alpha_val, scipy_alpha, atol=1e-6, rtol=1e-6)
np.testing.assert_allclose(beta_val, scipy_beta, atol=1e-6, rtol=1e-6)
if len(shape) == 2 and (output == "complex") == complex:
A_f = np.asfortranarray(A_val.copy())
B_f = np.asfortranarray(B_val.copy())
f_mut = function(
[In(A, mutable=True), In(B, mutable=True)],
outputs,
mode=get_default_mode().including("inplace"),
)
f_mut(A_f, B_f)
np.testing.assert_allclose(A_f, scipy_AA, atol=1e-6, rtol=1e-6)
np.testing.assert_allclose(B_f, scipy_BB, atol=1e-6, rtol=1e-6)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论