Unverified 提交 fffb84c1 authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Cleanup for Optimal Control Ops (#1045)

* Blockwise optimal linear control ops * Add jax rewrite to eliminate `BilinearSolveDiscreteLyapunov` * set `solve_discrete_lyapunov` method default to bilinear * Appease mypy * restore method dispatching * Use `pt.vectorize` on base `solve_discrete_lyapunov` case * Apply JAX rewrite before canonicalization * Improve tests * Remove useless warning filters * Fix local_blockwise_alloc rewrite The rewrite was squeezing too many dimensions of the alloced value, when this didn't have dummy expand dims to the left. * Fix float32 tests * Test against complex inputs * Appease ViPy (Vieira-py type checking) * Remove condition from `TensorLike` import * Infer dtype from `node.outputs.type.dtype` * Remove unused mypy ignore * Don't manually set dtype of output Revert change to `_solve_discrete_lyapunov` * Set dtype of Op outputs --------- Co-authored-by: 's avatarricardoV94 <ricardo.vieira1994@gmail.com>
上级 dae731d1
...@@ -127,8 +127,8 @@ def local_blockwise_alloc(fgraph, node): ...@@ -127,8 +127,8 @@ def local_blockwise_alloc(fgraph, node):
value, *shape = inp.owner.inputs value, *shape = inp.owner.inputs
# Check what to do with the value of the Alloc # Check what to do with the value of the Alloc
squeezed_value = _squeeze_left(value, batch_ndim) missing_ndim = inp.type.ndim - value.type.ndim
missing_ndim = len(shape) - value.type.ndim squeezed_value = _squeeze_left(value, (batch_ndim - missing_ndim))
if ( if (
(((1,) * missing_ndim + value.type.broadcastable)[batch_ndim:]) (((1,) * missing_ndim + value.type.broadcastable)[batch_ndim:])
!= inp.type.broadcastable[batch_ndim:] != inp.type.broadcastable[batch_ndim:]
......
...@@ -4,9 +4,11 @@ from typing import cast ...@@ -4,9 +4,11 @@ from typing import cast
from pytensor import Variable from pytensor import Variable
from pytensor import tensor as pt from pytensor import tensor as pt
from pytensor.compile import optdb
from pytensor.graph import Apply, FunctionGraph from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.rewriting.basic import ( from pytensor.graph.rewriting.basic import (
copy_stack_trace, copy_stack_trace,
in2out,
node_rewriter, node_rewriter,
) )
from pytensor.scalar.basic import Mul from pytensor.scalar.basic import Mul
...@@ -45,9 +47,11 @@ from pytensor.tensor.slinalg import ( ...@@ -45,9 +47,11 @@ from pytensor.tensor.slinalg import (
Cholesky, Cholesky,
Solve, Solve,
SolveBase, SolveBase,
_bilinear_solve_discrete_lyapunov,
block_diag, block_diag,
cholesky, cholesky,
solve, solve,
solve_discrete_lyapunov,
solve_triangular, solve_triangular,
) )
...@@ -966,3 +970,22 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node): ...@@ -966,3 +970,22 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
non_eye_input = pt.shape_padaxis(non_eye_input, -2) non_eye_input = pt.shape_padaxis(non_eye_input, -2)
return [eye_input * (non_eye_input**0.5)] return [eye_input * (non_eye_input**0.5)]
@node_rewriter([_bilinear_solve_discrete_lyapunov])
def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
"""
Replace BilinearSolveDiscreteLyapunov with a direct computation that is supported by JAX
"""
A, B = (cast(TensorVariable, x) for x in node.inputs)
result = solve_discrete_lyapunov(A, B, method="direct")
return [result]
optdb.register(
"jax_bilinaer_lyapunov_to_direct",
in2out(jax_bilinaer_lyapunov_to_direct),
"jax",
position=0.9, # Run before canonicalization
)
差异被折叠。
from functools import partial
from typing import Literal
import numpy as np import numpy as np
import pytest import pytest
...@@ -194,3 +197,25 @@ def test_jax_eigvalsh(lower): ...@@ -194,3 +197,25 @@ def test_jax_eigvalsh(lower):
None, None,
], ],
) )
@pytest.mark.parametrize("method", ["direct", "bilinear"])
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batch"])
def test_jax_solve_discrete_lyapunov(
method: Literal["direct", "bilinear"], shape: tuple[int]
):
A = pt.tensor(name="A", shape=shape)
B = pt.tensor(name="B", shape=shape)
out = pt_slinalg.solve_discrete_lyapunov(A, B, method=method)
out_fg = FunctionGraph([A, B], [out])
atol = rtol = 1e-8 if config.floatX == "float64" else 1e-3
compare_jax_and_py(
out_fg,
[
np.random.normal(size=shape).astype(config.floatX),
np.random.normal(size=shape).astype(config.floatX),
],
jax_mode="JAX",
assert_fn=partial(np.testing.assert_allclose, atol=atol, rtol=rtol),
)
import functools import functools
import itertools import itertools
from typing import Literal
import numpy as np import numpy as np
import pytest import pytest
...@@ -514,75 +515,133 @@ def test_expm_grad_3(): ...@@ -514,75 +515,133 @@ def test_expm_grad_3():
utt.verify_grad(expm, [A], rng=rng) utt.verify_grad(expm, [A], rng=rng)
def test_solve_discrete_lyapunov_via_direct_real(): def recover_Q(A, X, continuous=True):
N = 5 if continuous:
return A @ X + X @ A.conj().T
else:
return X - A @ X @ A.conj().T
vec_recover_Q = np.vectorize(recover_Q, signature="(m,m),(m,m),()->(m,m)")
@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batch"])
@pytest.mark.parametrize("method", ["direct", "bilinear"])
def test_solve_discrete_lyapunov(
use_complex, shape: tuple[int], method: Literal["direct", "bilinear"]
):
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
a = pt.dmatrix("a") dtype = config.floatX
q = pt.dmatrix("q") if use_complex:
f = function([a, q], [solve_discrete_lyapunov(a, q, method="direct")]) precision = int(dtype[-2:]) # 64 or 32
dtype = f"complex{int(2 * precision)}"
A1, A2 = rng.normal(size=(2, *shape))
Q1, Q2 = rng.normal(size=(2, *shape))
if use_complex:
A = A1 + 1j * A2
Q = Q1 + 1j * Q2
else:
A = A1
Q = Q1
A, Q = A.astype(dtype), Q.astype(dtype)
A = rng.normal(size=(N, N)) a = pt.tensor(name="a", shape=shape, dtype=dtype)
Q = rng.normal(size=(N, N)) q = pt.tensor(name="q", shape=shape, dtype=dtype)
x = solve_discrete_lyapunov(a, q, method=method)
f = function([a, q], x)
X = f(A, Q) X = f(A, Q)
assert np.allclose(A @ X @ A.T - X + Q, 0.0) Q_recovered = vec_recover_Q(A, X, continuous=False)
utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng) atol = rtol = 1e-4 if config.floatX == "float32" else 1e-8
np.testing.assert_allclose(Q_recovered, Q, atol=atol, rtol=rtol)
@pytest.mark.filterwarnings("ignore::UserWarning") @pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
def test_solve_discrete_lyapunov_via_direct_complex(): @pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batch"])
# Conj doesn't have C-op; filter the warning. @pytest.mark.parametrize("method", ["direct", "bilinear"])
def test_solve_discrete_lyapunov_gradient(
use_complex, shape: tuple[int], method: Literal["direct", "bilinear"]
):
if config.floatX == "float32":
pytest.skip(reason="Not enough precision in float32 to get a good gradient")
if use_complex:
pytest.skip(reason="Complex numbers are not supported in the gradient test")
N = 5
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
a = pt.zmatrix() A = rng.normal(size=shape).astype(config.floatX)
q = pt.zmatrix() Q = rng.normal(size=shape).astype(config.floatX)
f = function([a, q], [solve_discrete_lyapunov(a, q, method="direct")])
A = rng.normal(size=(N, N)) + rng.normal(size=(N, N)) * 1j
Q = rng.normal(size=(N, N))
X = f(A, Q)
np.testing.assert_array_less(A @ X @ A.conj().T - X + Q, 1e-12)
# TODO: the .conj() method currently does not have a gradient; add this test when gradients are implemented. utt.verify_grad(
# utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng) functools.partial(solve_discrete_lyapunov, method=method),
pt=[A, Q],
rng=rng,
)
def test_solve_discrete_lyapunov_via_bilinear(): @pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batched"])
N = 5 @pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
def test_solve_continuous_lyapunov(shape: tuple[int], use_complex: bool):
dtype = config.floatX
if use_complex and dtype == "float32":
pytest.skip(
"Not enough precision in complex64 to do schur decomposition "
"(ill-conditioned matrix errors arise)"
)
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
a = pt.dmatrix()
q = pt.dmatrix()
f = function([a, q], [solve_discrete_lyapunov(a, q, method="bilinear")])
A = rng.normal(size=(N, N)) if use_complex:
Q = rng.normal(size=(N, N)) precision = int(dtype[-2:]) # 64 or 32
dtype = f"complex{int(2 * precision)}"
X = f(A, Q) A1, A2 = rng.normal(size=(2, *shape))
Q1, Q2 = rng.normal(size=(2, *shape))
np.testing.assert_array_less(A @ X @ A.conj().T - X + Q, 1e-12) if use_complex:
utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng) A = A1 + 1j * A2
Q = Q1 + 1j * Q2
else:
A = A1
Q = Q1
A, Q = A.astype(dtype), Q.astype(dtype)
def test_solve_continuous_lyapunov(): a = pt.tensor(name="a", shape=shape, dtype=dtype)
N = 5 q = pt.tensor(name="q", shape=shape, dtype=dtype)
rng = np.random.default_rng(utt.fetch_seed()) x = solve_continuous_lyapunov(a, q)
a = pt.dmatrix()
q = pt.dmatrix() f = function([a, q], x)
f = function([a, q], [solve_continuous_lyapunov(a, q)])
A = rng.normal(size=(N, N))
Q = rng.normal(size=(N, N))
X = f(A, Q) X = f(A, Q)
Q_recovered = A @ X + X @ A.conj().T Q_recovered = vec_recover_Q(A, X, continuous=True)
atol = rtol = 1e-2 if config.floatX == "float32" else 1e-8
np.testing.assert_allclose(Q_recovered.squeeze(), Q, atol=atol, rtol=rtol)
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batched"])
@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
def test_solve_continuous_lyapunov_grad(shape: tuple[int], use_complex):
if config.floatX == "float32":
pytest.skip(reason="Not enough precision in float32 to get a good gradient")
if use_complex:
pytest.skip(reason="Complex numbers are not supported in the gradient test")
rng = np.random.default_rng(utt.fetch_seed())
A = rng.normal(size=shape).astype(config.floatX)
Q = rng.normal(size=shape).astype(config.floatX)
np.testing.assert_allclose(Q_recovered.squeeze(), Q)
utt.verify_grad(solve_continuous_lyapunov, pt=[A, Q], rng=rng) utt.verify_grad(solve_continuous_lyapunov, pt=[A, Q], rng=rng)
def test_solve_discrete_are_forward(): @pytest.mark.parametrize("add_batch_dim", [False, True])
def test_solve_discrete_are_forward(add_batch_dim):
# TEST CASE 4 : darex #1 -- taken from Scipy tests # TEST CASE 4 : darex #1 -- taken from Scipy tests
a, b, q, r = ( a, b, q, r = (
np.array([[4, 3], [-4.5, -3.5]]), np.array([[4, 3], [-4.5, -3.5]]),
...@@ -590,29 +649,39 @@ def test_solve_discrete_are_forward(): ...@@ -590,29 +649,39 @@ def test_solve_discrete_are_forward():
np.array([[9, 6], [6, 4]]), np.array([[9, 6], [6, 4]]),
np.array([[1]]), np.array([[1]]),
) )
a, b, q, r = (x.astype(config.floatX) for x in [a, b, q, r]) if add_batch_dim:
a, b, q, r = (np.stack([x] * 5) for x in [a, b, q, r])
x = solve_discrete_are(a, b, q, r).eval() a, b, q, r = (pt.as_tensor_variable(x).astype(config.floatX) for x in [a, b, q, r])
res = a.T.dot(x.dot(a)) - x + q
res -= ( x = solve_discrete_are(a, b, q, r)
a.conj()
.T.dot(x.dot(b)) def eval_fun(a, b, q, r, x):
.dot(np.linalg.solve(r + b.conj().T.dot(x.dot(b)), b.T).dot(x.dot(a))) term_1 = a.T @ x @ a
) term_2 = a.T @ x @ b
term_3 = pt.linalg.solve(r + b.T @ x @ b, b.T) @ x @ a
return term_1 - x - term_2 @ term_3 + q
res = pt.vectorize(eval_fun, "(m,m),(m,n),(m,m),(n,n),(m,m)->(m,m)")(a, b, q, r, x)
res_np = res.eval()
atol = 1e-4 if config.floatX == "float32" else 1e-12 atol = 1e-4 if config.floatX == "float32" else 1e-12
np.testing.assert_allclose(res, np.zeros_like(res), atol=atol) np.testing.assert_allclose(res_np, np.zeros_like(res_np), atol=atol)
def test_solve_discrete_are_grad(): @pytest.mark.parametrize("add_batch_dim", [False, True])
def test_solve_discrete_are_grad(add_batch_dim):
a, b, q, r = ( a, b, q, r = (
np.array([[4, 3], [-4.5, -3.5]]), np.array([[4, 3], [-4.5, -3.5]]),
np.array([[1], [-1]]), np.array([[1], [-1]]),
np.array([[9, 6], [6, 4]]), np.array([[9, 6], [6, 4]]),
np.array([[1]]), np.array([[1]]),
) )
a, b, q, r = (x.astype(config.floatX) for x in [a, b, q, r]) if add_batch_dim:
a, b, q, r = (np.stack([x] * 5) for x in [a, b, q, r])
a, b, q, r = (x.astype(config.floatX) for x in [a, b, q, r])
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
# TODO: Is there a "theoretically motivated" value to use here? I pulled 1e-4 out of a hat # TODO: Is there a "theoretically motivated" value to use here? I pulled 1e-4 out of a hat
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论