提交 a3bf6bb6 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Jesse Grabowski

Add TRSYL Op, refactor linear control Ops

上级 afabe83f
...@@ -55,8 +55,8 @@ from pytensor.tensor.slinalg import ( ...@@ -55,8 +55,8 @@ from pytensor.tensor.slinalg import (
LUFactor, LUFactor,
Solve, Solve,
SolveBase, SolveBase,
SolveBilinearDiscreteLyapunov,
SolveTriangular, SolveTriangular,
_bilinear_solve_discrete_lyapunov,
block_diag, block_diag,
cholesky, cholesky,
solve, solve,
...@@ -1045,10 +1045,10 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node): ...@@ -1045,10 +1045,10 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
return [eye_input * (non_eye_input**0.5)] return [eye_input * (non_eye_input**0.5)]
@node_rewriter([_bilinear_solve_discrete_lyapunov]) @node_rewriter([SolveBilinearDiscreteLyapunov])
def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply): def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
""" """
Replace BilinearSolveDiscreteLyapunov with a direct computation that is supported by JAX Replace SolveBilinearDiscreteLyapunov with a direct computation that is supported by JAX
""" """
A, B = (cast(TensorVariable, x) for x in node.inputs) A, B = (cast(TensorVariable, x) for x in node.inputs)
result = solve_discrete_lyapunov(A, B, method="direct") result = solve_discrete_lyapunov(A, B, method="direct")
......
...@@ -11,6 +11,7 @@ from scipy.linalg import get_lapack_funcs ...@@ -11,6 +11,7 @@ from scipy.linalg import get_lapack_funcs
import pytensor import pytensor
from pytensor import ifelse from pytensor import ifelse
from pytensor import tensor as pt from pytensor import tensor as pt
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType, disconnected_type from pytensor.gradient import DisconnectedType, disconnected_type
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
from pytensor.graph.op import Op from pytensor.graph.op import Op
...@@ -21,6 +22,7 @@ from pytensor.tensor import math as ptm ...@@ -21,6 +22,7 @@ from pytensor.tensor import math as ptm
from pytensor.tensor.basic import as_tensor_variable, diagonal from pytensor.tensor.basic import as_tensor_variable, diagonal
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.nlinalg import kron, matrix_dot from pytensor.tensor.nlinalg import kron, matrix_dot
from pytensor.tensor.reshape import join_dims
from pytensor.tensor.shape import reshape from pytensor.tensor.shape import reshape
from pytensor.tensor.type import matrix, tensor, vector from pytensor.tensor.type import matrix, tensor, vector
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
...@@ -1296,151 +1298,188 @@ class Expm(Op): ...@@ -1296,151 +1298,188 @@ class Expm(Op):
expm = Blockwise(Expm()) expm = Blockwise(Expm())
class SolveContinuousLyapunov(Op): class TRSYL(Op):
""" """
Solves a continuous Lyapunov equation, :math:`AX + XA^H = B`, for :math:`X. Wrapper around LAPACK's `trsyl` function to solve the Sylvester equation:
Continuous time Lyapunov equations are special cases of Sylvester equations, :math:`AX + XB = C`, and can be solved op(A) @ X + X @ op(B) = alpha * C
efficiently using the Bartels-Stewart algorithm. For more details, see the docstring for
scipy.linalg.solve_continuous_lyapunov Where `op(A)` is either `A` or `A^T`, depending on the options passed to `trsyl`. A and B must be
in Schur canonical form: block upper triangular matrices with 1x1 and 2x2 blocks on the diagonal;
each 2x2 diagonal block has its diagonal elements equal and its off-diagonal elements opposite in sign.
This Op is not public facing. Instead, it is intended to be used as a building block for higher-level
linear control solvers, such as `SolveSylvester` and `SolveContinuousLyapunov`.
""" """
__props__ = () __props__ = ("overwrite_c",)
gufunc_signature = "(m,m),(m,m)->(m,m)" gufunc_signature = "(m,m),(n,n),(m,n)->(m,n)"
def __init__(self, overwrite_c=False):
self.overwrite_c = overwrite_c
if self.overwrite_c:
self.destroy_map = {0: [2]}
def make_node(self, A, B): def make_node(self, A, B, C):
A = as_tensor_variable(A) A = as_tensor_variable(A)
B = as_tensor_variable(B) B = as_tensor_variable(B)
C = as_tensor_variable(C)
out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype) out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype, C.dtype)
X = pytensor.tensor.matrix(dtype=out_dtype)
return pytensor.graph.basic.Apply(self, [A, B], [X]) output_shape = list(C.type.shape)
if output_shape[0] is None and A.type.shape[0] is not None:
output_shape[0] = A.type.shape[0]
if output_shape[1] is None and B.type.shape[0] is not None:
output_shape[1] = B.type.shape[0]
def perform(self, node, inputs, output_storage): X = tensor(dtype=out_dtype, shape=tuple(output_shape))
(A, B) = inputs
X = output_storage[0] return Apply(self, [A, B, C], [X])
def perform(self, node, inputs, outputs_storage):
(A, B, C) = inputs
X = outputs_storage[0]
out_dtype = node.outputs[0].type.dtype out_dtype = node.outputs[0].type.dtype
X[0] = scipy_linalg.solve_continuous_lyapunov(A, B).astype(out_dtype) (trsyl,) = get_lapack_funcs(("trsyl",), (A, B, C))
if A.size == 0 or B.size == 0:
return np.empty_like(C, dtype=out_dtype)
Y, scale, info = trsyl(A, B, C, overwrite_c=self.overwrite_c)
if info < 0:
return np.full_like(C, np.nan, dtype=out_dtype)
Y *= scale
X[0] = Y
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
return [shapes[0]] return [shapes[2]]
def grad(self, inputs, output_grads): def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
# Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf if not allowed_inplace_inputs:
# Note that they write the equation as AX + XA.H + Q = 0, while scipy uses AX + XA^H = Q, return self
# so minor adjustments need to be made. new_props = self._props_dict() # type: ignore
A, Q = inputs new_props["overwrite_c"] = True
(dX,) = output_grads return type(self)(**new_props)
X = self(A, Q)
S = self(A.conj().T, -dX) # Eq 31, adjusted
A_bar = S.dot(X.conj().T) + S.conj().T.dot(X) def _trsyl(A: TensorLike, B: TensorLike, C: TensorLike) -> TensorVariable:
Q_bar = -S # Eq 29, adjusted A = as_tensor_variable(A)
B = as_tensor_variable(B)
C = as_tensor_variable(C)
return [A_bar, Q_bar] return cast(TensorVariable, Blockwise(TRSYL())(A, B, C))
_solve_continuous_lyapunov = Blockwise(SolveContinuousLyapunov()) class SolveSylvester(OpFromGraph):
"""
Wrapper Op for solving the continuous Sylvester equation :math:`AX + XB = C` for :math:`X`.
"""
gufunc_signature = "(m,m),(n,n),(m,n)->(m,n)"
def solve_continuous_lyapunov(A: TensorLike, Q: TensorLike) -> TensorVariable:
def _lop_solve_continuous_sylvester(inputs, outputs, output_grads):
""" """
Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`. Closed-form gradients for the solution for the Sylvester equation.
Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
Note that these authors write the equation as AP + PB + C = 0. The code here follows scipy notation,
so P = X and C = -Q. This change of notation requires minor adjustment to equations 10 and 11c
"""
A, B, _ = inputs
(dX,) = output_grads
(X,) = outputs
S = solve_sylvester(A.conj().mT, B.conj().mT, -dX) # Eq 10
A_bar = S @ X.conj().mT # Eq 11a
B_bar = X.conj().mT @ S # Eq 11b
Q_bar = -S # Eq 11c
return [A_bar, B_bar, Q_bar]
def solve_sylvester(A: TensorLike, B: TensorLike, Q: TensorLike) -> TensorVariable:
"""
Solve the Sylvester equation :math:`AX + XB = C` for :math:`X`.
Following scipy notation, this function solves the continuous-time Sylvester equation.
Parameters Parameters
---------- ----------
A: TensorLike A: TensorLike
Square matrix of shape ``M x M``.
B: TensorLike
Square matrix of shape ``N x N``. Square matrix of shape ``N x N``.
Q: TensorLike Q: TensorLike
Square matrix of shape ``N x N``. Matrix of shape ``M x N``.
Returns Returns
------- -------
X: TensorVariable X: TensorVariable
Square matrix of shape ``N x N`` Matrix of shape ``M x N``.
"""
return cast(TensorVariable, _solve_continuous_lyapunov(A, Q))
class BilinearSolveDiscreteLyapunov(Op):
"""
Solves a discrete lyapunov equation, :math:`AXA^H - X = Q`, for :math:`X.
The solution is computed by first transforming the discrete-time problem into a continuous-time form. The continuous
time lyapunov is a special case of a Sylvester equation, and can be efficiently solved. For more details, see the
docstring for scipy.linalg.solve_discrete_lyapunov
""" """
gufunc_signature = "(m,m),(m,m)->(m,m)"
def make_node(self, A, B):
A = as_tensor_variable(A) A = as_tensor_variable(A)
B = as_tensor_variable(B) B = as_tensor_variable(B)
Q = as_tensor_variable(Q)
out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype) A_matrix = pt.matrix(dtype=A.dtype, shape=A.type.shape[-2:])
X = pytensor.tensor.matrix(dtype=out_dtype) B_matrix = pt.matrix(dtype=B.dtype, shape=B.type.shape[-2:])
Q_matrix = pt.matrix(dtype=Q.dtype, shape=Q.type.shape[-2:])
return pytensor.graph.basic.Apply(self, [A, B], [X]) R, U = schur(A_matrix, output="real")
S, V = schur(B_matrix, output="real")
F = U.conj().mT @ Q_matrix @ V
def perform(self, node, inputs, output_storage): Y = _trsyl(R, S, F)
(A, B) = inputs X = U @ Y @ V.conj().mT
X = output_storage[0]
out_dtype = node.outputs[0].type.dtype op = SolveSylvester(
X[0] = scipy_linalg.solve_discrete_lyapunov(A, B, method="bilinear").astype( inputs=[A_matrix, B_matrix, Q_matrix],
out_dtype outputs=[X],
lop_overrides=_lop_solve_continuous_sylvester,
) )
def infer_shape(self, fgraph, node, shapes): return cast(TensorVariable, Blockwise(op)(A, B, Q))
return [shapes[0]]
def grad(self, inputs, output_grads):
# Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
A, Q = inputs
(dX,) = output_grads
X = self(A, Q)
# Eq 41, note that it is not written as a proper Lyapunov equation
S = self(A.conj().T, dX)
A_bar = pytensor.tensor.linalg.matrix_dot(
S, A, X.conj().T
) + pytensor.tensor.linalg.matrix_dot(S.conj().T, A, X)
Q_bar = S
return [A_bar, Q_bar]
def solve_continuous_lyapunov(A: TensorLike, Q: TensorLike) -> TensorVariable:
"""
Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
_bilinear_solve_discrete_lyapunov = Blockwise(BilinearSolveDiscreteLyapunov()) Note that the lyapunov equation is a special case of the Sylvester equation, with :math:`B = A^H`. This function
thus simply calls `solve_sylvester` with the appropriate arguments.
Parameters
----------
A: TensorLike
Square matrix of shape ``N x N``.
Q: TensorLike
Square matrix of shape ``N x N``.
def _direct_solve_discrete_lyapunov( Returns
A: TensorVariable, Q: TensorVariable -------
) -> TensorVariable: X: TensorVariable
r""" Square matrix of shape ``N x N``
Directly solve the discrete Lyapunov equation :math:`A X A^H - X = Q` using the kronecker method of Magnus and
Neudecker.
This involves constructing and inverting an intermediate matrix :math:`A \otimes A`, with shape :math:`N^2 x N^2`.
As a result, this method scales poorly with the size of :math:`N`, and should be avoided for large :math:`N`.
""" """
A = as_tensor_variable(A)
Q = as_tensor_variable(Q)
if A.type.dtype.startswith("complex"): return solve_sylvester(A, A.conj().mT, Q)
AxA = kron(A, A.conj())
else:
AxA = kron(A, A)
eye = pt.eye(AxA.shape[-1])
vec_Q = Q.ravel() class SolveBilinearDiscreteLyapunov(OpFromGraph):
vec_X = solve(eye - AxA, vec_Q, b_ndim=1) """
Wrapper Op for solving the discrete Lyapunov equation :math:`A X A^H - X = Q` for :math:`X`.
return reshape(vec_X, A.shape) Required so that backends that do not support method='bilinear' in `solve_discrete_lyapunov` can be rewritten
to method='direct'.
"""
def solve_discrete_lyapunov( def solve_discrete_lyapunov(
...@@ -1477,12 +1516,29 @@ def solve_discrete_lyapunov( ...@@ -1477,12 +1516,29 @@ def solve_discrete_lyapunov(
Q = as_tensor_variable(Q) Q = as_tensor_variable(Q)
if method == "direct": if method == "direct":
signature = BilinearSolveDiscreteLyapunov.gufunc_signature vec_kron = pt.vectorize(kron, signature="(n,n),(n,n)->(m,m)")
X = pt.vectorize(_direct_solve_discrete_lyapunov, signature=signature)(A, Q) AxA = vec_kron(A, A.conj())
return cast(TensorVariable, X) eye = pt.eye(AxA.shape[-1])
vec_Q = join_dims(Q, start_axis=-2, n_axes=2)
vec_X = solve(eye - AxA, vec_Q, b_ndim=1)
return reshape(vec_X, A.shape)
elif method == "bilinear": elif method == "bilinear":
return cast(TensorVariable, _bilinear_solve_discrete_lyapunov(A, Q)) I = pt.eye(A.shape[-2])
B_1 = A.conj().mT + I
B_2 = A.conj().mT - I
B = solve(B_1.mT, B_2.mT).mT
AI_inv_Q = solve(A + I, Q)
C = 2 * solve(B_1.mT, AI_inv_Q.mT).mT
X = solve_continuous_lyapunov(B.conj().mT, -C)
op = SolveBilinearDiscreteLyapunov(inputs=[A, Q], outputs=[X])
return cast(TensorVariable, op(A, Q))
else: else:
raise ValueError(f"Unknown method {method}") raise ValueError(f"Unknown method {method}")
...@@ -2270,5 +2326,6 @@ __all__ = [ ...@@ -2270,5 +2326,6 @@ __all__ = [
"solve_continuous_lyapunov", "solve_continuous_lyapunov",
"solve_discrete_are", "solve_discrete_are",
"solve_discrete_lyapunov", "solve_discrete_lyapunov",
"solve_sylvester",
"solve_triangular", "solve_triangular",
] ]
...@@ -36,6 +36,7 @@ from pytensor.tensor.slinalg import ( ...@@ -36,6 +36,7 @@ from pytensor.tensor.slinalg import (
solve_continuous_lyapunov, solve_continuous_lyapunov,
solve_discrete_are, solve_discrete_are,
solve_discrete_lyapunov, solve_discrete_lyapunov,
solve_sylvester,
solve_triangular, solve_triangular,
) )
from pytensor.tensor.type import dmatrix, matrix, tensor, vector from pytensor.tensor.type import dmatrix, matrix, tensor, vector
...@@ -916,6 +917,67 @@ def test_expm_grad(mode): ...@@ -916,6 +917,67 @@ def test_expm_grad(mode):
utt.verify_grad(expm, [A], rng=rng, abs_tol=1e-5, rel_tol=1e-5) utt.verify_grad(expm, [A], rng=rng, abs_tol=1e-5, rel_tol=1e-5)
@pytest.mark.parametrize(
"shape, use_complex",
[((5, 5), False), ((5, 5), True), ((5, 5, 5), False)],
ids=["float", "complex", "batch_float"],
)
def test_solve_continuous_sylvester(shape: tuple[int], use_complex: bool):
# batch-complex case got an error from BatchedDot not implemented for complex numbers
rng = np.random.default_rng()
dtype = config.floatX
if use_complex:
dtype = "complex128" if dtype == "float64" else "complex64"
A1, A2 = rng.normal(size=(2, *shape))
B1, B2 = rng.normal(size=(2, *shape))
Q1, Q2 = rng.normal(size=(2, *shape))
if use_complex:
A_val = A1 + 1j * A2
B_val = B1 + 1j * B2
Q_val = Q1 + 1j * Q2
else:
A_val = A1
B_val = B1
Q_val = Q1
A = pt.tensor("A", shape=shape, dtype=dtype)
B = pt.tensor("B", shape=shape, dtype=dtype)
Q = pt.tensor("Q", shape=shape, dtype=dtype)
X = solve_sylvester(A, B, Q)
Q_recovered = A @ X + X @ B
fn = function([A, B, Q], [X, Q_recovered])
X_val, Q_recovered_val = fn(A_val, B_val, Q_val)
vec_sylvester = np.vectorize(
scipy_linalg.solve_sylvester, signature="(m,m),(m,m),(m,m)->(m,m)"
)
np.testing.assert_allclose(Q_recovered_val, Q_val, atol=1e-8, rtol=1e-8)
np.testing.assert_allclose(
X_val, vec_sylvester(A_val, B_val, Q_val), atol=1e-8, rtol=1e-8
)
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batched"])
@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
def test_solve_continuous_sylvester_grad(shape: tuple[int], use_complex):
if config.floatX == "float32":
pytest.skip(reason="Not enough precision in float32 to get a good gradient")
if use_complex:
pytest.skip(reason="Complex numbers are not supported in the gradient test")
rng = np.random.default_rng(utt.fetch_seed())
A = rng.normal(size=shape).astype(config.floatX)
B = rng.normal(size=shape).astype(config.floatX)
Q = rng.normal(size=shape).astype(config.floatX)
utt.verify_grad(solve_sylvester, pt=[A, B, Q], rng=rng)
def recover_Q(A, X, continuous=True): def recover_Q(A, X, continuous=True):
if continuous: if continuous:
return A @ X + X @ A.conj().T return A @ X + X @ A.conj().T
...@@ -985,60 +1047,24 @@ def test_solve_discrete_lyapunov_gradient( ...@@ -985,60 +1047,24 @@ def test_solve_discrete_lyapunov_gradient(
) )
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batched"]) def test_solve_continuous_lyapunov():
@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"]) # solve_continuous_lyapunov just calls solve_sylvester, so extensive tests are not needed.
def test_solve_continuous_lyapunov(shape: tuple[int], use_complex: bool): A = pt.tensor("A", shape=(3, 5, 5))
dtype = config.floatX Q = pt.tensor("Q", shape=(3, 5, 5))
if use_complex and dtype == "float32":
pytest.skip(
"Not enough precision in complex64 to do schur decomposition "
"(ill-conditioned matrix errors arise)"
)
rng = np.random.default_rng(utt.fetch_seed())
if use_complex:
precision = int(dtype[-2:]) # 64 or 32
dtype = f"complex{int(2 * precision)}"
A1, A2 = rng.normal(size=(2, *shape))
Q1, Q2 = rng.normal(size=(2, *shape))
if use_complex:
A = A1 + 1j * A2
Q = Q1 + 1j * Q2
else:
A = A1
Q = Q1
A, Q = A.astype(dtype), Q.astype(dtype)
a = pt.tensor(name="a", shape=shape, dtype=dtype)
q = pt.tensor(name="q", shape=shape, dtype=dtype)
x = solve_continuous_lyapunov(a, q)
f = function([a, q], x)
X = f(A, Q)
Q_recovered = vec_recover_Q(A, X, continuous=True)
atol = rtol = 1e-2 if config.floatX == "float32" else 1e-8
np.testing.assert_allclose(Q_recovered.squeeze(), Q, atol=atol, rtol=rtol)
X = solve_continuous_lyapunov(A, Q)
Q_recovered = A @ X + X @ A.conj().mT
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batched"]) fn = function([A, Q], [X, Q_recovered])
@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
def test_solve_continuous_lyapunov_grad(shape: tuple[int], use_complex):
if config.floatX == "float32":
pytest.skip(reason="Not enough precision in float32 to get a good gradient")
if use_complex:
pytest.skip(reason="Complex numbers are not supported in the gradient test")
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
A = rng.normal(size=shape).astype(config.floatX) A_val = rng.normal(size=(3, 5, 5)).astype(config.floatX)
Q = rng.normal(size=shape).astype(config.floatX) Q_val = rng.normal(size=(3, 5, 5)).astype(config.floatX)
_, Q_recovered_val = fn(A_val, Q_val)
utt.verify_grad(solve_continuous_lyapunov, pt=[A, Q], rng=rng) atol = rtol = 1e-2 if config.floatX == "float32" else 1e-8
np.testing.assert_allclose(Q_recovered_val, Q_val, atol=atol, rtol=rtol)
utt.verify_grad(solve_continuous_lyapunov, pt=[A_val, Q_val], rng=rng)
@pytest.mark.parametrize("add_batch_dim", [False, True]) @pytest.mark.parametrize("add_batch_dim", [False, True])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论