Unverified 提交 fffb84c1 authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Cleanup for Optimal Control Ops (#1045)

* Blockwise optimal linear control ops * Add jax rewrite to eliminate `BilinearSolveDiscreteLyapunov` * set `solve_discrete_lyapunov` method default to bilinear * Appease mypy * restore method dispatching * Use `pt.vectorize` on base `solve_discrete_lyapunov` case * Apply JAX rewrite before canonicalization * Improve tests * Remove useless warning filters * Fix local_blockwise_alloc rewrite The rewrite was squeezing too many dimensions of the alloced value, when this didn't have dummy expand dims to the left. * Fix float32 tests * Test against complex inputs * Appease ViPy (Vieira-py type checking) * Remove condition from `TensorLike` import * Infer dtype from `node.outputs.type.dtype` * Remove unused mypy ignore * Don't manually set dtype of output Revert change to `_solve_discrete_lyapunov` * Set dtype of Op outputs --------- Co-authored-by: 's avatarricardoV94 <ricardo.vieira1994@gmail.com>
上级 dae731d1
......@@ -127,8 +127,8 @@ def local_blockwise_alloc(fgraph, node):
value, *shape = inp.owner.inputs
# Check what to do with the value of the Alloc
squeezed_value = _squeeze_left(value, batch_ndim)
missing_ndim = len(shape) - value.type.ndim
missing_ndim = inp.type.ndim - value.type.ndim
squeezed_value = _squeeze_left(value, (batch_ndim - missing_ndim))
if (
(((1,) * missing_ndim + value.type.broadcastable)[batch_ndim:])
!= inp.type.broadcastable[batch_ndim:]
......
......@@ -4,9 +4,11 @@ from typing import cast
from pytensor import Variable
from pytensor import tensor as pt
from pytensor.compile import optdb
from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.rewriting.basic import (
copy_stack_trace,
in2out,
node_rewriter,
)
from pytensor.scalar.basic import Mul
......@@ -45,9 +47,11 @@ from pytensor.tensor.slinalg import (
Cholesky,
Solve,
SolveBase,
_bilinear_solve_discrete_lyapunov,
block_diag,
cholesky,
solve,
solve_discrete_lyapunov,
solve_triangular,
)
......@@ -966,3 +970,22 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
non_eye_input = pt.shape_padaxis(non_eye_input, -2)
return [eye_input * (non_eye_input**0.5)]
@node_rewriter([_bilinear_solve_discrete_lyapunov])
def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
"""
Replace BilinearSolveDiscreteLyapunov with a direct computation that is supported by JAX
"""
A, B = (cast(TensorVariable, x) for x in node.inputs)
result = solve_discrete_lyapunov(A, B, method="direct")
return [result]
optdb.register(
"jax_bilinaer_lyapunov_to_direct",
in2out(jax_bilinaer_lyapunov_to_direct),
"jax",
position=0.9, # Run before canonicalization
)
差异被折叠。
from functools import partial
from typing import Literal
import numpy as np
import pytest
......@@ -194,3 +197,25 @@ def test_jax_eigvalsh(lower):
None,
],
)
@pytest.mark.parametrize("method", ["direct", "bilinear"])
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batch"])
def test_jax_solve_discrete_lyapunov(
method: Literal["direct", "bilinear"], shape: tuple[int]
):
A = pt.tensor(name="A", shape=shape)
B = pt.tensor(name="B", shape=shape)
out = pt_slinalg.solve_discrete_lyapunov(A, B, method=method)
out_fg = FunctionGraph([A, B], [out])
atol = rtol = 1e-8 if config.floatX == "float64" else 1e-3
compare_jax_and_py(
out_fg,
[
np.random.normal(size=shape).astype(config.floatX),
np.random.normal(size=shape).astype(config.floatX),
],
jax_mode="JAX",
assert_fn=partial(np.testing.assert_allclose, atol=atol, rtol=rtol),
)
import functools
import itertools
from typing import Literal
import numpy as np
import pytest
......@@ -514,75 +515,133 @@ def test_expm_grad_3():
utt.verify_grad(expm, [A], rng=rng)
def test_solve_discrete_lyapunov_via_direct_real():
N = 5
def recover_Q(A, X, continuous=True):
if continuous:
return A @ X + X @ A.conj().T
else:
return X - A @ X @ A.conj().T
vec_recover_Q = np.vectorize(recover_Q, signature="(m,m),(m,m),()->(m,m)")
@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batch"])
@pytest.mark.parametrize("method", ["direct", "bilinear"])
def test_solve_discrete_lyapunov(
use_complex, shape: tuple[int], method: Literal["direct", "bilinear"]
):
rng = np.random.default_rng(utt.fetch_seed())
a = pt.dmatrix("a")
q = pt.dmatrix("q")
f = function([a, q], [solve_discrete_lyapunov(a, q, method="direct")])
dtype = config.floatX
if use_complex:
precision = int(dtype[-2:]) # 64 or 32
dtype = f"complex{int(2 * precision)}"
A1, A2 = rng.normal(size=(2, *shape))
Q1, Q2 = rng.normal(size=(2, *shape))
if use_complex:
A = A1 + 1j * A2
Q = Q1 + 1j * Q2
else:
A = A1
Q = Q1
A, Q = A.astype(dtype), Q.astype(dtype)
A = rng.normal(size=(N, N))
Q = rng.normal(size=(N, N))
a = pt.tensor(name="a", shape=shape, dtype=dtype)
q = pt.tensor(name="q", shape=shape, dtype=dtype)
x = solve_discrete_lyapunov(a, q, method=method)
f = function([a, q], x)
X = f(A, Q)
assert np.allclose(A @ X @ A.T - X + Q, 0.0)
Q_recovered = vec_recover_Q(A, X, continuous=False)
utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng)
atol = rtol = 1e-4 if config.floatX == "float32" else 1e-8
np.testing.assert_allclose(Q_recovered, Q, atol=atol, rtol=rtol)
@pytest.mark.filterwarnings("ignore::UserWarning")
def test_solve_discrete_lyapunov_via_direct_complex():
# Conj doesn't have C-op; filter the warning.
@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batch"])
@pytest.mark.parametrize("method", ["direct", "bilinear"])
def test_solve_discrete_lyapunov_gradient(
use_complex, shape: tuple[int], method: Literal["direct", "bilinear"]
):
if config.floatX == "float32":
pytest.skip(reason="Not enough precision in float32 to get a good gradient")
if use_complex:
pytest.skip(reason="Complex numbers are not supported in the gradient test")
N = 5
rng = np.random.default_rng(utt.fetch_seed())
a = pt.zmatrix()
q = pt.zmatrix()
f = function([a, q], [solve_discrete_lyapunov(a, q, method="direct")])
A = rng.normal(size=(N, N)) + rng.normal(size=(N, N)) * 1j
Q = rng.normal(size=(N, N))
X = f(A, Q)
np.testing.assert_array_less(A @ X @ A.conj().T - X + Q, 1e-12)
A = rng.normal(size=shape).astype(config.floatX)
Q = rng.normal(size=shape).astype(config.floatX)
# TODO: the .conj() method currently does not have a gradient; add this test when gradients are implemented.
# utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng)
utt.verify_grad(
functools.partial(solve_discrete_lyapunov, method=method),
pt=[A, Q],
rng=rng,
)
def test_solve_discrete_lyapunov_via_bilinear():
N = 5
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batched"])
@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
def test_solve_continuous_lyapunov(shape: tuple[int], use_complex: bool):
dtype = config.floatX
if use_complex and dtype == "float32":
pytest.skip(
"Not enough precision in complex64 to do schur decomposition "
"(ill-conditioned matrix errors arise)"
)
rng = np.random.default_rng(utt.fetch_seed())
a = pt.dmatrix()
q = pt.dmatrix()
f = function([a, q], [solve_discrete_lyapunov(a, q, method="bilinear")])
A = rng.normal(size=(N, N))
Q = rng.normal(size=(N, N))
if use_complex:
precision = int(dtype[-2:]) # 64 or 32
dtype = f"complex{int(2 * precision)}"
X = f(A, Q)
A1, A2 = rng.normal(size=(2, *shape))
Q1, Q2 = rng.normal(size=(2, *shape))
np.testing.assert_array_less(A @ X @ A.conj().T - X + Q, 1e-12)
utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng)
if use_complex:
A = A1 + 1j * A2
Q = Q1 + 1j * Q2
else:
A = A1
Q = Q1
A, Q = A.astype(dtype), Q.astype(dtype)
def test_solve_continuous_lyapunov():
N = 5
rng = np.random.default_rng(utt.fetch_seed())
a = pt.dmatrix()
q = pt.dmatrix()
f = function([a, q], [solve_continuous_lyapunov(a, q)])
a = pt.tensor(name="a", shape=shape, dtype=dtype)
q = pt.tensor(name="q", shape=shape, dtype=dtype)
x = solve_continuous_lyapunov(a, q)
f = function([a, q], x)
A = rng.normal(size=(N, N))
Q = rng.normal(size=(N, N))
X = f(A, Q)
Q_recovered = A @ X + X @ A.conj().T
Q_recovered = vec_recover_Q(A, X, continuous=True)
atol = rtol = 1e-2 if config.floatX == "float32" else 1e-8
np.testing.assert_allclose(Q_recovered.squeeze(), Q, atol=atol, rtol=rtol)
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batched"])
@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
def test_solve_continuous_lyapunov_grad(shape: tuple[int], use_complex):
if config.floatX == "float32":
pytest.skip(reason="Not enough precision in float32 to get a good gradient")
if use_complex:
pytest.skip(reason="Complex numbers are not supported in the gradient test")
rng = np.random.default_rng(utt.fetch_seed())
A = rng.normal(size=shape).astype(config.floatX)
Q = rng.normal(size=shape).astype(config.floatX)
np.testing.assert_allclose(Q_recovered.squeeze(), Q)
utt.verify_grad(solve_continuous_lyapunov, pt=[A, Q], rng=rng)
def test_solve_discrete_are_forward():
@pytest.mark.parametrize("add_batch_dim", [False, True])
def test_solve_discrete_are_forward(add_batch_dim):
# TEST CASE 4 : darex #1 -- taken from Scipy tests
a, b, q, r = (
np.array([[4, 3], [-4.5, -3.5]]),
......@@ -590,29 +649,39 @@ def test_solve_discrete_are_forward():
np.array([[9, 6], [6, 4]]),
np.array([[1]]),
)
a, b, q, r = (x.astype(config.floatX) for x in [a, b, q, r])
if add_batch_dim:
a, b, q, r = (np.stack([x] * 5) for x in [a, b, q, r])
x = solve_discrete_are(a, b, q, r).eval()
res = a.T.dot(x.dot(a)) - x + q
res -= (
a.conj()
.T.dot(x.dot(b))
.dot(np.linalg.solve(r + b.conj().T.dot(x.dot(b)), b.T).dot(x.dot(a)))
)
a, b, q, r = (pt.as_tensor_variable(x).astype(config.floatX) for x in [a, b, q, r])
x = solve_discrete_are(a, b, q, r)
def eval_fun(a, b, q, r, x):
term_1 = a.T @ x @ a
term_2 = a.T @ x @ b
term_3 = pt.linalg.solve(r + b.T @ x @ b, b.T) @ x @ a
return term_1 - x - term_2 @ term_3 + q
res = pt.vectorize(eval_fun, "(m,m),(m,n),(m,m),(n,n),(m,m)->(m,m)")(a, b, q, r, x)
res_np = res.eval()
atol = 1e-4 if config.floatX == "float32" else 1e-12
np.testing.assert_allclose(res, np.zeros_like(res), atol=atol)
np.testing.assert_allclose(res_np, np.zeros_like(res_np), atol=atol)
def test_solve_discrete_are_grad():
@pytest.mark.parametrize("add_batch_dim", [False, True])
def test_solve_discrete_are_grad(add_batch_dim):
a, b, q, r = (
np.array([[4, 3], [-4.5, -3.5]]),
np.array([[1], [-1]]),
np.array([[9, 6], [6, 4]]),
np.array([[1]]),
)
a, b, q, r = (x.astype(config.floatX) for x in [a, b, q, r])
if add_batch_dim:
a, b, q, r = (np.stack([x] * 5) for x in [a, b, q, r])
a, b, q, r = (x.astype(config.floatX) for x in [a, b, q, r])
rng = np.random.default_rng(utt.fetch_seed())
# TODO: Is there a "theoretically motivated" value to use here? I pulled 1e-4 out of a hat
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论