Unverified 提交 fffb84c1 authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Cleanup for Optimal Control Ops (#1045)

* Blockwise optimal linear control ops * Add jax rewrite to eliminate `BilinearSolveDiscreteLyapunov` * set `solve_discrete_lyapunov` method default to bilinear * Appease mypy * restore method dispatching * Use `pt.vectorize` on base `solve_discrete_lyapunov` case * Apply JAX rewrite before canonicalization * Improve tests * Remove useless warning filters * Fix local_blockwise_alloc rewrite The rewrite was squeezing too many dimensions of the alloced value, when this didn't have dummy expand dims to the left. * Fix float32 tests * Test against complex inputs * Appease ViPy (Vieira-py type checking) * Remove condition from `TensorLike` import * Infer dtype from `node.outputs.type.dtype` * Remove unused mypy ignore * Don't manually set dtype of output Revert change to `_solve_discrete_lyapunov` * Set dtype of Op outputs --------- Co-authored-by: 's avatarricardoV94 <ricardo.vieira1994@gmail.com>
上级 dae731d1
...@@ -127,8 +127,8 @@ def local_blockwise_alloc(fgraph, node): ...@@ -127,8 +127,8 @@ def local_blockwise_alloc(fgraph, node):
value, *shape = inp.owner.inputs value, *shape = inp.owner.inputs
# Check what to do with the value of the Alloc # Check what to do with the value of the Alloc
squeezed_value = _squeeze_left(value, batch_ndim) missing_ndim = inp.type.ndim - value.type.ndim
missing_ndim = len(shape) - value.type.ndim squeezed_value = _squeeze_left(value, (batch_ndim - missing_ndim))
if ( if (
(((1,) * missing_ndim + value.type.broadcastable)[batch_ndim:]) (((1,) * missing_ndim + value.type.broadcastable)[batch_ndim:])
!= inp.type.broadcastable[batch_ndim:] != inp.type.broadcastable[batch_ndim:]
......
...@@ -4,9 +4,11 @@ from typing import cast ...@@ -4,9 +4,11 @@ from typing import cast
from pytensor import Variable from pytensor import Variable
from pytensor import tensor as pt from pytensor import tensor as pt
from pytensor.compile import optdb
from pytensor.graph import Apply, FunctionGraph from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.rewriting.basic import ( from pytensor.graph.rewriting.basic import (
copy_stack_trace, copy_stack_trace,
in2out,
node_rewriter, node_rewriter,
) )
from pytensor.scalar.basic import Mul from pytensor.scalar.basic import Mul
...@@ -45,9 +47,11 @@ from pytensor.tensor.slinalg import ( ...@@ -45,9 +47,11 @@ from pytensor.tensor.slinalg import (
Cholesky, Cholesky,
Solve, Solve,
SolveBase, SolveBase,
_bilinear_solve_discrete_lyapunov,
block_diag, block_diag,
cholesky, cholesky,
solve, solve,
solve_discrete_lyapunov,
solve_triangular, solve_triangular,
) )
...@@ -966,3 +970,22 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node): ...@@ -966,3 +970,22 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
non_eye_input = pt.shape_padaxis(non_eye_input, -2) non_eye_input = pt.shape_padaxis(non_eye_input, -2)
return [eye_input * (non_eye_input**0.5)] return [eye_input * (non_eye_input**0.5)]
@node_rewriter([_bilinear_solve_discrete_lyapunov])
def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
"""
Replace BilinearSolveDiscreteLyapunov with a direct computation that is supported by JAX
"""
A, B = (cast(TensorVariable, x) for x in node.inputs)
result = solve_discrete_lyapunov(A, B, method="direct")
return [result]
optdb.register(
"jax_bilinaer_lyapunov_to_direct",
in2out(jax_bilinaer_lyapunov_to_direct),
"jax",
position=0.9, # Run before canonicalization
)
...@@ -2,7 +2,7 @@ import logging ...@@ -2,7 +2,7 @@ import logging
import typing import typing
import warnings import warnings
from functools import reduce from functools import reduce
from typing import TYPE_CHECKING, Literal, cast from typing import Literal, cast
import numpy as np import numpy as np
import scipy.linalg import scipy.linalg
...@@ -11,7 +11,7 @@ import pytensor ...@@ -11,7 +11,7 @@ import pytensor
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.tensor import as_tensor_variable from pytensor.tensor import TensorLike, as_tensor_variable
from pytensor.tensor import basic as ptb from pytensor.tensor import basic as ptb
from pytensor.tensor import math as ptm from pytensor.tensor import math as ptm
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
...@@ -21,9 +21,6 @@ from pytensor.tensor.type import matrix, tensor, vector ...@@ -21,9 +21,6 @@ from pytensor.tensor.type import matrix, tensor, vector
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
if TYPE_CHECKING:
from pytensor.tensor import TensorLike
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -777,7 +774,16 @@ expm = Expm() ...@@ -777,7 +774,16 @@ expm = Expm()
class SolveContinuousLyapunov(Op): class SolveContinuousLyapunov(Op):
"""
Solves a continuous Lyapunov equation, :math:`AX + XA^H = B`, for :math:`X.
Continuous time Lyapunov equations are special cases of Sylvester equations, :math:`AX + XB = C`, and can be solved
efficiently using the Bartels-Stewart algorithm. For more details, see the docstring for
scipy.linalg.solve_continuous_lyapunov
"""
__props__ = () __props__ = ()
gufunc_signature = "(m,m),(m,m)->(m,m)"
def make_node(self, A, B): def make_node(self, A, B):
A = as_tensor_variable(A) A = as_tensor_variable(A)
...@@ -792,7 +798,8 @@ class SolveContinuousLyapunov(Op): ...@@ -792,7 +798,8 @@ class SolveContinuousLyapunov(Op):
(A, B) = inputs (A, B) = inputs
X = output_storage[0] X = output_storage[0]
X[0] = scipy.linalg.solve_continuous_lyapunov(A, B) out_dtype = node.outputs[0].type.dtype
X[0] = scipy.linalg.solve_continuous_lyapunov(A, B).astype(out_dtype)
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
return [shapes[0]] return [shapes[0]]
...@@ -813,7 +820,41 @@ class SolveContinuousLyapunov(Op): ...@@ -813,7 +820,41 @@ class SolveContinuousLyapunov(Op):
return [A_bar, Q_bar] return [A_bar, Q_bar]
_solve_continuous_lyapunov = Blockwise(SolveContinuousLyapunov())
def solve_continuous_lyapunov(A: TensorLike, Q: TensorLike) -> TensorVariable:
"""
Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
Parameters
----------
A: TensorLike
Square matrix of shape ``N x N``.
Q: TensorLike
Square matrix of shape ``N x N``.
Returns
-------
X: TensorVariable
Square matrix of shape ``N x N``
"""
return cast(TensorVariable, _solve_continuous_lyapunov(A, Q))
class BilinearSolveDiscreteLyapunov(Op): class BilinearSolveDiscreteLyapunov(Op):
"""
Solves a discrete lyapunov equation, :math:`AXA^H - X = Q`, for :math:`X.
The solution is computed by first transforming the discrete-time problem into a continuous-time form. The continuous
time lyapunov is a special case of a Sylvester equation, and can be efficiently solved. For more details, see the
docstring for scipy.linalg.solve_discrete_lyapunov
"""
gufunc_signature = "(m,m),(m,m)->(m,m)"
def make_node(self, A, B): def make_node(self, A, B):
A = as_tensor_variable(A) A = as_tensor_variable(A)
B = as_tensor_variable(B) B = as_tensor_variable(B)
...@@ -827,7 +868,10 @@ class BilinearSolveDiscreteLyapunov(Op): ...@@ -827,7 +868,10 @@ class BilinearSolveDiscreteLyapunov(Op):
(A, B) = inputs (A, B) = inputs
X = output_storage[0] X = output_storage[0]
X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear") out_dtype = node.outputs[0].type.dtype
X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear").astype(
out_dtype
)
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
return [shapes[0]] return [shapes[0]]
...@@ -849,46 +893,56 @@ class BilinearSolveDiscreteLyapunov(Op): ...@@ -849,46 +893,56 @@ class BilinearSolveDiscreteLyapunov(Op):
return [A_bar, Q_bar] return [A_bar, Q_bar]
_solve_continuous_lyapunov = SolveContinuousLyapunov() _bilinear_solve_discrete_lyapunov = Blockwise(BilinearSolveDiscreteLyapunov())
_solve_bilinear_direct_lyapunov = cast(typing.Callable, BilinearSolveDiscreteLyapunov())
def _direct_solve_discrete_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable: def _direct_solve_discrete_lyapunov(
A_ = as_tensor_variable(A) A: TensorVariable, Q: TensorVariable
Q_ = as_tensor_variable(Q) ) -> TensorVariable:
r"""
Directly solve the discrete Lyapunov equation :math:`A X A^H - X = Q` using the kronecker method of Magnus and
Neudecker.
This involves constructing and inverting an intermediate matrix :math:`A \otimes A`, with shape :math:`N^2 x N^2`.
As a result, this method scales poorly with the size of :math:`N`, and should be avoided for large :math:`N`.
"""
if "complex" in A_.type.dtype: if A.type.dtype.startswith("complex"):
AA = kron(A_, A_.conj()) AxA = kron(A, A.conj())
else: else:
AA = kron(A_, A_) AxA = kron(A, A)
eye = pt.eye(AxA.shape[-1])
X = solve(pt.eye(AA.shape[0]) - AA, Q_.ravel()) vec_Q = Q.ravel()
return cast(TensorVariable, reshape(X, Q_.shape)) vec_X = solve(eye - AxA, vec_Q, b_ndim=1)
return cast(TensorVariable, reshape(vec_X, A.shape))
def solve_discrete_lyapunov( def solve_discrete_lyapunov(
A: "TensorLike", Q: "TensorLike", method: Literal["direct", "bilinear"] = "direct" A: TensorLike,
Q: TensorLike,
method: Literal["direct", "bilinear"] = "bilinear",
) -> TensorVariable: ) -> TensorVariable:
"""Solve the discrete Lyapunov equation :math:`A X A^H - X = Q`. """Solve the discrete Lyapunov equation :math:`A X A^H - X = Q`.
Parameters Parameters
---------- ----------
A A: TensorLike
Square matrix of shape N x N; must have the same shape as Q Square matrix of shape N x N
Q Q: TensorLike
Square matrix of shape N x N; must have the same shape as A Square matrix of shape N x N
method method: str, one of ``"direct"`` or ``"bilinear"``
Solver method used, one of ``"direct"`` or ``"bilinear"``. ``"direct"`` Solver method used, . ``"direct"`` solves the problem directly via matrix inversion. This has a pure
solves the problem directly via matrix inversion. This has a pure PyTensor implementation and can thus be cross-compiled to supported backends, and should be preferred when
PyTensor implementation and can thus be cross-compiled to supported ``N`` is not large. The direct method scales poorly with the size of ``N``, and the bilinear can be
backends, and should be preferred when ``N`` is not large. The direct
method scales poorly with the size of ``N``, and the bilinear can be
used in these cases. used in these cases.
Returns Returns
------- -------
Square matrix of shape ``N x N``, representing the solution to the X: TensorVariable
Lyapunov equation Square matrix of shape ``N x N``. Solution to the Lyapunov equation
""" """
if method not in ["direct", "bilinear"]: if method not in ["direct", "bilinear"]:
...@@ -896,36 +950,26 @@ def solve_discrete_lyapunov( ...@@ -896,36 +950,26 @@ def solve_discrete_lyapunov(
f'Parameter "method" must be one of "direct" or "bilinear", found {method}' f'Parameter "method" must be one of "direct" or "bilinear", found {method}'
) )
if method == "direct": A = as_tensor_variable(A)
return _direct_solve_discrete_lyapunov(A, Q) Q = as_tensor_variable(Q)
if method == "bilinear":
return cast(TensorVariable, _solve_bilinear_direct_lyapunov(A, Q))
def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable:
"""Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
Parameters
----------
A
Square matrix of shape ``N x N``; must have the same shape as `Q`.
Q
Square matrix of shape ``N x N``; must have the same shape as `A`.
Returns if method == "direct":
------- signature = BilinearSolveDiscreteLyapunov.gufunc_signature
Square matrix of shape ``N x N``, representing the solution to the X = pt.vectorize(_direct_solve_discrete_lyapunov, signature=signature)(A, Q)
Lyapunov equation return cast(TensorVariable, X)
""" elif method == "bilinear":
return cast(TensorVariable, _bilinear_solve_discrete_lyapunov(A, Q))
return cast(TensorVariable, _solve_continuous_lyapunov(A, Q)) else:
raise ValueError(f"Unknown method {method}")
class SolveDiscreteARE(pt.Op): class SolveDiscreteARE(Op):
__props__ = ("enforce_Q_symmetric",) __props__ = ("enforce_Q_symmetric",)
gufunc_signature = "(m,m),(m,n),(m,m),(n,n)->(m,m)"
def __init__(self, enforce_Q_symmetric=False): def __init__(self, enforce_Q_symmetric: bool = False):
self.enforce_Q_symmetric = enforce_Q_symmetric self.enforce_Q_symmetric = enforce_Q_symmetric
def make_node(self, A, B, Q, R): def make_node(self, A, B, Q, R):
...@@ -946,9 +990,8 @@ class SolveDiscreteARE(pt.Op): ...@@ -946,9 +990,8 @@ class SolveDiscreteARE(pt.Op):
if self.enforce_Q_symmetric: if self.enforce_Q_symmetric:
Q = 0.5 * (Q + Q.T) Q = 0.5 * (Q + Q.T)
X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R).astype( out_dtype = node.outputs[0].type.dtype
node.outputs[0].type.dtype X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype)
)
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
return [shapes[0]] return [shapes[0]]
...@@ -960,14 +1003,16 @@ class SolveDiscreteARE(pt.Op): ...@@ -960,14 +1003,16 @@ class SolveDiscreteARE(pt.Op):
(dX,) = output_grads (dX,) = output_grads
X = self(A, B, Q, R) X = self(A, B, Q, R)
K_inner = R + pt.linalg.matrix_dot(B.T, X, B) K_inner = R + matrix_dot(B.T, X, B)
K_inner_inv = pt.linalg.solve(K_inner, pt.eye(R.shape[0]))
K = matrix_dot(K_inner_inv, B.T, X, A) # K_inner is guaranteed to be symmetric, because X and R are symmetric
K_inner_inv_BT = solve(K_inner, B.T, assume_a="sym")
K = matrix_dot(K_inner_inv_BT, X, A)
A_tilde = A - B.dot(K) A_tilde = A - B.dot(K)
dX_symm = 0.5 * (dX + dX.T) dX_symm = 0.5 * (dX + dX.T)
S = solve_discrete_lyapunov(A_tilde, dX_symm).astype(dX.type.dtype) S = solve_discrete_lyapunov(A_tilde, dX_symm)
A_bar = 2 * matrix_dot(X, A_tilde, S) A_bar = 2 * matrix_dot(X, A_tilde, S)
B_bar = -2 * matrix_dot(X, A_tilde, S, K.T) B_bar = -2 * matrix_dot(X, A_tilde, S, K.T)
...@@ -977,30 +1022,45 @@ class SolveDiscreteARE(pt.Op): ...@@ -977,30 +1022,45 @@ class SolveDiscreteARE(pt.Op):
return [A_bar, B_bar, Q_bar, R_bar] return [A_bar, B_bar, Q_bar, R_bar]
def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable: def solve_discrete_are(
A: TensorLike,
B: TensorLike,
Q: TensorLike,
R: TensorLike,
enforce_Q_symmetric: bool = False,
) -> TensorVariable:
""" """
Solve the discrete Algebraic Riccati equation :math:`A^TXA - X - (A^TXB)(R + B^TXB)^{-1}(B^TXA) + Q = 0`. Solve the discrete Algebraic Riccati equation :math:`A^TXA - X - (A^TXB)(R + B^TXB)^{-1}(B^TXA) + Q = 0`.
Discrete-time Algebraic Riccati equations arise in the context of optimal control and filtering problems, as the
solution to Linear-Quadratic Regulators (LQR), Linear-Quadratic-Guassian (LQG) control problems, and as the
steady-state covariance of the Kalman Filter.
Such problems typically have many solutions, but we are generally only interested in the unique *stabilizing*
solution. This stable solution, if it exists, will be returned by this function.
Parameters Parameters
---------- ----------
A: ArrayLike A: TensorLike
Square matrix of shape M x M Square matrix of shape M x M
B: ArrayLike B: TensorLike
Square matrix of shape M x M Square matrix of shape M x M
Q: ArrayLike Q: TensorLike
Symmetric square matrix of shape M x M Symmetric square matrix of shape M x M
R: ArrayLike R: TensorLike
Square matrix of shape N x N Square matrix of shape N x N
enforce_Q_symmetric: bool enforce_Q_symmetric: bool
If True, the provided Q matrix is transformed to 0.5 * (Q + Q.T) to ensure symmetry If True, the provided Q matrix is transformed to 0.5 * (Q + Q.T) to ensure symmetry
Returns Returns
------- -------
X: pt.matrix X: TensorVariable
Square matrix of shape M x M, representing the solution to the DARE Square matrix of shape M x M, representing the solution to the DARE
""" """
return cast(TensorVariable, SolveDiscreteARE(enforce_Q_symmetric)(A, B, Q, R)) return cast(
TensorVariable, Blockwise(SolveDiscreteARE(enforce_Q_symmetric))(A, B, Q, R)
)
def _largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype: def _largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype:
......
from functools import partial
from typing import Literal
import numpy as np import numpy as np
import pytest import pytest
...@@ -194,3 +197,25 @@ def test_jax_eigvalsh(lower): ...@@ -194,3 +197,25 @@ def test_jax_eigvalsh(lower):
None, None,
], ],
) )
@pytest.mark.parametrize("method", ["direct", "bilinear"])
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batch"])
def test_jax_solve_discrete_lyapunov(
method: Literal["direct", "bilinear"], shape: tuple[int]
):
A = pt.tensor(name="A", shape=shape)
B = pt.tensor(name="B", shape=shape)
out = pt_slinalg.solve_discrete_lyapunov(A, B, method=method)
out_fg = FunctionGraph([A, B], [out])
atol = rtol = 1e-8 if config.floatX == "float64" else 1e-3
compare_jax_and_py(
out_fg,
[
np.random.normal(size=shape).astype(config.floatX),
np.random.normal(size=shape).astype(config.floatX),
],
jax_mode="JAX",
assert_fn=partial(np.testing.assert_allclose, atol=atol, rtol=rtol),
)
import functools import functools
import itertools import itertools
from typing import Literal
import numpy as np import numpy as np
import pytest import pytest
...@@ -514,75 +515,133 @@ def test_expm_grad_3(): ...@@ -514,75 +515,133 @@ def test_expm_grad_3():
utt.verify_grad(expm, [A], rng=rng) utt.verify_grad(expm, [A], rng=rng)
def test_solve_discrete_lyapunov_via_direct_real(): def recover_Q(A, X, continuous=True):
N = 5 if continuous:
return A @ X + X @ A.conj().T
else:
return X - A @ X @ A.conj().T
vec_recover_Q = np.vectorize(recover_Q, signature="(m,m),(m,m),()->(m,m)")
@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batch"])
@pytest.mark.parametrize("method", ["direct", "bilinear"])
def test_solve_discrete_lyapunov(
use_complex, shape: tuple[int], method: Literal["direct", "bilinear"]
):
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
a = pt.dmatrix("a") dtype = config.floatX
q = pt.dmatrix("q") if use_complex:
f = function([a, q], [solve_discrete_lyapunov(a, q, method="direct")]) precision = int(dtype[-2:]) # 64 or 32
dtype = f"complex{int(2 * precision)}"
A1, A2 = rng.normal(size=(2, *shape))
Q1, Q2 = rng.normal(size=(2, *shape))
if use_complex:
A = A1 + 1j * A2
Q = Q1 + 1j * Q2
else:
A = A1
Q = Q1
A, Q = A.astype(dtype), Q.astype(dtype)
a = pt.tensor(name="a", shape=shape, dtype=dtype)
q = pt.tensor(name="q", shape=shape, dtype=dtype)
A = rng.normal(size=(N, N)) x = solve_discrete_lyapunov(a, q, method=method)
Q = rng.normal(size=(N, N)) f = function([a, q], x)
X = f(A, Q) X = f(A, Q)
assert np.allclose(A @ X @ A.T - X + Q, 0.0) Q_recovered = vec_recover_Q(A, X, continuous=False)
utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng) atol = rtol = 1e-4 if config.floatX == "float32" else 1e-8
np.testing.assert_allclose(Q_recovered, Q, atol=atol, rtol=rtol)
@pytest.mark.filterwarnings("ignore::UserWarning") @pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
def test_solve_discrete_lyapunov_via_direct_complex(): @pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batch"])
# Conj doesn't have C-op; filter the warning. @pytest.mark.parametrize("method", ["direct", "bilinear"])
def test_solve_discrete_lyapunov_gradient(
use_complex, shape: tuple[int], method: Literal["direct", "bilinear"]
):
if config.floatX == "float32":
pytest.skip(reason="Not enough precision in float32 to get a good gradient")
if use_complex:
pytest.skip(reason="Complex numbers are not supported in the gradient test")
N = 5
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
a = pt.zmatrix() A = rng.normal(size=shape).astype(config.floatX)
q = pt.zmatrix() Q = rng.normal(size=shape).astype(config.floatX)
f = function([a, q], [solve_discrete_lyapunov(a, q, method="direct")])
A = rng.normal(size=(N, N)) + rng.normal(size=(N, N)) * 1j utt.verify_grad(
Q = rng.normal(size=(N, N)) functools.partial(solve_discrete_lyapunov, method=method),
X = f(A, Q) pt=[A, Q],
np.testing.assert_array_less(A @ X @ A.conj().T - X + Q, 1e-12) rng=rng,
)
# TODO: the .conj() method currently does not have a gradient; add this test when gradients are implemented.
# utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng)
def test_solve_discrete_lyapunov_via_bilinear(): @pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batched"])
N = 5 @pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
def test_solve_continuous_lyapunov(shape: tuple[int], use_complex: bool):
dtype = config.floatX
if use_complex and dtype == "float32":
pytest.skip(
"Not enough precision in complex64 to do schur decomposition "
"(ill-conditioned matrix errors arise)"
)
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
a = pt.dmatrix()
q = pt.dmatrix()
f = function([a, q], [solve_discrete_lyapunov(a, q, method="bilinear")])
A = rng.normal(size=(N, N)) if use_complex:
Q = rng.normal(size=(N, N)) precision = int(dtype[-2:]) # 64 or 32
dtype = f"complex{int(2 * precision)}"
X = f(A, Q) A1, A2 = rng.normal(size=(2, *shape))
Q1, Q2 = rng.normal(size=(2, *shape))
np.testing.assert_array_less(A @ X @ A.conj().T - X + Q, 1e-12) if use_complex:
utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng) A = A1 + 1j * A2
Q = Q1 + 1j * Q2
else:
A = A1
Q = Q1
A, Q = A.astype(dtype), Q.astype(dtype)
def test_solve_continuous_lyapunov(): a = pt.tensor(name="a", shape=shape, dtype=dtype)
N = 5 q = pt.tensor(name="q", shape=shape, dtype=dtype)
rng = np.random.default_rng(utt.fetch_seed()) x = solve_continuous_lyapunov(a, q)
a = pt.dmatrix()
q = pt.dmatrix() f = function([a, q], x)
f = function([a, q], [solve_continuous_lyapunov(a, q)])
A = rng.normal(size=(N, N))
Q = rng.normal(size=(N, N))
X = f(A, Q) X = f(A, Q)
Q_recovered = A @ X + X @ A.conj().T Q_recovered = vec_recover_Q(A, X, continuous=True)
atol = rtol = 1e-2 if config.floatX == "float32" else 1e-8
np.testing.assert_allclose(Q_recovered.squeeze(), Q, atol=atol, rtol=rtol)
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batched"])
@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
def test_solve_continuous_lyapunov_grad(shape: tuple[int], use_complex):
if config.floatX == "float32":
pytest.skip(reason="Not enough precision in float32 to get a good gradient")
if use_complex:
pytest.skip(reason="Complex numbers are not supported in the gradient test")
rng = np.random.default_rng(utt.fetch_seed())
A = rng.normal(size=shape).astype(config.floatX)
Q = rng.normal(size=shape).astype(config.floatX)
np.testing.assert_allclose(Q_recovered.squeeze(), Q)
utt.verify_grad(solve_continuous_lyapunov, pt=[A, Q], rng=rng) utt.verify_grad(solve_continuous_lyapunov, pt=[A, Q], rng=rng)
def test_solve_discrete_are_forward(): @pytest.mark.parametrize("add_batch_dim", [False, True])
def test_solve_discrete_are_forward(add_batch_dim):
# TEST CASE 4 : darex #1 -- taken from Scipy tests # TEST CASE 4 : darex #1 -- taken from Scipy tests
a, b, q, r = ( a, b, q, r = (
np.array([[4, 3], [-4.5, -3.5]]), np.array([[4, 3], [-4.5, -3.5]]),
...@@ -590,29 +649,39 @@ def test_solve_discrete_are_forward(): ...@@ -590,29 +649,39 @@ def test_solve_discrete_are_forward():
np.array([[9, 6], [6, 4]]), np.array([[9, 6], [6, 4]]),
np.array([[1]]), np.array([[1]]),
) )
a, b, q, r = (x.astype(config.floatX) for x in [a, b, q, r]) if add_batch_dim:
a, b, q, r = (np.stack([x] * 5) for x in [a, b, q, r])
x = solve_discrete_are(a, b, q, r).eval() a, b, q, r = (pt.as_tensor_variable(x).astype(config.floatX) for x in [a, b, q, r])
res = a.T.dot(x.dot(a)) - x + q
res -= ( x = solve_discrete_are(a, b, q, r)
a.conj()
.T.dot(x.dot(b)) def eval_fun(a, b, q, r, x):
.dot(np.linalg.solve(r + b.conj().T.dot(x.dot(b)), b.T).dot(x.dot(a))) term_1 = a.T @ x @ a
) term_2 = a.T @ x @ b
term_3 = pt.linalg.solve(r + b.T @ x @ b, b.T) @ x @ a
return term_1 - x - term_2 @ term_3 + q
res = pt.vectorize(eval_fun, "(m,m),(m,n),(m,m),(n,n),(m,m)->(m,m)")(a, b, q, r, x)
res_np = res.eval()
atol = 1e-4 if config.floatX == "float32" else 1e-12 atol = 1e-4 if config.floatX == "float32" else 1e-12
np.testing.assert_allclose(res, np.zeros_like(res), atol=atol) np.testing.assert_allclose(res_np, np.zeros_like(res_np), atol=atol)
def test_solve_discrete_are_grad(): @pytest.mark.parametrize("add_batch_dim", [False, True])
def test_solve_discrete_are_grad(add_batch_dim):
a, b, q, r = ( a, b, q, r = (
np.array([[4, 3], [-4.5, -3.5]]), np.array([[4, 3], [-4.5, -3.5]]),
np.array([[1], [-1]]), np.array([[1], [-1]]),
np.array([[9, 6], [6, 4]]), np.array([[9, 6], [6, 4]]),
np.array([[1]]), np.array([[1]]),
) )
a, b, q, r = (x.astype(config.floatX) for x in [a, b, q, r]) if add_batch_dim:
a, b, q, r = (np.stack([x] * 5) for x in [a, b, q, r])
a, b, q, r = (x.astype(config.floatX) for x in [a, b, q, r])
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
# TODO: Is there a "theoretically motivated" value to use here? I pulled 1e-4 out of a hat # TODO: Is there a "theoretically motivated" value to use here? I pulled 1e-4 out of a hat
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论