提交 a3bf6bb6 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Jesse Grabowski

Add TRSYL Op, refactor linear control Ops

上级 afabe83f
...@@ -55,8 +55,8 @@ from pytensor.tensor.slinalg import ( ...@@ -55,8 +55,8 @@ from pytensor.tensor.slinalg import (
LUFactor, LUFactor,
Solve, Solve,
SolveBase, SolveBase,
SolveBilinearDiscreteLyapunov,
SolveTriangular, SolveTriangular,
_bilinear_solve_discrete_lyapunov,
block_diag, block_diag,
cholesky, cholesky,
solve, solve,
...@@ -1045,10 +1045,10 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node): ...@@ -1045,10 +1045,10 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
return [eye_input * (non_eye_input**0.5)] return [eye_input * (non_eye_input**0.5)]
@node_rewriter([_bilinear_solve_discrete_lyapunov]) @node_rewriter([SolveBilinearDiscreteLyapunov])
def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply): def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
""" """
Replace BilinearSolveDiscreteLyapunov with a direct computation that is supported by JAX Replace SolveBilinearDiscreteLyapunov with a direct computation that is supported by JAX
""" """
A, B = (cast(TensorVariable, x) for x in node.inputs) A, B = (cast(TensorVariable, x) for x in node.inputs)
result = solve_discrete_lyapunov(A, B, method="direct") result = solve_discrete_lyapunov(A, B, method="direct")
......
...@@ -11,6 +11,7 @@ from scipy.linalg import get_lapack_funcs ...@@ -11,6 +11,7 @@ from scipy.linalg import get_lapack_funcs
import pytensor import pytensor
from pytensor import ifelse from pytensor import ifelse
from pytensor import tensor as pt from pytensor import tensor as pt
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType, disconnected_type from pytensor.gradient import DisconnectedType, disconnected_type
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
from pytensor.graph.op import Op from pytensor.graph.op import Op
...@@ -21,6 +22,7 @@ from pytensor.tensor import math as ptm ...@@ -21,6 +22,7 @@ from pytensor.tensor import math as ptm
from pytensor.tensor.basic import as_tensor_variable, diagonal from pytensor.tensor.basic import as_tensor_variable, diagonal
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.nlinalg import kron, matrix_dot from pytensor.tensor.nlinalg import kron, matrix_dot
from pytensor.tensor.reshape import join_dims
from pytensor.tensor.shape import reshape from pytensor.tensor.shape import reshape
from pytensor.tensor.type import matrix, tensor, vector from pytensor.tensor.type import matrix, tensor, vector
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
...@@ -1296,151 +1298,188 @@ class Expm(Op): ...@@ -1296,151 +1298,188 @@ class Expm(Op):
expm = Blockwise(Expm()) expm = Blockwise(Expm())
class SolveContinuousLyapunov(Op): class TRSYL(Op):
""" """
Solves a continuous Lyapunov equation, :math:`AX + XA^H = B`, for :math:`X. Wrapper around LAPACK's `trsyl` function to solve the Sylvester equation:
Continuous time Lyapunov equations are special cases of Sylvester equations, :math:`AX + XB = C`, and can be solved op(A) @ X + X @ op(B) = alpha * C
efficiently using the Bartels-Stewart algorithm. For more details, see the docstring for
scipy.linalg.solve_continuous_lyapunov Where `op(A)` is either `A` or `A^T`, depending on the options passed to `trsyl`. A and B must be
in Schur canonical form: block upper triangular matrices with 1x1 and 2x2 blocks on the diagonal;
each 2x2 diagonal block has its diagonal elements equal and its off-diagonal elements opposite in sign.
This Op is not public facing. Instead, it is intended to be used as a building block for higher-level
linear control solvers, such as `SolveSylvester` and `SolveContinuousLyapunov`.
""" """
__props__ = () __props__ = ("overwrite_c",)
gufunc_signature = "(m,m),(m,m)->(m,m)" gufunc_signature = "(m,m),(n,n),(m,n)->(m,n)"
def make_node(self, A, B): def __init__(self, overwrite_c=False):
self.overwrite_c = overwrite_c
if self.overwrite_c:
self.destroy_map = {0: [2]}
def make_node(self, A, B, C):
A = as_tensor_variable(A) A = as_tensor_variable(A)
B = as_tensor_variable(B) B = as_tensor_variable(B)
C = as_tensor_variable(C)
out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype) out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype, C.dtype)
X = pytensor.tensor.matrix(dtype=out_dtype)
return pytensor.graph.basic.Apply(self, [A, B], [X]) output_shape = list(C.type.shape)
if output_shape[0] is None and A.type.shape[0] is not None:
output_shape[0] = A.type.shape[0]
if output_shape[1] is None and B.type.shape[0] is not None:
output_shape[1] = B.type.shape[0]
def perform(self, node, inputs, output_storage): X = tensor(dtype=out_dtype, shape=tuple(output_shape))
(A, B) = inputs
X = output_storage[0]
out_dtype = node.outputs[0].type.dtype return Apply(self, [A, B, C], [X])
X[0] = scipy_linalg.solve_continuous_lyapunov(A, B).astype(out_dtype)
def infer_shape(self, fgraph, node, shapes): def perform(self, node, inputs, outputs_storage):
return [shapes[0]] (A, B, C) = inputs
X = outputs_storage[0]
def grad(self, inputs, output_grads): out_dtype = node.outputs[0].type.dtype
# Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf (trsyl,) = get_lapack_funcs(("trsyl",), (A, B, C))
# Note that they write the equation as AX + XA.H + Q = 0, while scipy uses AX + XA^H = Q,
# so minor adjustments need to be made.
A, Q = inputs
(dX,) = output_grads
X = self(A, Q) if A.size == 0 or B.size == 0:
S = self(A.conj().T, -dX) # Eq 31, adjusted return np.empty_like(C, dtype=out_dtype)
A_bar = S.dot(X.conj().T) + S.conj().T.dot(X) Y, scale, info = trsyl(A, B, C, overwrite_c=self.overwrite_c)
Q_bar = -S # Eq 29, adjusted
return [A_bar, Q_bar] if info < 0:
return np.full_like(C, np.nan, dtype=out_dtype)
Y *= scale
X[0] = Y
_solve_continuous_lyapunov = Blockwise(SolveContinuousLyapunov()) def infer_shape(self, fgraph, node, shapes):
return [shapes[2]]
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
if not allowed_inplace_inputs:
return self
new_props = self._props_dict() # type: ignore
new_props["overwrite_c"] = True
return type(self)(**new_props)
def solve_continuous_lyapunov(A: TensorLike, Q: TensorLike) -> TensorVariable:
"""
Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
Parameters def _trsyl(A: TensorLike, B: TensorLike, C: TensorLike) -> TensorVariable:
---------- A = as_tensor_variable(A)
A: TensorLike B = as_tensor_variable(B)
Square matrix of shape ``N x N``. C = as_tensor_variable(C)
Q: TensorLike
Square matrix of shape ``N x N``. return cast(TensorVariable, Blockwise(TRSYL())(A, B, C))
Returns
-------
X: TensorVariable
Square matrix of shape ``N x N``
class SolveSylvester(OpFromGraph):
"""
Wrapper Op for solving the continuous Sylvester equation :math:`AX + XB = C` for :math:`X`.
""" """
return cast(TensorVariable, _solve_continuous_lyapunov(A, Q)) gufunc_signature = "(m,m),(n,n),(m,n)->(m,n)"
class BilinearSolveDiscreteLyapunov(Op): def _lop_solve_continuous_sylvester(inputs, outputs, output_grads):
""" """
Solves a discrete lyapunov equation, :math:`AXA^H - X = Q`, for :math:`X. Closed-form gradients for the solution for the Sylvester equation.
Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
The solution is computed by first transforming the discrete-time problem into a continuous-time form. The continuous Note that these authors write the equation as AP + PB + C = 0. The code here follows scipy notation,
time lyapunov is a special case of a Sylvester equation, and can be efficiently solved. For more details, see the so P = X and C = -Q. This change of notation requires minor adjustment to equations 10 and 11c
docstring for scipy.linalg.solve_discrete_lyapunov
""" """
A, B, _ = inputs
(dX,) = output_grads
(X,) = outputs
gufunc_signature = "(m,m),(m,m)->(m,m)" S = solve_sylvester(A.conj().mT, B.conj().mT, -dX) # Eq 10
A_bar = S @ X.conj().mT # Eq 11a
B_bar = X.conj().mT @ S # Eq 11b
Q_bar = -S # Eq 11c
def make_node(self, A, B): return [A_bar, B_bar, Q_bar]
A = as_tensor_variable(A)
B = as_tensor_variable(B)
out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype)
X = pytensor.tensor.matrix(dtype=out_dtype)
return pytensor.graph.basic.Apply(self, [A, B], [X]) def solve_sylvester(A: TensorLike, B: TensorLike, Q: TensorLike) -> TensorVariable:
"""
Solve the Sylvester equation :math:`AX + XB = C` for :math:`X`.
def perform(self, node, inputs, output_storage): Following scipy notation, this function solves the continuous-time Sylvester equation.
(A, B) = inputs
X = output_storage[0]
out_dtype = node.outputs[0].type.dtype Parameters
X[0] = scipy_linalg.solve_discrete_lyapunov(A, B, method="bilinear").astype( ----------
out_dtype A: TensorLike
) Square matrix of shape ``M x M``.
B: TensorLike
Square matrix of shape ``N x N``.
Q: TensorLike
Matrix of shape ``M x N``.
def infer_shape(self, fgraph, node, shapes): Returns
return [shapes[0]] -------
X: TensorVariable
Matrix of shape ``M x N``.
"""
A = as_tensor_variable(A)
B = as_tensor_variable(B)
Q = as_tensor_variable(Q)
def grad(self, inputs, output_grads): A_matrix = pt.matrix(dtype=A.dtype, shape=A.type.shape[-2:])
# Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf B_matrix = pt.matrix(dtype=B.dtype, shape=B.type.shape[-2:])
A, Q = inputs Q_matrix = pt.matrix(dtype=Q.dtype, shape=Q.type.shape[-2:])
(dX,) = output_grads
X = self(A, Q) R, U = schur(A_matrix, output="real")
S, V = schur(B_matrix, output="real")
F = U.conj().mT @ Q_matrix @ V
# Eq 41, note that it is not written as a proper Lyapunov equation Y = _trsyl(R, S, F)
S = self(A.conj().T, dX) X = U @ Y @ V.conj().mT
A_bar = pytensor.tensor.linalg.matrix_dot( op = SolveSylvester(
S, A, X.conj().T inputs=[A_matrix, B_matrix, Q_matrix],
) + pytensor.tensor.linalg.matrix_dot(S.conj().T, A, X) outputs=[X],
Q_bar = S lop_overrides=_lop_solve_continuous_sylvester,
return [A_bar, Q_bar] )
return cast(TensorVariable, Blockwise(op)(A, B, Q))
_bilinear_solve_discrete_lyapunov = Blockwise(BilinearSolveDiscreteLyapunov())
def solve_continuous_lyapunov(A: TensorLike, Q: TensorLike) -> TensorVariable:
"""
Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
def _direct_solve_discrete_lyapunov( Note that the lyapunov equation is a special case of the Sylvester equation, with :math:`B = A^H`. This function
A: TensorVariable, Q: TensorVariable thus simply calls `solve_sylvester` with the appropriate arguments.
) -> TensorVariable:
r""" Parameters
Directly solve the discrete Lyapunov equation :math:`A X A^H - X = Q` using the kronecker method of Magnus and ----------
Neudecker. A: TensorLike
Square matrix of shape ``N x N``.
Q: TensorLike
Square matrix of shape ``N x N``.
Returns
-------
X: TensorVariable
Square matrix of shape ``N x N``
This involves constructing and inverting an intermediate matrix :math:`A \otimes A`, with shape :math:`N^2 x N^2`.
As a result, this method scales poorly with the size of :math:`N`, and should be avoided for large :math:`N`.
""" """
A = as_tensor_variable(A)
Q = as_tensor_variable(Q)
if A.type.dtype.startswith("complex"): return solve_sylvester(A, A.conj().mT, Q)
AxA = kron(A, A.conj())
else:
AxA = kron(A, A)
eye = pt.eye(AxA.shape[-1])
vec_Q = Q.ravel() class SolveBilinearDiscreteLyapunov(OpFromGraph):
vec_X = solve(eye - AxA, vec_Q, b_ndim=1) """
Wrapper Op for solving the discrete Lyapunov equation :math:`A X A^H - X = Q` for :math:`X`.
return reshape(vec_X, A.shape) Required so that backends that do not support method='bilinear' in `solve_discrete_lyapunov` can be rewritten
to method='direct'.
"""
def solve_discrete_lyapunov( def solve_discrete_lyapunov(
...@@ -1477,12 +1516,29 @@ def solve_discrete_lyapunov( ...@@ -1477,12 +1516,29 @@ def solve_discrete_lyapunov(
Q = as_tensor_variable(Q) Q = as_tensor_variable(Q)
if method == "direct": if method == "direct":
signature = BilinearSolveDiscreteLyapunov.gufunc_signature vec_kron = pt.vectorize(kron, signature="(n,n),(n,n)->(m,m)")
X = pt.vectorize(_direct_solve_discrete_lyapunov, signature=signature)(A, Q) AxA = vec_kron(A, A.conj())
return cast(TensorVariable, X) eye = pt.eye(AxA.shape[-1])
vec_Q = join_dims(Q, start_axis=-2, n_axes=2)
vec_X = solve(eye - AxA, vec_Q, b_ndim=1)
return reshape(vec_X, A.shape)
elif method == "bilinear": elif method == "bilinear":
return cast(TensorVariable, _bilinear_solve_discrete_lyapunov(A, Q)) I = pt.eye(A.shape[-2])
B_1 = A.conj().mT + I
B_2 = A.conj().mT - I
B = solve(B_1.mT, B_2.mT).mT
AI_inv_Q = solve(A + I, Q)
C = 2 * solve(B_1.mT, AI_inv_Q.mT).mT
X = solve_continuous_lyapunov(B.conj().mT, -C)
op = SolveBilinearDiscreteLyapunov(inputs=[A, Q], outputs=[X])
return cast(TensorVariable, op(A, Q))
else: else:
raise ValueError(f"Unknown method {method}") raise ValueError(f"Unknown method {method}")
...@@ -2270,5 +2326,6 @@ __all__ = [ ...@@ -2270,5 +2326,6 @@ __all__ = [
"solve_continuous_lyapunov", "solve_continuous_lyapunov",
"solve_discrete_are", "solve_discrete_are",
"solve_discrete_lyapunov", "solve_discrete_lyapunov",
"solve_sylvester",
"solve_triangular", "solve_triangular",
] ]
...@@ -36,6 +36,7 @@ from pytensor.tensor.slinalg import ( ...@@ -36,6 +36,7 @@ from pytensor.tensor.slinalg import (
solve_continuous_lyapunov, solve_continuous_lyapunov,
solve_discrete_are, solve_discrete_are,
solve_discrete_lyapunov, solve_discrete_lyapunov,
solve_sylvester,
solve_triangular, solve_triangular,
) )
from pytensor.tensor.type import dmatrix, matrix, tensor, vector from pytensor.tensor.type import dmatrix, matrix, tensor, vector
...@@ -916,6 +917,67 @@ def test_expm_grad(mode): ...@@ -916,6 +917,67 @@ def test_expm_grad(mode):
utt.verify_grad(expm, [A], rng=rng, abs_tol=1e-5, rel_tol=1e-5) utt.verify_grad(expm, [A], rng=rng, abs_tol=1e-5, rel_tol=1e-5)
@pytest.mark.parametrize(
"shape, use_complex",
[((5, 5), False), ((5, 5), True), ((5, 5, 5), False)],
ids=["float", "complex", "batch_float"],
)
def test_solve_continuous_sylvester(shape: tuple[int], use_complex: bool):
# batch-complex case got an error from BatchedDot not implemented for complex numbers
rng = np.random.default_rng()
dtype = config.floatX
if use_complex:
dtype = "complex128" if dtype == "float64" else "complex64"
A1, A2 = rng.normal(size=(2, *shape))
B1, B2 = rng.normal(size=(2, *shape))
Q1, Q2 = rng.normal(size=(2, *shape))
if use_complex:
A_val = A1 + 1j * A2
B_val = B1 + 1j * B2
Q_val = Q1 + 1j * Q2
else:
A_val = A1
B_val = B1
Q_val = Q1
A = pt.tensor("A", shape=shape, dtype=dtype)
B = pt.tensor("B", shape=shape, dtype=dtype)
Q = pt.tensor("Q", shape=shape, dtype=dtype)
X = solve_sylvester(A, B, Q)
Q_recovered = A @ X + X @ B
fn = function([A, B, Q], [X, Q_recovered])
X_val, Q_recovered_val = fn(A_val, B_val, Q_val)
vec_sylvester = np.vectorize(
scipy_linalg.solve_sylvester, signature="(m,m),(m,m),(m,m)->(m,m)"
)
np.testing.assert_allclose(Q_recovered_val, Q_val, atol=1e-8, rtol=1e-8)
np.testing.assert_allclose(
X_val, vec_sylvester(A_val, B_val, Q_val), atol=1e-8, rtol=1e-8
)
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batched"])
@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
def test_solve_continuous_sylvester_grad(shape: tuple[int], use_complex):
if config.floatX == "float32":
pytest.skip(reason="Not enough precision in float32 to get a good gradient")
if use_complex:
pytest.skip(reason="Complex numbers are not supported in the gradient test")
rng = np.random.default_rng(utt.fetch_seed())
A = rng.normal(size=shape).astype(config.floatX)
B = rng.normal(size=shape).astype(config.floatX)
Q = rng.normal(size=shape).astype(config.floatX)
utt.verify_grad(solve_sylvester, pt=[A, B, Q], rng=rng)
def recover_Q(A, X, continuous=True): def recover_Q(A, X, continuous=True):
if continuous: if continuous:
return A @ X + X @ A.conj().T return A @ X + X @ A.conj().T
...@@ -985,60 +1047,24 @@ def test_solve_discrete_lyapunov_gradient( ...@@ -985,60 +1047,24 @@ def test_solve_discrete_lyapunov_gradient(
) )
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batched"]) def test_solve_continuous_lyapunov():
@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"]) # solve_continuous_lyapunov just calls solve_sylvester, so extensive tests are not needed.
def test_solve_continuous_lyapunov(shape: tuple[int], use_complex: bool): A = pt.tensor("A", shape=(3, 5, 5))
dtype = config.floatX Q = pt.tensor("Q", shape=(3, 5, 5))
if use_complex and dtype == "float32":
pytest.skip(
"Not enough precision in complex64 to do schur decomposition "
"(ill-conditioned matrix errors arise)"
)
rng = np.random.default_rng(utt.fetch_seed())
if use_complex:
precision = int(dtype[-2:]) # 64 or 32
dtype = f"complex{int(2 * precision)}"
A1, A2 = rng.normal(size=(2, *shape))
Q1, Q2 = rng.normal(size=(2, *shape))
if use_complex:
A = A1 + 1j * A2
Q = Q1 + 1j * Q2
else:
A = A1
Q = Q1
A, Q = A.astype(dtype), Q.astype(dtype)
a = pt.tensor(name="a", shape=shape, dtype=dtype)
q = pt.tensor(name="q", shape=shape, dtype=dtype)
x = solve_continuous_lyapunov(a, q)
f = function([a, q], x)
X = f(A, Q)
Q_recovered = vec_recover_Q(A, X, continuous=True)
atol = rtol = 1e-2 if config.floatX == "float32" else 1e-8
np.testing.assert_allclose(Q_recovered.squeeze(), Q, atol=atol, rtol=rtol)
X = solve_continuous_lyapunov(A, Q)
Q_recovered = A @ X + X @ A.conj().mT
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batched"]) fn = function([A, Q], [X, Q_recovered])
@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
def test_solve_continuous_lyapunov_grad(shape: tuple[int], use_complex):
if config.floatX == "float32":
pytest.skip(reason="Not enough precision in float32 to get a good gradient")
if use_complex:
pytest.skip(reason="Complex numbers are not supported in the gradient test")
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
A = rng.normal(size=shape).astype(config.floatX) A_val = rng.normal(size=(3, 5, 5)).astype(config.floatX)
Q = rng.normal(size=shape).astype(config.floatX) Q_val = rng.normal(size=(3, 5, 5)).astype(config.floatX)
_, Q_recovered_val = fn(A_val, Q_val)
utt.verify_grad(solve_continuous_lyapunov, pt=[A, Q], rng=rng) atol = rtol = 1e-2 if config.floatX == "float32" else 1e-8
np.testing.assert_allclose(Q_recovered_val, Q_val, atol=atol, rtol=rtol)
utt.verify_grad(solve_continuous_lyapunov, pt=[A_val, Q_val], rng=rng)
@pytest.mark.parametrize("add_batch_dim", [False, True]) @pytest.mark.parametrize("add_batch_dim", [False, True])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论