Unverified 提交 288a3f34 authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Add `Op` corresponding to `scipy.linalg.solve_discrete_are` (#417)

* Add pytensor function corresponding to * Add pytensor function corresponding to * Cast numpy to node output dtype, rather than depending on config.floatX Change output type hint back to * Cast numpy to node output dtype, rather than depending on config.floatX Change output type hint back to * Use rather than for output equality tests
上级 34eaaa53
import logging import logging
import typing
import warnings import warnings
from typing import TYPE_CHECKING, Literal, Union from typing import TYPE_CHECKING, Literal, Union
...@@ -12,6 +13,7 @@ from pytensor.graph.op import Op ...@@ -12,6 +13,7 @@ from pytensor.graph.op import Op
from pytensor.tensor import as_tensor_variable from pytensor.tensor import as_tensor_variable
from pytensor.tensor import basic as at from pytensor.tensor import basic as at
from pytensor.tensor import math as atm from pytensor.tensor import math as atm
from pytensor.tensor.nlinalg import matrix_dot
from pytensor.tensor.shape import reshape from pytensor.tensor.shape import reshape
from pytensor.tensor.type import matrix, tensor, vector from pytensor.tensor.type import matrix, tensor, vector
from pytensor.tensor.var import TensorVariable from pytensor.tensor.var import TensorVariable
...@@ -321,9 +323,6 @@ class SolveTriangular(SolveBase): ...@@ -321,9 +323,6 @@ class SolveTriangular(SolveBase):
return res return res
solvetriangular = SolveTriangular()
def solve_triangular( def solve_triangular(
a: TensorVariable, a: TensorVariable,
b: TensorVariable, b: TensorVariable,
...@@ -397,9 +396,6 @@ class Solve(SolveBase): ...@@ -397,9 +396,6 @@ class Solve(SolveBase):
) )
solve = Solve()
def solve(a, b, assume_a="gen", lower=False, check_finite=True): def solve(a, b, assume_a="gen", lower=False, check_finite=True):
"""Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix. """Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
...@@ -748,13 +744,9 @@ class BilinearSolveDiscreteLyapunov(Op): ...@@ -748,13 +744,9 @@ class BilinearSolveDiscreteLyapunov(Op):
_solve_continuous_lyapunov = SolveContinuousLyapunov() _solve_continuous_lyapunov = SolveContinuousLyapunov()
_solve_bilinear_direct_lyapunov = BilinearSolveDiscreteLyapunov() _solve_bilinear_direct_lyapunov = typing.cast(
typing.Callable, BilinearSolveDiscreteLyapunov()
)
def iscomplexobj(x):
type_ = x.type
dtype = type_.dtype
return "complex" in dtype
def _direct_solve_discrete_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable: def _direct_solve_discrete_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable:
...@@ -767,7 +759,7 @@ def _direct_solve_discrete_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorV ...@@ -767,7 +759,7 @@ def _direct_solve_discrete_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorV
AA = kron(A_, A_) AA = kron(A_, A_)
X = solve(pt.eye(AA.shape[0]) - AA, Q_.ravel()) X = solve(pt.eye(AA.shape[0]) - AA, Q_.ravel())
return reshape(X, Q_.shape) return typing.cast(TensorVariable, reshape(X, Q_.shape))
def solve_discrete_lyapunov( def solve_discrete_lyapunov(
...@@ -803,7 +795,7 @@ def solve_discrete_lyapunov( ...@@ -803,7 +795,7 @@ def solve_discrete_lyapunov(
if method == "direct": if method == "direct":
return _direct_solve_discrete_lyapunov(A, Q) return _direct_solve_discrete_lyapunov(A, Q)
if method == "bilinear": if method == "bilinear":
return _solve_bilinear_direct_lyapunov(A, Q) return typing.cast(TensorVariable, _solve_bilinear_direct_lyapunov(A, Q))
def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable: def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable:
...@@ -823,7 +815,90 @@ def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariabl ...@@ -823,7 +815,90 @@ def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariabl
""" """
return _solve_continuous_lyapunov(A, Q) return typing.cast(TensorVariable, _solve_continuous_lyapunov(A, Q))
class SolveDiscreteARE(pt.Op):
__props__ = ("enforce_Q_symmetric",)
def __init__(self, enforce_Q_symmetric=False):
self.enforce_Q_symmetric = enforce_Q_symmetric
def make_node(self, A, B, Q, R):
A = as_tensor_variable(A)
B = as_tensor_variable(B)
Q = as_tensor_variable(Q)
R = as_tensor_variable(R)
out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype, Q.dtype, R.dtype)
X = pytensor.tensor.matrix(dtype=out_dtype)
return pytensor.graph.basic.Apply(self, [A, B, Q, R], [X])
def perform(self, node, inputs, output_storage):
A, B, Q, R = inputs
X = output_storage[0]
if self.enforce_Q_symmetric:
Q = 0.5 * (Q + Q.T)
X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R).astype(
node.outputs[0].type.dtype
)
def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
def grad(self, inputs, output_grads):
# Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
A, B, Q, R = inputs
(dX,) = output_grads
X = self(A, B, Q, R)
K_inner = R + pt.linalg.matrix_dot(B.T, X, B)
K_inner_inv = pt.linalg.solve(K_inner, pt.eye(R.shape[0]))
K = matrix_dot(K_inner_inv, B.T, X, A)
A_tilde = A - B.dot(K)
dX_symm = 0.5 * (dX + dX.T)
S = solve_discrete_lyapunov(A_tilde, dX_symm).astype(dX.type.dtype)
A_bar = 2 * matrix_dot(X, A_tilde, S)
B_bar = -2 * matrix_dot(X, A_tilde, S, K.T)
Q_bar = S
R_bar = matrix_dot(K, S, K.T)
return [A_bar, B_bar, Q_bar, R_bar]
def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable:
"""
Solve the discrete Algebraic Riccati equation :math:`A^TXA - X - (A^TXB)(R + B^TXB)^{-1}(B^TXA) + Q = 0`.
Parameters
----------
A: ArrayLike
Square matrix of shape M x M
B: ArrayLike
Square matrix of shape M x M
Q: ArrayLike
Symmetric square matrix of shape M x M
R: ArrayLike
Square matrix of shape N x N
enforce_Q_symmetric: bool
If True, the provided Q matrix is transformed to 0.5 * (Q + Q.T) to ensure symmetry
Returns
-------
X: pt.matrix
Square matrix of shape M x M, representing the solution to the DARE
"""
return typing.cast(
TensorVariable, SolveDiscreteARE(enforce_Q_symmetric)(A, B, Q, R)
)
__all__ = [ __all__ = [
...@@ -832,4 +907,8 @@ __all__ = [ ...@@ -832,4 +907,8 @@ __all__ = [
"eigvalsh", "eigvalsh",
"kron", "kron",
"expm", "expm",
"solve_discrete_lyapunov",
"solve_continuous_lyapunov",
"solve_discrete_are",
"solve_triangular",
] ]
...@@ -22,6 +22,7 @@ from pytensor.tensor.slinalg import ( ...@@ -22,6 +22,7 @@ from pytensor.tensor.slinalg import (
kron, kron,
solve, solve,
solve_continuous_lyapunov, solve_continuous_lyapunov,
solve_discrete_are,
solve_discrete_lyapunov, solve_discrete_lyapunov,
solve_triangular, solve_triangular,
) )
...@@ -532,7 +533,7 @@ class TestKron(utt.InferShapeTester): ...@@ -532,7 +533,7 @@ class TestKron(utt.InferShapeTester):
scipy_val = scipy.linalg.kron(a[np.newaxis, :], b).flatten() scipy_val = scipy.linalg.kron(a[np.newaxis, :], b).flatten()
else: else:
scipy_val = scipy.linalg.kron(a, b) scipy_val = scipy.linalg.kron(a, b)
utt.assert_allclose(out, scipy_val) np.testing.assert_allclose(out, scipy_val)
def test_numpy_2d(self): def test_numpy_2d(self):
for shp0 in [(2, 3)]: for shp0 in [(2, 3)]:
...@@ -564,7 +565,10 @@ def test_solve_discrete_lyapunov_via_direct_real(): ...@@ -564,7 +565,10 @@ def test_solve_discrete_lyapunov_via_direct_real():
utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng) utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng)
@pytest.mark.filterwarnings("ignore::UserWarning")
def test_solve_discrete_lyapunov_via_direct_complex(): def test_solve_discrete_lyapunov_via_direct_complex():
# Conj doesn't have C-op; filter the warning.
N = 5 N = 5
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
a = pt.zmatrix() a = pt.zmatrix()
...@@ -574,7 +578,7 @@ def test_solve_discrete_lyapunov_via_direct_complex(): ...@@ -574,7 +578,7 @@ def test_solve_discrete_lyapunov_via_direct_complex():
A = rng.normal(size=(N, N)) + rng.normal(size=(N, N)) * 1j A = rng.normal(size=(N, N)) + rng.normal(size=(N, N)) * 1j
Q = rng.normal(size=(N, N)) Q = rng.normal(size=(N, N))
X = f(A, Q) X = f(A, Q)
assert np.allclose(A @ X @ A.conj().T - X + Q, 0.0) np.testing.assert_array_less(A @ X @ A.conj().T - X + Q, 1e-12)
# TODO: the .conj() method currently does not have a gradient; add this test when gradients are implemented. # TODO: the .conj() method currently does not have a gradient; add this test when gradients are implemented.
# utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng) # utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng)
...@@ -591,8 +595,8 @@ def test_solve_discrete_lyapunov_via_bilinear(): ...@@ -591,8 +595,8 @@ def test_solve_discrete_lyapunov_via_bilinear():
Q = rng.normal(size=(N, N)) Q = rng.normal(size=(N, N))
X = f(A, Q) X = f(A, Q)
assert np.allclose(A @ X @ A.conj().T - X + Q, 0.0)
np.testing.assert_array_less(A @ X @ A.conj().T - X + Q, 1e-12)
utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng) utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng)
...@@ -607,6 +611,51 @@ def test_solve_continuous_lyapunov(): ...@@ -607,6 +611,51 @@ def test_solve_continuous_lyapunov():
Q = rng.normal(size=(N, N)) Q = rng.normal(size=(N, N))
X = f(A, Q) X = f(A, Q)
assert np.allclose(A @ X + X @ A.conj().T, Q) Q_recovered = A @ X + X @ A.conj().T
np.testing.assert_allclose(Q_recovered.squeeze(), Q)
utt.verify_grad(solve_continuous_lyapunov, pt=[A, Q], rng=rng) utt.verify_grad(solve_continuous_lyapunov, pt=[A, Q], rng=rng)
def test_solve_discrete_are_forward():
# TEST CASE 4 : darex #1 -- taken from Scipy tests
a, b, q, r = (
np.array([[4, 3], [-4.5, -3.5]]),
np.array([[1], [-1]]),
np.array([[9, 6], [6, 4]]),
np.array([[1]]),
)
a, b, q, r = (x.astype(config.floatX) for x in [a, b, q, r])
x = solve_discrete_are(a, b, q, r).eval()
res = a.T.dot(x.dot(a)) - x + q
res -= (
a.conj()
.T.dot(x.dot(b))
.dot(np.linalg.solve(r + b.conj().T.dot(x.dot(b)), b.T).dot(x.dot(a)))
)
atol = 1e-4 if config.floatX == "float32" else 1e-12
np.testing.assert_allclose(res, np.zeros_like(res), atol=atol)
def test_solve_discrete_are_grad():
a, b, q, r = (
np.array([[4, 3], [-4.5, -3.5]]),
np.array([[1], [-1]]),
np.array([[9, 6], [6, 4]]),
np.array([[1]]),
)
a, b, q, r = (x.astype(config.floatX) for x in [a, b, q, r])
rng = np.random.default_rng(utt.fetch_seed())
# TODO: Is there a "theoretically motivated" value to use here? I pulled 1e-4 out of a hat
atol = 1e-4 if config.floatX == "float32" else 1e-12
utt.verify_grad(
functools.partial(solve_discrete_are, enforce_Q_symmetric=True),
pt=[a, b, q, r],
rng=rng,
abs_tol=atol,
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论