提交 a3bf6bb6 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Jesse Grabowski

Add TRSYL Op, refactor linear control Ops

上级 afabe83f
......@@ -55,8 +55,8 @@ from pytensor.tensor.slinalg import (
LUFactor,
Solve,
SolveBase,
SolveBilinearDiscreteLyapunov,
SolveTriangular,
_bilinear_solve_discrete_lyapunov,
block_diag,
cholesky,
solve,
......@@ -1045,10 +1045,10 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
return [eye_input * (non_eye_input**0.5)]
@node_rewriter([_bilinear_solve_discrete_lyapunov])
@node_rewriter([SolveBilinearDiscreteLyapunov])
def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
"""
Replace BilinearSolveDiscreteLyapunov with a direct computation that is supported by JAX
Replace SolveBilinearDiscreteLyapunov with a direct computation that is supported by JAX
"""
A, B = (cast(TensorVariable, x) for x in node.inputs)
result = solve_discrete_lyapunov(A, B, method="direct")
......
差异被折叠。
......@@ -36,6 +36,7 @@ from pytensor.tensor.slinalg import (
solve_continuous_lyapunov,
solve_discrete_are,
solve_discrete_lyapunov,
solve_sylvester,
solve_triangular,
)
from pytensor.tensor.type import dmatrix, matrix, tensor, vector
......@@ -916,6 +917,67 @@ def test_expm_grad(mode):
utt.verify_grad(expm, [A], rng=rng, abs_tol=1e-5, rel_tol=1e-5)
@pytest.mark.parametrize(
"shape, use_complex",
[((5, 5), False), ((5, 5), True), ((5, 5, 5), False)],
ids=["float", "complex", "batch_float"],
)
def test_solve_continuous_sylvester(shape: tuple[int], use_complex: bool):
# batch-complex case got an error from BatchedDot not implemented for complex numbers
rng = np.random.default_rng()
dtype = config.floatX
if use_complex:
dtype = "complex128" if dtype == "float64" else "complex64"
A1, A2 = rng.normal(size=(2, *shape))
B1, B2 = rng.normal(size=(2, *shape))
Q1, Q2 = rng.normal(size=(2, *shape))
if use_complex:
A_val = A1 + 1j * A2
B_val = B1 + 1j * B2
Q_val = Q1 + 1j * Q2
else:
A_val = A1
B_val = B1
Q_val = Q1
A = pt.tensor("A", shape=shape, dtype=dtype)
B = pt.tensor("B", shape=shape, dtype=dtype)
Q = pt.tensor("Q", shape=shape, dtype=dtype)
X = solve_sylvester(A, B, Q)
Q_recovered = A @ X + X @ B
fn = function([A, B, Q], [X, Q_recovered])
X_val, Q_recovered_val = fn(A_val, B_val, Q_val)
vec_sylvester = np.vectorize(
scipy_linalg.solve_sylvester, signature="(m,m),(m,m),(m,m)->(m,m)"
)
np.testing.assert_allclose(Q_recovered_val, Q_val, atol=1e-8, rtol=1e-8)
np.testing.assert_allclose(
X_val, vec_sylvester(A_val, B_val, Q_val), atol=1e-8, rtol=1e-8
)
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batched"])
@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
def test_solve_continuous_sylvester_grad(shape: tuple[int], use_complex):
if config.floatX == "float32":
pytest.skip(reason="Not enough precision in float32 to get a good gradient")
if use_complex:
pytest.skip(reason="Complex numbers are not supported in the gradient test")
rng = np.random.default_rng(utt.fetch_seed())
A = rng.normal(size=shape).astype(config.floatX)
B = rng.normal(size=shape).astype(config.floatX)
Q = rng.normal(size=shape).astype(config.floatX)
utt.verify_grad(solve_sylvester, pt=[A, B, Q], rng=rng)
def recover_Q(A, X, continuous=True):
if continuous:
return A @ X + X @ A.conj().T
......@@ -985,60 +1047,24 @@ def test_solve_discrete_lyapunov_gradient(
)
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batched"])
@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
def test_solve_continuous_lyapunov(shape: tuple[int], use_complex: bool):
dtype = config.floatX
if use_complex and dtype == "float32":
pytest.skip(
"Not enough precision in complex64 to do schur decomposition "
"(ill-conditioned matrix errors arise)"
)
rng = np.random.default_rng(utt.fetch_seed())
if use_complex:
precision = int(dtype[-2:]) # 64 or 32
dtype = f"complex{int(2 * precision)}"
A1, A2 = rng.normal(size=(2, *shape))
Q1, Q2 = rng.normal(size=(2, *shape))
if use_complex:
A = A1 + 1j * A2
Q = Q1 + 1j * Q2
else:
A = A1
Q = Q1
A, Q = A.astype(dtype), Q.astype(dtype)
a = pt.tensor(name="a", shape=shape, dtype=dtype)
q = pt.tensor(name="q", shape=shape, dtype=dtype)
x = solve_continuous_lyapunov(a, q)
f = function([a, q], x)
X = f(A, Q)
Q_recovered = vec_recover_Q(A, X, continuous=True)
atol = rtol = 1e-2 if config.floatX == "float32" else 1e-8
np.testing.assert_allclose(Q_recovered.squeeze(), Q, atol=atol, rtol=rtol)
def test_solve_continuous_lyapunov():
# solve_continuous_lyapunov just calls solve_sylvester, so extensive tests are not needed.
A = pt.tensor("A", shape=(3, 5, 5))
Q = pt.tensor("Q", shape=(3, 5, 5))
X = solve_continuous_lyapunov(A, Q)
Q_recovered = A @ X + X @ A.conj().mT
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batched"])
@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
def test_solve_continuous_lyapunov_grad(shape: tuple[int], use_complex):
if config.floatX == "float32":
pytest.skip(reason="Not enough precision in float32 to get a good gradient")
if use_complex:
pytest.skip(reason="Complex numbers are not supported in the gradient test")
fn = function([A, Q], [X, Q_recovered])
rng = np.random.default_rng(utt.fetch_seed())
A = rng.normal(size=shape).astype(config.floatX)
Q = rng.normal(size=shape).astype(config.floatX)
A_val = rng.normal(size=(3, 5, 5)).astype(config.floatX)
Q_val = rng.normal(size=(3, 5, 5)).astype(config.floatX)
_, Q_recovered_val = fn(A_val, Q_val)
utt.verify_grad(solve_continuous_lyapunov, pt=[A, Q], rng=rng)
atol = rtol = 1e-2 if config.floatX == "float32" else 1e-8
np.testing.assert_allclose(Q_recovered_val, Q_val, atol=atol, rtol=rtol)
utt.verify_grad(solve_continuous_lyapunov, pt=[A_val, Q_val], rng=rng)
@pytest.mark.parametrize("add_batch_dim", [False, True])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论