提交 55dad3b8 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Jesse Grabowski

Re-implement solve_discrete_are as OpFromGraph

上级 9f911e35
from typing import Literal, cast from typing import Literal, cast
import numpy as np import numpy as np
from scipy import linalg as scipy_linalg
from scipy.linalg import get_lapack_funcs from scipy.linalg import get_lapack_funcs
import pytensor import pytensor
import pytensor.tensor.basic as ptb import pytensor.tensor.basic as ptb
import pytensor.tensor.math as ptm
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.graph import Apply, Op from pytensor.graph import Apply, Op
from pytensor.tensor import TensorLike from pytensor.tensor import TensorLike
from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.basic import as_tensor_variable, zeros
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.functional import vectorize from pytensor.tensor.functional import vectorize
from pytensor.tensor.nlinalg import kron, matrix_dot from pytensor.tensor.nlinalg import kron, matrix_dot, norm
from pytensor.tensor.reshape import join_dims from pytensor.tensor.reshape import join_dims
from pytensor.tensor.shape import reshape from pytensor.tensor.shape import reshape
from pytensor.tensor.slinalg import schur, solve from pytensor.tensor.slinalg import lu, qr, qz, schur, solve, solve_triangular
from pytensor.tensor.type import matrix from pytensor.tensor.type import matrix
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
...@@ -260,43 +260,16 @@ def solve_discrete_lyapunov( ...@@ -260,43 +260,16 @@ def solve_discrete_lyapunov(
raise ValueError(f"Unknown method {method}") raise ValueError(f"Unknown method {method}")
class SolveDiscreteARE(Op): def _lop_solve_discrete_are(inputs, outputs, output_grads):
__props__ = ("enforce_Q_symmetric",) """
gufunc_signature = "(m,m),(m,n),(m,m),(n,n)->(m,m)" Closed-form gradients for the solution for the discrete Algebraic Riccati equation.
def __init__(self, enforce_Q_symmetric: bool = False):
self.enforce_Q_symmetric = enforce_Q_symmetric
def make_node(self, A, B, Q, R):
A = as_tensor_variable(A)
B = as_tensor_variable(B)
Q = as_tensor_variable(Q)
R = as_tensor_variable(R)
out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype, Q.dtype, R.dtype)
X = pytensor.tensor.matrix(dtype=out_dtype)
return pytensor.graph.basic.Apply(self, [A, B, Q, R], [X])
def perform(self, node, inputs, output_storage):
A, B, Q, R = inputs
X = output_storage[0]
if self.enforce_Q_symmetric:
Q = 0.5 * (Q + Q.T)
out_dtype = node.outputs[0].type.dtype
X[0] = scipy_linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype)
def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
def grad(self, inputs, output_grads): Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
# Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf """
A, B, Q, R = inputs A, B, Q, R = inputs
(dX,) = output_grads (dX,) = output_grads
X = self(A, B, Q, R) X = solve_discrete_are(A, B, Q, R)
K_inner = R + matrix_dot(B.T, X, B) K_inner = R + matrix_dot(B.T, X, B)
...@@ -317,12 +290,20 @@ class SolveDiscreteARE(Op): ...@@ -317,12 +290,20 @@ class SolveDiscreteARE(Op):
return [A_bar, B_bar, Q_bar, R_bar] return [A_bar, B_bar, Q_bar, R_bar]
class SolveDiscreteARE(OpFromGraph):
"""
Wrapper Op for solving the discrete Algebraic Riccati equation
:math:`A^TXA - X - (A^TXB)(R + B^TXB)^{-1}(B^TXA) + Q = 0` for :math:`X`.
"""
gufunc_signature = "(m,m),(m,n),(m,m),(n,n)->(m,m)"
def solve_discrete_are( def solve_discrete_are(
A: TensorLike, A: TensorLike,
B: TensorLike, B: TensorLike,
Q: TensorLike, Q: TensorLike,
R: TensorLike, R: TensorLike,
enforce_Q_symmetric: bool = False,
) -> TensorVariable: ) -> TensorVariable:
""" """
Solve the discrete Algebraic Riccati equation :math:`A^TXA - X - (A^TXB)(R + B^TXB)^{-1}(B^TXA) + Q = 0`. Solve the discrete Algebraic Riccati equation :math:`A^TXA - X - (A^TXB)(R + B^TXB)^{-1}(B^TXA) + Q = 0`.
...@@ -344,18 +325,128 @@ def solve_discrete_are( ...@@ -344,18 +325,128 @@ def solve_discrete_are(
Symmetric square matrix of shape M x M Symmetric square matrix of shape M x M
R: TensorLike R: TensorLike
Square matrix of shape N x N Square matrix of shape N x N
enforce_Q_symmetric: bool
If True, the provided Q matrix is transformed to 0.5 * (Q + Q.T) to ensure symmetry
Returns Returns
------- -------
X: TensorVariable X: TensorVariable
Square matrix of shape M x M, representing the solution to the DARE Square matrix of shape M x M, representing the solution to the DARE
Notes
-----
This function is copied from the scipy implementation, found here: https://github.com/scipy/scipy/blob/892baa06054c31bed734423c0f53eaed52b1914b/scipy/linalg/_solvers.py#L687
Notes are also adapted from the scipy documentation.
The equation is solved by forming the extended symplectic matrix pencil as described in [1]_,
:math: `H - \\lambda J`, given by the block matrices:
.. math::
H = \begin{bmatrix} A & 0 & B \\\\
-Q & I & 0 \\\\
0 & 0 & R \\end{bmatrix}
, \\quad
J = \begin{bmatrix} I & 0 & 0 \\\\
0 & A^H & 0 \\\\
0 & -B^H & 0 \\end{bmatrix}
The stable invariant subspace of the pencil is then computed via the QZ decomposition. Failure conditions are
linked to the symmetry of the solution matrix :math:`U_2 U_1^{-1}`, as described in [1]_ and [2]_. When the
solution is not symmetric, NaNs are returned.
[3]_ describes a balancing procedure for Hamiltonian matrices that can improve numerical stability. This procedure
is not yet implemented in this function.
References
----------
.. [1] P. van Dooren , "A Generalized Eigenvalue Approach For Solving
Riccati Equations.", SIAM Journal on Scientific and Statistical
Computing, Vol.2(2), :doi:`10.1137/0902010`
.. [2] A.J. Laub, "A Schur Method for Solving Algebraic Riccati
Equations.", Massachusetts Institute of Technology. Laboratory for
Information and Decision Systems. LIDS-R ; 859. Available online :
http://hdl.handle.net/1721.1/1301
.. [3] P. Benner, "Symplectic Balancing of Hamiltonian Matrices", 2001,
SIAM J. Sci. Comput., 2001, Vol.22(5), :doi:`10.1137/S1064827500367993`
""" """
A, B, Q, R = map(as_tensor_variable, (A, B, Q, R))
is_complex = any(
input_matrix.type.numpy_dtype.kind == "c" for input_matrix in (A, B, Q, R)
)
A_core = matrix(dtype=A.dtype, shape=A.type.shape[-2:])
B_core = matrix(dtype=B.dtype, shape=B.type.shape[-2:])
Q_core = matrix(dtype=Q.dtype, shape=Q.type.shape[-2:])
R_core = matrix(dtype=R.dtype, shape=R.type.shape[-2:])
# Given Zmm = zeros(m, m), Zmn = zeros(m, n), Znm = zeros(n, m), E = eye(m)
# Construct the block matrix H of shape (2m + n, 2m + n):
# H = block([[ A, Zmm, B ],
# [ -Q, E, Zmn],
# [Znm, Znm, R ]])
m, n = B_core.shape[-2:]
H = zeros((2 * m + n, 2 * m + n), dtype=A.dtype)
H = H[:m, :m].set(A_core)
H = H[:m, 2 * m :].set(B_core)
H = H[m : 2 * m, :m].set(-Q_core)
H = H[m : 2 * m, m : 2 * m].set(ptb.eye(m))
H = H[2 * m :, 2 * m :].set(R_core)
# Construct block matrix J of shape (2m + n, 2m + n):
# J = block([[ E, Zmm, Zmn],
# [Zmm, A^H, Zmn],
# [Znm, -B^H, Zmn]])
J = zeros((2 * m + n, 2 * m + n), dtype=A_core.dtype)
J = J[:m, :m].set(ptb.eye(m))
J = J[m : 2 * m, m : 2 * m].set(A_core.conj().T)
J = J[2 * m :, m : 2 * m].set(-B_core.conj().T)
# TODO: Implement balancing procedure from [3]_
Q_of_QR, _ = qr(H[:, -n:], mode="full")
H = Q_of_QR[:, n:].conj().T @ H[:, : 2 * m]
J = Q_of_QR[:, n:].conj().T @ J[:, : 2 * m]
*_, U = qz(
H,
J,
sort="iuc",
output="complex" if is_complex else "real",
return_eigenvalues=False,
)
return cast( U00 = U[:m, :m]
TensorVariable, Blockwise(SolveDiscreteARE(enforce_Q_symmetric))(A, B, Q, R) U10 = U[m:, :m]
UP, UL, UU = lu(U00) # type: ignore[misc]
lhs = solve_triangular(
UL.conj().T,
solve_triangular(UU.conj().T, U10.conj().T, lower=True),
unit_diagonal=True,
) )
X = lhs.conj().T @ UP.conj().T
U_sym = U00.conj().T @ U10
norm_U_sym = norm(U_sym, ord=1)
U_sym = U_sym - U_sym.conj().T
sym_threshold = ptm.maximum(np.spacing(1000.0), 0.1 * norm_U_sym)
result = ptb.switch(
norm(U_sym, ord=1) > sym_threshold,
ptb.full_like(X, np.nan),
0.5 * (X + X.conj().T),
)
core_op = SolveDiscreteARE(
inputs=[A_core, B_core, Q_core, R_core],
outputs=[result],
lop_overrides=_lop_solve_discrete_are,
)
return cast(TensorVariable, Blockwise(core_op)(A, B, Q, R))
__all__ = [ __all__ = [
......
...@@ -1118,7 +1118,7 @@ def test_solve_discrete_are_grad(add_batch_dim): ...@@ -1118,7 +1118,7 @@ def test_solve_discrete_are_grad(add_batch_dim):
atol = 1e-4 if config.floatX == "float32" else 1e-12 atol = 1e-4 if config.floatX == "float32" else 1e-12
utt.verify_grad( utt.verify_grad(
functools.partial(solve_discrete_are, enforce_Q_symmetric=True), solve_discrete_are,
pt=[a, b, q, r], pt=[a, b, q, r],
rng=rng, rng=rng,
abs_tol=atol, abs_tol=atol,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论