提交 a3bf6bb6 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Jesse Grabowski

Add TRSYL Op, refactor linear control Ops

上级 afabe83f
......@@ -55,8 +55,8 @@ from pytensor.tensor.slinalg import (
LUFactor,
Solve,
SolveBase,
SolveBilinearDiscreteLyapunov,
SolveTriangular,
_bilinear_solve_discrete_lyapunov,
block_diag,
cholesky,
solve,
......@@ -1045,10 +1045,10 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
return [eye_input * (non_eye_input**0.5)]
@node_rewriter([_bilinear_solve_discrete_lyapunov])
@node_rewriter([SolveBilinearDiscreteLyapunov])
def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
"""
Replace BilinearSolveDiscreteLyapunov with a direct computation that is supported by JAX
Replace SolveBilinearDiscreteLyapunov with a direct computation that is supported by JAX
"""
A, B = (cast(TensorVariable, x) for x in node.inputs)
result = solve_discrete_lyapunov(A, B, method="direct")
......
......@@ -11,6 +11,7 @@ from scipy.linalg import get_lapack_funcs
import pytensor
from pytensor import ifelse
from pytensor import tensor as pt
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType, disconnected_type
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
......@@ -21,6 +22,7 @@ from pytensor.tensor import math as ptm
from pytensor.tensor.basic import as_tensor_variable, diagonal
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.nlinalg import kron, matrix_dot
from pytensor.tensor.reshape import join_dims
from pytensor.tensor.shape import reshape
from pytensor.tensor.type import matrix, tensor, vector
from pytensor.tensor.variable import TensorVariable
......@@ -1296,151 +1298,188 @@ class Expm(Op):
expm = Blockwise(Expm())
class SolveContinuousLyapunov(Op):
class TRSYL(Op):
"""
Solves a continuous Lyapunov equation, :math:`AX + XA^H = B`, for :math:`X.
Wrapper around LAPACK's `trsyl` function to solve the Sylvester equation:
Continuous time Lyapunov equations are special cases of Sylvester equations, :math:`AX + XB = C`, and can be solved
efficiently using the Bartels-Stewart algorithm. For more details, see the docstring for
scipy.linalg.solve_continuous_lyapunov
op(A) @ X + X @ op(B) = alpha * C
Where `op(A)` is either `A` or `A^T`, depending on the options passed to `trsyl`. A and B must be
in Schur canonical form: block upper triangular matrices with 1x1 and 2x2 blocks on the diagonal;
each 2x2 diagonal block has its diagonal elements equal and its off-diagonal elements opposite in sign.
This Op is not public facing. Instead, it is intended to be used as a building block for higher-level
linear control solvers, such as `SolveSylvester` and `SolveContinuousLyapunov`.
"""
__props__ = ()
gufunc_signature = "(m,m),(m,m)->(m,m)"
__props__ = ("overwrite_c",)
gufunc_signature = "(m,m),(n,n),(m,n)->(m,n)"
def __init__(self, overwrite_c=False):
self.overwrite_c = overwrite_c
if self.overwrite_c:
self.destroy_map = {0: [2]}
def make_node(self, A, B):
def make_node(self, A, B, C):
A = as_tensor_variable(A)
B = as_tensor_variable(B)
C = as_tensor_variable(C)
out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype)
X = pytensor.tensor.matrix(dtype=out_dtype)
out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype, C.dtype)
return pytensor.graph.basic.Apply(self, [A, B], [X])
output_shape = list(C.type.shape)
if output_shape[0] is None and A.type.shape[0] is not None:
output_shape[0] = A.type.shape[0]
if output_shape[1] is None and B.type.shape[0] is not None:
output_shape[1] = B.type.shape[0]
def perform(self, node, inputs, output_storage):
(A, B) = inputs
X = output_storage[0]
X = tensor(dtype=out_dtype, shape=tuple(output_shape))
return Apply(self, [A, B, C], [X])
def perform(self, node, inputs, outputs_storage):
(A, B, C) = inputs
X = outputs_storage[0]
out_dtype = node.outputs[0].type.dtype
X[0] = scipy_linalg.solve_continuous_lyapunov(A, B).astype(out_dtype)
(trsyl,) = get_lapack_funcs(("trsyl",), (A, B, C))
if A.size == 0 or B.size == 0:
return np.empty_like(C, dtype=out_dtype)
Y, scale, info = trsyl(A, B, C, overwrite_c=self.overwrite_c)
if info < 0:
return np.full_like(C, np.nan, dtype=out_dtype)
Y *= scale
X[0] = Y
def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
return [shapes[2]]
def grad(self, inputs, output_grads):
# Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
# Note that they write the equation as AX + XA.H + Q = 0, while scipy uses AX + XA^H = Q,
# so minor adjustments need to be made.
A, Q = inputs
(dX,) = output_grads
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
if not allowed_inplace_inputs:
return self
new_props = self._props_dict() # type: ignore
new_props["overwrite_c"] = True
return type(self)(**new_props)
X = self(A, Q)
S = self(A.conj().T, -dX) # Eq 31, adjusted
A_bar = S.dot(X.conj().T) + S.conj().T.dot(X)
Q_bar = -S # Eq 29, adjusted
def _trsyl(A: TensorLike, B: TensorLike, C: TensorLike) -> TensorVariable:
A = as_tensor_variable(A)
B = as_tensor_variable(B)
C = as_tensor_variable(C)
return [A_bar, Q_bar]
return cast(TensorVariable, Blockwise(TRSYL())(A, B, C))
_solve_continuous_lyapunov = Blockwise(SolveContinuousLyapunov())
class SolveSylvester(OpFromGraph):
"""
Wrapper Op for solving the continuous Sylvester equation :math:`AX + XB = C` for :math:`X`.
"""
gufunc_signature = "(m,m),(n,n),(m,n)->(m,n)"
def solve_continuous_lyapunov(A: TensorLike, Q: TensorLike) -> TensorVariable:
def _lop_solve_continuous_sylvester(inputs, outputs, output_grads):
"""
Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
Closed-form gradients for the solution for the Sylvester equation.
Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
Note that these authors write the equation as AP + PB + C = 0. The code here follows scipy notation,
so P = X and C = -Q. This change of notation requires minor adjustment to equations 10 and 11c
"""
A, B, _ = inputs
(dX,) = output_grads
(X,) = outputs
S = solve_sylvester(A.conj().mT, B.conj().mT, -dX) # Eq 10
A_bar = S @ X.conj().mT # Eq 11a
B_bar = X.conj().mT @ S # Eq 11b
Q_bar = -S # Eq 11c
return [A_bar, B_bar, Q_bar]
def solve_sylvester(A: TensorLike, B: TensorLike, Q: TensorLike) -> TensorVariable:
"""
Solve the Sylvester equation :math:`AX + XB = C` for :math:`X`.
Following scipy notation, this function solves the continuous-time Sylvester equation.
Parameters
----------
A: TensorLike
Square matrix of shape ``M x M``.
B: TensorLike
Square matrix of shape ``N x N``.
Q: TensorLike
Square matrix of shape ``N x N``.
Matrix of shape ``M x N``.
Returns
-------
X: TensorVariable
Square matrix of shape ``N x N``
"""
return cast(TensorVariable, _solve_continuous_lyapunov(A, Q))
class BilinearSolveDiscreteLyapunov(Op):
"""
Solves a discrete lyapunov equation, :math:`AXA^H - X = Q`, for :math:`X.
The solution is computed by first transforming the discrete-time problem into a continuous-time form. The continuous
time lyapunov is a special case of a Sylvester equation, and can be efficiently solved. For more details, see the
docstring for scipy.linalg.solve_discrete_lyapunov
Matrix of shape ``M x N``.
"""
gufunc_signature = "(m,m),(m,m)->(m,m)"
def make_node(self, A, B):
A = as_tensor_variable(A)
B = as_tensor_variable(B)
Q = as_tensor_variable(Q)
out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype)
X = pytensor.tensor.matrix(dtype=out_dtype)
A_matrix = pt.matrix(dtype=A.dtype, shape=A.type.shape[-2:])
B_matrix = pt.matrix(dtype=B.dtype, shape=B.type.shape[-2:])
Q_matrix = pt.matrix(dtype=Q.dtype, shape=Q.type.shape[-2:])
return pytensor.graph.basic.Apply(self, [A, B], [X])
R, U = schur(A_matrix, output="real")
S, V = schur(B_matrix, output="real")
F = U.conj().mT @ Q_matrix @ V
def perform(self, node, inputs, output_storage):
(A, B) = inputs
X = output_storage[0]
Y = _trsyl(R, S, F)
X = U @ Y @ V.conj().mT
out_dtype = node.outputs[0].type.dtype
X[0] = scipy_linalg.solve_discrete_lyapunov(A, B, method="bilinear").astype(
out_dtype
op = SolveSylvester(
inputs=[A_matrix, B_matrix, Q_matrix],
outputs=[X],
lop_overrides=_lop_solve_continuous_sylvester,
)
def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
def grad(self, inputs, output_grads):
# Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
A, Q = inputs
(dX,) = output_grads
X = self(A, Q)
# Eq 41, note that it is not written as a proper Lyapunov equation
S = self(A.conj().T, dX)
return cast(TensorVariable, Blockwise(op)(A, B, Q))
A_bar = pytensor.tensor.linalg.matrix_dot(
S, A, X.conj().T
) + pytensor.tensor.linalg.matrix_dot(S.conj().T, A, X)
Q_bar = S
return [A_bar, Q_bar]
def solve_continuous_lyapunov(A: TensorLike, Q: TensorLike) -> TensorVariable:
"""
Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
_bilinear_solve_discrete_lyapunov = Blockwise(BilinearSolveDiscreteLyapunov())
Note that the lyapunov equation is a special case of the Sylvester equation, with :math:`B = A^H`. This function
thus simply calls `solve_sylvester` with the appropriate arguments.
Parameters
----------
A: TensorLike
Square matrix of shape ``N x N``.
Q: TensorLike
Square matrix of shape ``N x N``.
def _direct_solve_discrete_lyapunov(
A: TensorVariable, Q: TensorVariable
) -> TensorVariable:
r"""
Directly solve the discrete Lyapunov equation :math:`A X A^H - X = Q` using the kronecker method of Magnus and
Neudecker.
Returns
-------
X: TensorVariable
Square matrix of shape ``N x N``
This involves constructing and inverting an intermediate matrix :math:`A \otimes A`, with shape :math:`N^2 x N^2`.
As a result, this method scales poorly with the size of :math:`N`, and should be avoided for large :math:`N`.
"""
A = as_tensor_variable(A)
Q = as_tensor_variable(Q)
if A.type.dtype.startswith("complex"):
AxA = kron(A, A.conj())
else:
AxA = kron(A, A)
return solve_sylvester(A, A.conj().mT, Q)
eye = pt.eye(AxA.shape[-1])
vec_Q = Q.ravel()
vec_X = solve(eye - AxA, vec_Q, b_ndim=1)
class SolveBilinearDiscreteLyapunov(OpFromGraph):
"""
Wrapper Op for solving the discrete Lyapunov equation :math:`A X A^H - X = Q` for :math:`X`.
return reshape(vec_X, A.shape)
Required so that backends that do not support method='bilinear' in `solve_discrete_lyapunov` can be rewritten
to method='direct'.
"""
def solve_discrete_lyapunov(
......@@ -1477,12 +1516,29 @@ def solve_discrete_lyapunov(
Q = as_tensor_variable(Q)
if method == "direct":
signature = BilinearSolveDiscreteLyapunov.gufunc_signature
X = pt.vectorize(_direct_solve_discrete_lyapunov, signature=signature)(A, Q)
return cast(TensorVariable, X)
vec_kron = pt.vectorize(kron, signature="(n,n),(n,n)->(m,m)")
AxA = vec_kron(A, A.conj())
eye = pt.eye(AxA.shape[-1])
vec_Q = join_dims(Q, start_axis=-2, n_axes=2)
vec_X = solve(eye - AxA, vec_Q, b_ndim=1)
return reshape(vec_X, A.shape)
elif method == "bilinear":
return cast(TensorVariable, _bilinear_solve_discrete_lyapunov(A, Q))
I = pt.eye(A.shape[-2])
B_1 = A.conj().mT + I
B_2 = A.conj().mT - I
B = solve(B_1.mT, B_2.mT).mT
AI_inv_Q = solve(A + I, Q)
C = 2 * solve(B_1.mT, AI_inv_Q.mT).mT
X = solve_continuous_lyapunov(B.conj().mT, -C)
op = SolveBilinearDiscreteLyapunov(inputs=[A, Q], outputs=[X])
return cast(TensorVariable, op(A, Q))
else:
raise ValueError(f"Unknown method {method}")
......@@ -2270,5 +2326,6 @@ __all__ = [
"solve_continuous_lyapunov",
"solve_discrete_are",
"solve_discrete_lyapunov",
"solve_sylvester",
"solve_triangular",
]
......@@ -36,6 +36,7 @@ from pytensor.tensor.slinalg import (
solve_continuous_lyapunov,
solve_discrete_are,
solve_discrete_lyapunov,
solve_sylvester,
solve_triangular,
)
from pytensor.tensor.type import dmatrix, matrix, tensor, vector
......@@ -916,6 +917,67 @@ def test_expm_grad(mode):
utt.verify_grad(expm, [A], rng=rng, abs_tol=1e-5, rel_tol=1e-5)
@pytest.mark.parametrize(
"shape, use_complex",
[((5, 5), False), ((5, 5), True), ((5, 5, 5), False)],
ids=["float", "complex", "batch_float"],
)
def test_solve_continuous_sylvester(shape: tuple[int], use_complex: bool):
# batch-complex case got an error from BatchedDot not implemented for complex numbers
rng = np.random.default_rng()
dtype = config.floatX
if use_complex:
dtype = "complex128" if dtype == "float64" else "complex64"
A1, A2 = rng.normal(size=(2, *shape))
B1, B2 = rng.normal(size=(2, *shape))
Q1, Q2 = rng.normal(size=(2, *shape))
if use_complex:
A_val = A1 + 1j * A2
B_val = B1 + 1j * B2
Q_val = Q1 + 1j * Q2
else:
A_val = A1
B_val = B1
Q_val = Q1
A = pt.tensor("A", shape=shape, dtype=dtype)
B = pt.tensor("B", shape=shape, dtype=dtype)
Q = pt.tensor("Q", shape=shape, dtype=dtype)
X = solve_sylvester(A, B, Q)
Q_recovered = A @ X + X @ B
fn = function([A, B, Q], [X, Q_recovered])
X_val, Q_recovered_val = fn(A_val, B_val, Q_val)
vec_sylvester = np.vectorize(
scipy_linalg.solve_sylvester, signature="(m,m),(m,m),(m,m)->(m,m)"
)
np.testing.assert_allclose(Q_recovered_val, Q_val, atol=1e-8, rtol=1e-8)
np.testing.assert_allclose(
X_val, vec_sylvester(A_val, B_val, Q_val), atol=1e-8, rtol=1e-8
)
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batched"])
@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
def test_solve_continuous_sylvester_grad(shape: tuple[int], use_complex):
if config.floatX == "float32":
pytest.skip(reason="Not enough precision in float32 to get a good gradient")
if use_complex:
pytest.skip(reason="Complex numbers are not supported in the gradient test")
rng = np.random.default_rng(utt.fetch_seed())
A = rng.normal(size=shape).astype(config.floatX)
B = rng.normal(size=shape).astype(config.floatX)
Q = rng.normal(size=shape).astype(config.floatX)
utt.verify_grad(solve_sylvester, pt=[A, B, Q], rng=rng)
def recover_Q(A, X, continuous=True):
if continuous:
return A @ X + X @ A.conj().T
......@@ -985,60 +1047,24 @@ def test_solve_discrete_lyapunov_gradient(
)
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batched"])
@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
def test_solve_continuous_lyapunov(shape: tuple[int], use_complex: bool):
dtype = config.floatX
if use_complex and dtype == "float32":
pytest.skip(
"Not enough precision in complex64 to do schur decomposition "
"(ill-conditioned matrix errors arise)"
)
rng = np.random.default_rng(utt.fetch_seed())
if use_complex:
precision = int(dtype[-2:]) # 64 or 32
dtype = f"complex{int(2 * precision)}"
A1, A2 = rng.normal(size=(2, *shape))
Q1, Q2 = rng.normal(size=(2, *shape))
if use_complex:
A = A1 + 1j * A2
Q = Q1 + 1j * Q2
else:
A = A1
Q = Q1
A, Q = A.astype(dtype), Q.astype(dtype)
a = pt.tensor(name="a", shape=shape, dtype=dtype)
q = pt.tensor(name="q", shape=shape, dtype=dtype)
x = solve_continuous_lyapunov(a, q)
f = function([a, q], x)
X = f(A, Q)
Q_recovered = vec_recover_Q(A, X, continuous=True)
atol = rtol = 1e-2 if config.floatX == "float32" else 1e-8
np.testing.assert_allclose(Q_recovered.squeeze(), Q, atol=atol, rtol=rtol)
def test_solve_continuous_lyapunov():
# solve_continuous_lyapunov just calls solve_sylvester, so extensive tests are not needed.
A = pt.tensor("A", shape=(3, 5, 5))
Q = pt.tensor("Q", shape=(3, 5, 5))
X = solve_continuous_lyapunov(A, Q)
Q_recovered = A @ X + X @ A.conj().mT
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batched"])
@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
def test_solve_continuous_lyapunov_grad(shape: tuple[int], use_complex):
if config.floatX == "float32":
pytest.skip(reason="Not enough precision in float32 to get a good gradient")
if use_complex:
pytest.skip(reason="Complex numbers are not supported in the gradient test")
fn = function([A, Q], [X, Q_recovered])
rng = np.random.default_rng(utt.fetch_seed())
A = rng.normal(size=shape).astype(config.floatX)
Q = rng.normal(size=shape).astype(config.floatX)
A_val = rng.normal(size=(3, 5, 5)).astype(config.floatX)
Q_val = rng.normal(size=(3, 5, 5)).astype(config.floatX)
_, Q_recovered_val = fn(A_val, Q_val)
utt.verify_grad(solve_continuous_lyapunov, pt=[A, Q], rng=rng)
atol = rtol = 1e-2 if config.floatX == "float32" else 1e-8
np.testing.assert_allclose(Q_recovered_val, Q_val, atol=atol, rtol=rtol)
utt.verify_grad(solve_continuous_lyapunov, pt=[A_val, Q_val], rng=rng)
@pytest.mark.parametrize("add_batch_dim", [False, True])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论