提交 55dad3b8 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Jesse Grabowski

Re-implement solve_discrete_are as OpFromGraph

上级 9f911e35
from typing import Literal, cast
import numpy as np
from scipy import linalg as scipy_linalg
from scipy.linalg import get_lapack_funcs
import pytensor
import pytensor.tensor.basic as ptb
import pytensor.tensor.math as ptm
from pytensor.compile.builders import OpFromGraph
from pytensor.graph import Apply, Op
from pytensor.tensor import TensorLike
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.basic import as_tensor_variable, zeros
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.functional import vectorize
from pytensor.tensor.nlinalg import kron, matrix_dot
from pytensor.tensor.nlinalg import kron, matrix_dot, norm
from pytensor.tensor.reshape import join_dims
from pytensor.tensor.shape import reshape
from pytensor.tensor.slinalg import schur, solve
from pytensor.tensor.slinalg import lu, qr, qz, schur, solve, solve_triangular
from pytensor.tensor.type import matrix
from pytensor.tensor.variable import TensorVariable
......@@ -260,61 +260,43 @@ def solve_discrete_lyapunov(
raise ValueError(f"Unknown method {method}")
class SolveDiscreteARE(Op):
__props__ = ("enforce_Q_symmetric",)
gufunc_signature = "(m,m),(m,n),(m,m),(n,n)->(m,m)"
def __init__(self, enforce_Q_symmetric: bool = False):
self.enforce_Q_symmetric = enforce_Q_symmetric
def make_node(self, A, B, Q, R):
A = as_tensor_variable(A)
B = as_tensor_variable(B)
Q = as_tensor_variable(Q)
R = as_tensor_variable(R)
out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype, Q.dtype, R.dtype)
X = pytensor.tensor.matrix(dtype=out_dtype)
return pytensor.graph.basic.Apply(self, [A, B, Q, R], [X])
def perform(self, node, inputs, output_storage):
A, B, Q, R = inputs
X = output_storage[0]
def _lop_solve_discrete_are(inputs, outputs, output_grads):
"""
Closed-form gradients for the solution for the discrete Algebraic Riccati equation.
if self.enforce_Q_symmetric:
Q = 0.5 * (Q + Q.T)
Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
"""
A, B, Q, R = inputs
out_dtype = node.outputs[0].type.dtype
X[0] = scipy_linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype)
(dX,) = output_grads
X = solve_discrete_are(A, B, Q, R)
def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
K_inner = R + matrix_dot(B.T, X, B)
def grad(self, inputs, output_grads):
# Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
A, B, Q, R = inputs
# K_inner is guaranteed to be symmetric, because X and R are symmetric
K_inner_inv_BT = solve(K_inner, B.T, assume_a="sym")
K = matrix_dot(K_inner_inv_BT, X, A)
(dX,) = output_grads
X = self(A, B, Q, R)
A_tilde = A - B.dot(K)
K_inner = R + matrix_dot(B.T, X, B)
dX_symm = 0.5 * (dX + dX.T)
S = solve_discrete_lyapunov(A_tilde, dX_symm)
# K_inner is guaranteed to be symmetric, because X and R are symmetric
K_inner_inv_BT = solve(K_inner, B.T, assume_a="sym")
K = matrix_dot(K_inner_inv_BT, X, A)
A_bar = 2 * matrix_dot(X, A_tilde, S)
B_bar = -2 * matrix_dot(X, A_tilde, S, K.T)
Q_bar = S
R_bar = matrix_dot(K, S, K.T)
A_tilde = A - B.dot(K)
return [A_bar, B_bar, Q_bar, R_bar]
dX_symm = 0.5 * (dX + dX.T)
S = solve_discrete_lyapunov(A_tilde, dX_symm)
A_bar = 2 * matrix_dot(X, A_tilde, S)
B_bar = -2 * matrix_dot(X, A_tilde, S, K.T)
Q_bar = S
R_bar = matrix_dot(K, S, K.T)
class SolveDiscreteARE(OpFromGraph):
"""
Wrapper Op for solving the discrete Algebraic Riccati equation
:math:`A^TXA - X - (A^TXB)(R + B^TXB)^{-1}(B^TXA) + Q = 0` for :math:`X`.
"""
return [A_bar, B_bar, Q_bar, R_bar]
gufunc_signature = "(m,m),(m,n),(m,m),(n,n)->(m,m)"
def solve_discrete_are(
......@@ -322,7 +304,6 @@ def solve_discrete_are(
B: TensorLike,
Q: TensorLike,
R: TensorLike,
enforce_Q_symmetric: bool = False,
) -> TensorVariable:
"""
Solve the discrete Algebraic Riccati equation :math:`A^TXA - X - (A^TXB)(R + B^TXB)^{-1}(B^TXA) + Q = 0`.
......@@ -344,19 +325,129 @@ def solve_discrete_are(
Symmetric square matrix of shape M x M
R: TensorLike
Square matrix of shape N x N
enforce_Q_symmetric: bool
If True, the provided Q matrix is transformed to 0.5 * (Q + Q.T) to ensure symmetry
Returns
-------
X: TensorVariable
Square matrix of shape M x M, representing the solution to the DARE
Notes
-----
This function is copied from the scipy implementation, found here: https://github.com/scipy/scipy/blob/892baa06054c31bed734423c0f53eaed52b1914b/scipy/linalg/_solvers.py#L687
Notes are also adapted from the scipy documentation.
The equation is solved by forming the extended symplectic matrix pencil as described in [1]_,
:math: `H - \\lambda J`, given by the block matrices:
.. math::
H = \begin{bmatrix} A & 0 & B \\\\
-Q & I & 0 \\\\
0 & 0 & R \\end{bmatrix}
, \\quad
J = \begin{bmatrix} I & 0 & 0 \\\\
0 & A^H & 0 \\\\
0 & -B^H & 0 \\end{bmatrix}
The stable invariant subspace of the pencil is then computed via the QZ decomposition. Failure conditions are
linked to the symmetry of the solution matrix :math:`U_2 U_1^{-1}`, as described in [1]_ and [2]_. When the
solution is not symmetric, NaNs are returned.
[3]_ describes a balancing procedure for Hamiltonian matrices that can improve numerical stability. This procedure
is not yet implemented in this function.
References
----------
.. [1] P. van Dooren , "A Generalized Eigenvalue Approach For Solving
Riccati Equations.", SIAM Journal on Scientific and Statistical
Computing, Vol.2(2), :doi:`10.1137/0902010`
.. [2] A.J. Laub, "A Schur Method for Solving Algebraic Riccati
Equations.", Massachusetts Institute of Technology. Laboratory for
Information and Decision Systems. LIDS-R ; 859. Available online :
http://hdl.handle.net/1721.1/1301
.. [3] P. Benner, "Symplectic Balancing of Hamiltonian Matrices", 2001,
SIAM J. Sci. Comput., 2001, Vol.22(5), :doi:`10.1137/S1064827500367993`
"""
A, B, Q, R = map(as_tensor_variable, (A, B, Q, R))
is_complex = any(
input_matrix.type.numpy_dtype.kind == "c" for input_matrix in (A, B, Q, R)
)
return cast(
TensorVariable, Blockwise(SolveDiscreteARE(enforce_Q_symmetric))(A, B, Q, R)
A_core = matrix(dtype=A.dtype, shape=A.type.shape[-2:])
B_core = matrix(dtype=B.dtype, shape=B.type.shape[-2:])
Q_core = matrix(dtype=Q.dtype, shape=Q.type.shape[-2:])
R_core = matrix(dtype=R.dtype, shape=R.type.shape[-2:])
# Given Zmm = zeros(m, m), Zmn = zeros(m, n), Znm = zeros(n, m), E = eye(m)
# Construct the block matrix H of shape (2m + n, 2m + n):
# H = block([[ A, Zmm, B ],
# [ -Q, E, Zmn],
# [Znm, Znm, R ]])
m, n = B_core.shape[-2:]
H = zeros((2 * m + n, 2 * m + n), dtype=A.dtype)
H = H[:m, :m].set(A_core)
H = H[:m, 2 * m :].set(B_core)
H = H[m : 2 * m, :m].set(-Q_core)
H = H[m : 2 * m, m : 2 * m].set(ptb.eye(m))
H = H[2 * m :, 2 * m :].set(R_core)
# Construct block matrix J of shape (2m + n, 2m + n):
# J = block([[ E, Zmm, Zmn],
# [Zmm, A^H, Zmn],
# [Znm, -B^H, Zmn]])
J = zeros((2 * m + n, 2 * m + n), dtype=A_core.dtype)
J = J[:m, :m].set(ptb.eye(m))
J = J[m : 2 * m, m : 2 * m].set(A_core.conj().T)
J = J[2 * m :, m : 2 * m].set(-B_core.conj().T)
# TODO: Implement balancing procedure from [3]_
Q_of_QR, _ = qr(H[:, -n:], mode="full")
H = Q_of_QR[:, n:].conj().T @ H[:, : 2 * m]
J = Q_of_QR[:, n:].conj().T @ J[:, : 2 * m]
*_, U = qz(
H,
J,
sort="iuc",
output="complex" if is_complex else "real",
return_eigenvalues=False,
)
U00 = U[:m, :m]
U10 = U[m:, :m]
UP, UL, UU = lu(U00) # type: ignore[misc]
lhs = solve_triangular(
UL.conj().T,
solve_triangular(UU.conj().T, U10.conj().T, lower=True),
unit_diagonal=True,
)
X = lhs.conj().T @ UP.conj().T
U_sym = U00.conj().T @ U10
norm_U_sym = norm(U_sym, ord=1)
U_sym = U_sym - U_sym.conj().T
sym_threshold = ptm.maximum(np.spacing(1000.0), 0.1 * norm_U_sym)
result = ptb.switch(
norm(U_sym, ord=1) > sym_threshold,
ptb.full_like(X, np.nan),
0.5 * (X + X.conj().T),
)
core_op = SolveDiscreteARE(
inputs=[A_core, B_core, Q_core, R_core],
outputs=[result],
lop_overrides=_lop_solve_discrete_are,
)
return cast(TensorVariable, Blockwise(core_op)(A, B, Q, R))
__all__ = [
"solve_continuous_lyapunov",
......
......@@ -1118,7 +1118,7 @@ def test_solve_discrete_are_grad(add_batch_dim):
atol = 1e-4 if config.floatX == "float32" else 1e-12
utt.verify_grad(
functools.partial(solve_discrete_are, enforce_Q_symmetric=True),
solve_discrete_are,
pt=[a, b, q, r],
rng=rng,
abs_tol=atol,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论