Unverified 提交 fffb84c1 authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Cleanup for Optimal Control Ops (#1045)

* Blockwise optimal linear control ops * Add jax rewrite to eliminate `BilinearSolveDiscreteLyapunov` * set `solve_discrete_lyapunov` method default to bilinear * Appease mypy * restore method dispatching * Use `pt.vectorize` on base `solve_discrete_lyapunov` case * Apply JAX rewrite before canonicalization * Improve tests * Remove useless warning filters * Fix local_blockwise_alloc rewrite The rewrite was squeezing too many dimensions of the alloced value, when this didn't have dummy expand dims to the left. * Fix float32 tests * Test against complex inputs * Appease ViPy (Vieira-py type checking) * Remove condition from `TensorLike` import * Infer dtype from `node.outputs.type.dtype` * Remove unused mypy ignore * Don't manually set dtype of output Revert change to `_solve_discrete_lyapunov` * Set dtype of Op outputs --------- Co-authored-by: 's avatarricardoV94 <ricardo.vieira1994@gmail.com>
上级 dae731d1
......@@ -127,8 +127,8 @@ def local_blockwise_alloc(fgraph, node):
value, *shape = inp.owner.inputs
# Check what to do with the value of the Alloc
squeezed_value = _squeeze_left(value, batch_ndim)
missing_ndim = len(shape) - value.type.ndim
missing_ndim = inp.type.ndim - value.type.ndim
squeezed_value = _squeeze_left(value, (batch_ndim - missing_ndim))
if (
(((1,) * missing_ndim + value.type.broadcastable)[batch_ndim:])
!= inp.type.broadcastable[batch_ndim:]
......
......@@ -4,9 +4,11 @@ from typing import cast
from pytensor import Variable
from pytensor import tensor as pt
from pytensor.compile import optdb
from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.rewriting.basic import (
copy_stack_trace,
in2out,
node_rewriter,
)
from pytensor.scalar.basic import Mul
......@@ -45,9 +47,11 @@ from pytensor.tensor.slinalg import (
Cholesky,
Solve,
SolveBase,
_bilinear_solve_discrete_lyapunov,
block_diag,
cholesky,
solve,
solve_discrete_lyapunov,
solve_triangular,
)
......@@ -966,3 +970,22 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
non_eye_input = pt.shape_padaxis(non_eye_input, -2)
return [eye_input * (non_eye_input**0.5)]
@node_rewriter([_bilinear_solve_discrete_lyapunov])
def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
"""
Replace BilinearSolveDiscreteLyapunov with a direct computation that is supported by JAX
"""
A, B = (cast(TensorVariable, x) for x in node.inputs)
result = solve_discrete_lyapunov(A, B, method="direct")
return [result]
optdb.register(
"jax_bilinaer_lyapunov_to_direct",
in2out(jax_bilinaer_lyapunov_to_direct),
"jax",
position=0.9, # Run before canonicalization
)
......@@ -2,7 +2,7 @@ import logging
import typing
import warnings
from functools import reduce
from typing import TYPE_CHECKING, Literal, cast
from typing import Literal, cast
import numpy as np
import scipy.linalg
......@@ -11,7 +11,7 @@ import pytensor
import pytensor.tensor as pt
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.tensor import as_tensor_variable
from pytensor.tensor import TensorLike, as_tensor_variable
from pytensor.tensor import basic as ptb
from pytensor.tensor import math as ptm
from pytensor.tensor.blockwise import Blockwise
......@@ -21,9 +21,6 @@ from pytensor.tensor.type import matrix, tensor, vector
from pytensor.tensor.variable import TensorVariable
if TYPE_CHECKING:
from pytensor.tensor import TensorLike
logger = logging.getLogger(__name__)
......@@ -777,7 +774,16 @@ expm = Expm()
class SolveContinuousLyapunov(Op):
"""
Solves a continuous Lyapunov equation, :math:`AX + XA^H = B`, for :math:`X.
Continuous time Lyapunov equations are special cases of Sylvester equations, :math:`AX + XB = C`, and can be solved
efficiently using the Bartels-Stewart algorithm. For more details, see the docstring for
scipy.linalg.solve_continuous_lyapunov
"""
__props__ = ()
gufunc_signature = "(m,m),(m,m)->(m,m)"
def make_node(self, A, B):
A = as_tensor_variable(A)
......@@ -792,7 +798,8 @@ class SolveContinuousLyapunov(Op):
(A, B) = inputs
X = output_storage[0]
X[0] = scipy.linalg.solve_continuous_lyapunov(A, B)
out_dtype = node.outputs[0].type.dtype
X[0] = scipy.linalg.solve_continuous_lyapunov(A, B).astype(out_dtype)
def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
......@@ -813,7 +820,41 @@ class SolveContinuousLyapunov(Op):
return [A_bar, Q_bar]
_solve_continuous_lyapunov = Blockwise(SolveContinuousLyapunov())
def solve_continuous_lyapunov(A: TensorLike, Q: TensorLike) -> TensorVariable:
"""
Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
Parameters
----------
A: TensorLike
Square matrix of shape ``N x N``.
Q: TensorLike
Square matrix of shape ``N x N``.
Returns
-------
X: TensorVariable
Square matrix of shape ``N x N``
"""
return cast(TensorVariable, _solve_continuous_lyapunov(A, Q))
class BilinearSolveDiscreteLyapunov(Op):
"""
Solves a discrete lyapunov equation, :math:`AXA^H - X = Q`, for :math:`X.
The solution is computed by first transforming the discrete-time problem into a continuous-time form. The continuous
time lyapunov is a special case of a Sylvester equation, and can be efficiently solved. For more details, see the
docstring for scipy.linalg.solve_discrete_lyapunov
"""
gufunc_signature = "(m,m),(m,m)->(m,m)"
def make_node(self, A, B):
A = as_tensor_variable(A)
B = as_tensor_variable(B)
......@@ -827,7 +868,10 @@ class BilinearSolveDiscreteLyapunov(Op):
(A, B) = inputs
X = output_storage[0]
X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear")
out_dtype = node.outputs[0].type.dtype
X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear").astype(
out_dtype
)
def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
......@@ -849,46 +893,56 @@ class BilinearSolveDiscreteLyapunov(Op):
return [A_bar, Q_bar]
_solve_continuous_lyapunov = SolveContinuousLyapunov()
_solve_bilinear_direct_lyapunov = cast(typing.Callable, BilinearSolveDiscreteLyapunov())
_bilinear_solve_discrete_lyapunov = Blockwise(BilinearSolveDiscreteLyapunov())
def _direct_solve_discrete_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable:
A_ = as_tensor_variable(A)
Q_ = as_tensor_variable(Q)
def _direct_solve_discrete_lyapunov(
A: TensorVariable, Q: TensorVariable
) -> TensorVariable:
r"""
Directly solve the discrete Lyapunov equation :math:`A X A^H - X = Q` using the kronecker method of Magnus and
Neudecker.
This involves constructing and inverting an intermediate matrix :math:`A \otimes A`, with shape :math:`N^2 x N^2`.
As a result, this method scales poorly with the size of :math:`N`, and should be avoided for large :math:`N`.
"""
if "complex" in A_.type.dtype:
AA = kron(A_, A_.conj())
if A.type.dtype.startswith("complex"):
AxA = kron(A, A.conj())
else:
AA = kron(A_, A_)
AxA = kron(A, A)
eye = pt.eye(AxA.shape[-1])
X = solve(pt.eye(AA.shape[0]) - AA, Q_.ravel())
return cast(TensorVariable, reshape(X, Q_.shape))
vec_Q = Q.ravel()
vec_X = solve(eye - AxA, vec_Q, b_ndim=1)
return cast(TensorVariable, reshape(vec_X, A.shape))
def solve_discrete_lyapunov(
A: "TensorLike", Q: "TensorLike", method: Literal["direct", "bilinear"] = "direct"
A: TensorLike,
Q: TensorLike,
method: Literal["direct", "bilinear"] = "bilinear",
) -> TensorVariable:
"""Solve the discrete Lyapunov equation :math:`A X A^H - X = Q`.
Parameters
----------
A
Square matrix of shape N x N; must have the same shape as Q
Q
Square matrix of shape N x N; must have the same shape as A
method
Solver method used, one of ``"direct"`` or ``"bilinear"``. ``"direct"``
solves the problem directly via matrix inversion. This has a pure
PyTensor implementation and can thus be cross-compiled to supported
backends, and should be preferred when ``N`` is not large. The direct
method scales poorly with the size of ``N``, and the bilinear can be
A: TensorLike
Square matrix of shape N x N
Q: TensorLike
Square matrix of shape N x N
method: str, one of ``"direct"`` or ``"bilinear"``
Solver method used, . ``"direct"`` solves the problem directly via matrix inversion. This has a pure
PyTensor implementation and can thus be cross-compiled to supported backends, and should be preferred when
``N`` is not large. The direct method scales poorly with the size of ``N``, and the bilinear can be
used in these cases.
Returns
-------
Square matrix of shape ``N x N``, representing the solution to the
Lyapunov equation
X: TensorVariable
Square matrix of shape ``N x N``. Solution to the Lyapunov equation
"""
if method not in ["direct", "bilinear"]:
......@@ -896,36 +950,26 @@ def solve_discrete_lyapunov(
f'Parameter "method" must be one of "direct" or "bilinear", found {method}'
)
if method == "direct":
return _direct_solve_discrete_lyapunov(A, Q)
if method == "bilinear":
return cast(TensorVariable, _solve_bilinear_direct_lyapunov(A, Q))
def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable:
"""Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
Parameters
----------
A
Square matrix of shape ``N x N``; must have the same shape as `Q`.
Q
Square matrix of shape ``N x N``; must have the same shape as `A`.
A = as_tensor_variable(A)
Q = as_tensor_variable(Q)
Returns
-------
Square matrix of shape ``N x N``, representing the solution to the
Lyapunov equation
if method == "direct":
signature = BilinearSolveDiscreteLyapunov.gufunc_signature
X = pt.vectorize(_direct_solve_discrete_lyapunov, signature=signature)(A, Q)
return cast(TensorVariable, X)
"""
elif method == "bilinear":
return cast(TensorVariable, _bilinear_solve_discrete_lyapunov(A, Q))
return cast(TensorVariable, _solve_continuous_lyapunov(A, Q))
else:
raise ValueError(f"Unknown method {method}")
class SolveDiscreteARE(pt.Op):
class SolveDiscreteARE(Op):
__props__ = ("enforce_Q_symmetric",)
gufunc_signature = "(m,m),(m,n),(m,m),(n,n)->(m,m)"
def __init__(self, enforce_Q_symmetric=False):
def __init__(self, enforce_Q_symmetric: bool = False):
self.enforce_Q_symmetric = enforce_Q_symmetric
def make_node(self, A, B, Q, R):
......@@ -946,9 +990,8 @@ class SolveDiscreteARE(pt.Op):
if self.enforce_Q_symmetric:
Q = 0.5 * (Q + Q.T)
X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R).astype(
node.outputs[0].type.dtype
)
out_dtype = node.outputs[0].type.dtype
X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype)
def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
......@@ -960,14 +1003,16 @@ class SolveDiscreteARE(pt.Op):
(dX,) = output_grads
X = self(A, B, Q, R)
K_inner = R + pt.linalg.matrix_dot(B.T, X, B)
K_inner_inv = pt.linalg.solve(K_inner, pt.eye(R.shape[0]))
K = matrix_dot(K_inner_inv, B.T, X, A)
K_inner = R + matrix_dot(B.T, X, B)
# K_inner is guaranteed to be symmetric, because X and R are symmetric
K_inner_inv_BT = solve(K_inner, B.T, assume_a="sym")
K = matrix_dot(K_inner_inv_BT, X, A)
A_tilde = A - B.dot(K)
dX_symm = 0.5 * (dX + dX.T)
S = solve_discrete_lyapunov(A_tilde, dX_symm).astype(dX.type.dtype)
S = solve_discrete_lyapunov(A_tilde, dX_symm)
A_bar = 2 * matrix_dot(X, A_tilde, S)
B_bar = -2 * matrix_dot(X, A_tilde, S, K.T)
......@@ -977,30 +1022,45 @@ class SolveDiscreteARE(pt.Op):
return [A_bar, B_bar, Q_bar, R_bar]
def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable:
def solve_discrete_are(
A: TensorLike,
B: TensorLike,
Q: TensorLike,
R: TensorLike,
enforce_Q_symmetric: bool = False,
) -> TensorVariable:
"""
Solve the discrete Algebraic Riccati equation :math:`A^TXA - X - (A^TXB)(R + B^TXB)^{-1}(B^TXA) + Q = 0`.
Discrete-time Algebraic Riccati equations arise in the context of optimal control and filtering problems, as the
solution to Linear-Quadratic Regulators (LQR), Linear-Quadratic-Guassian (LQG) control problems, and as the
steady-state covariance of the Kalman Filter.
Such problems typically have many solutions, but we are generally only interested in the unique *stabilizing*
solution. This stable solution, if it exists, will be returned by this function.
Parameters
----------
A: ArrayLike
A: TensorLike
Square matrix of shape M x M
B: ArrayLike
B: TensorLike
Square matrix of shape M x M
Q: ArrayLike
Q: TensorLike
Symmetric square matrix of shape M x M
R: ArrayLike
R: TensorLike
Square matrix of shape N x N
enforce_Q_symmetric: bool
If True, the provided Q matrix is transformed to 0.5 * (Q + Q.T) to ensure symmetry
Returns
-------
X: pt.matrix
X: TensorVariable
Square matrix of shape M x M, representing the solution to the DARE
"""
return cast(TensorVariable, SolveDiscreteARE(enforce_Q_symmetric)(A, B, Q, R))
return cast(
TensorVariable, Blockwise(SolveDiscreteARE(enforce_Q_symmetric))(A, B, Q, R)
)
def _largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype:
......
from functools import partial
from typing import Literal
import numpy as np
import pytest
......@@ -194,3 +197,25 @@ def test_jax_eigvalsh(lower):
None,
],
)
@pytest.mark.parametrize("method", ["direct", "bilinear"])
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batch"])
def test_jax_solve_discrete_lyapunov(
method: Literal["direct", "bilinear"], shape: tuple[int]
):
A = pt.tensor(name="A", shape=shape)
B = pt.tensor(name="B", shape=shape)
out = pt_slinalg.solve_discrete_lyapunov(A, B, method=method)
out_fg = FunctionGraph([A, B], [out])
atol = rtol = 1e-8 if config.floatX == "float64" else 1e-3
compare_jax_and_py(
out_fg,
[
np.random.normal(size=shape).astype(config.floatX),
np.random.normal(size=shape).astype(config.floatX),
],
jax_mode="JAX",
assert_fn=partial(np.testing.assert_allclose, atol=atol, rtol=rtol),
)
import functools
import itertools
from typing import Literal
import numpy as np
import pytest
......@@ -514,75 +515,133 @@ def test_expm_grad_3():
utt.verify_grad(expm, [A], rng=rng)
def test_solve_discrete_lyapunov_via_direct_real():
N = 5
def recover_Q(A, X, continuous=True):
if continuous:
return A @ X + X @ A.conj().T
else:
return X - A @ X @ A.conj().T
vec_recover_Q = np.vectorize(recover_Q, signature="(m,m),(m,m),()->(m,m)")
@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batch"])
@pytest.mark.parametrize("method", ["direct", "bilinear"])
def test_solve_discrete_lyapunov(
use_complex, shape: tuple[int], method: Literal["direct", "bilinear"]
):
rng = np.random.default_rng(utt.fetch_seed())
a = pt.dmatrix("a")
q = pt.dmatrix("q")
f = function([a, q], [solve_discrete_lyapunov(a, q, method="direct")])
dtype = config.floatX
if use_complex:
precision = int(dtype[-2:]) # 64 or 32
dtype = f"complex{int(2 * precision)}"
A1, A2 = rng.normal(size=(2, *shape))
Q1, Q2 = rng.normal(size=(2, *shape))
if use_complex:
A = A1 + 1j * A2
Q = Q1 + 1j * Q2
else:
A = A1
Q = Q1
A, Q = A.astype(dtype), Q.astype(dtype)
a = pt.tensor(name="a", shape=shape, dtype=dtype)
q = pt.tensor(name="q", shape=shape, dtype=dtype)
A = rng.normal(size=(N, N))
Q = rng.normal(size=(N, N))
x = solve_discrete_lyapunov(a, q, method=method)
f = function([a, q], x)
X = f(A, Q)
assert np.allclose(A @ X @ A.T - X + Q, 0.0)
Q_recovered = vec_recover_Q(A, X, continuous=False)
utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng)
atol = rtol = 1e-4 if config.floatX == "float32" else 1e-8
np.testing.assert_allclose(Q_recovered, Q, atol=atol, rtol=rtol)
@pytest.mark.filterwarnings("ignore::UserWarning")
def test_solve_discrete_lyapunov_via_direct_complex():
# Conj doesn't have C-op; filter the warning.
@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batch"])
@pytest.mark.parametrize("method", ["direct", "bilinear"])
def test_solve_discrete_lyapunov_gradient(
use_complex, shape: tuple[int], method: Literal["direct", "bilinear"]
):
if config.floatX == "float32":
pytest.skip(reason="Not enough precision in float32 to get a good gradient")
if use_complex:
pytest.skip(reason="Complex numbers are not supported in the gradient test")
N = 5
rng = np.random.default_rng(utt.fetch_seed())
a = pt.zmatrix()
q = pt.zmatrix()
f = function([a, q], [solve_discrete_lyapunov(a, q, method="direct")])
A = rng.normal(size=shape).astype(config.floatX)
Q = rng.normal(size=shape).astype(config.floatX)
A = rng.normal(size=(N, N)) + rng.normal(size=(N, N)) * 1j
Q = rng.normal(size=(N, N))
X = f(A, Q)
np.testing.assert_array_less(A @ X @ A.conj().T - X + Q, 1e-12)
# TODO: the .conj() method currently does not have a gradient; add this test when gradients are implemented.
# utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng)
utt.verify_grad(
functools.partial(solve_discrete_lyapunov, method=method),
pt=[A, Q],
rng=rng,
)
def test_solve_discrete_lyapunov_via_bilinear():
N = 5
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batched"])
@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
def test_solve_continuous_lyapunov(shape: tuple[int], use_complex: bool):
dtype = config.floatX
if use_complex and dtype == "float32":
pytest.skip(
"Not enough precision in complex64 to do schur decomposition "
"(ill-conditioned matrix errors arise)"
)
rng = np.random.default_rng(utt.fetch_seed())
a = pt.dmatrix()
q = pt.dmatrix()
f = function([a, q], [solve_discrete_lyapunov(a, q, method="bilinear")])
A = rng.normal(size=(N, N))
Q = rng.normal(size=(N, N))
if use_complex:
precision = int(dtype[-2:]) # 64 or 32
dtype = f"complex{int(2 * precision)}"
X = f(A, Q)
A1, A2 = rng.normal(size=(2, *shape))
Q1, Q2 = rng.normal(size=(2, *shape))
np.testing.assert_array_less(A @ X @ A.conj().T - X + Q, 1e-12)
utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng)
if use_complex:
A = A1 + 1j * A2
Q = Q1 + 1j * Q2
else:
A = A1
Q = Q1
A, Q = A.astype(dtype), Q.astype(dtype)
def test_solve_continuous_lyapunov():
N = 5
rng = np.random.default_rng(utt.fetch_seed())
a = pt.dmatrix()
q = pt.dmatrix()
f = function([a, q], [solve_continuous_lyapunov(a, q)])
a = pt.tensor(name="a", shape=shape, dtype=dtype)
q = pt.tensor(name="q", shape=shape, dtype=dtype)
x = solve_continuous_lyapunov(a, q)
f = function([a, q], x)
A = rng.normal(size=(N, N))
Q = rng.normal(size=(N, N))
X = f(A, Q)
Q_recovered = A @ X + X @ A.conj().T
Q_recovered = vec_recover_Q(A, X, continuous=True)
atol = rtol = 1e-2 if config.floatX == "float32" else 1e-8
np.testing.assert_allclose(Q_recovered.squeeze(), Q, atol=atol, rtol=rtol)
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batched"])
@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
def test_solve_continuous_lyapunov_grad(shape: tuple[int], use_complex):
if config.floatX == "float32":
pytest.skip(reason="Not enough precision in float32 to get a good gradient")
if use_complex:
pytest.skip(reason="Complex numbers are not supported in the gradient test")
rng = np.random.default_rng(utt.fetch_seed())
A = rng.normal(size=shape).astype(config.floatX)
Q = rng.normal(size=shape).astype(config.floatX)
np.testing.assert_allclose(Q_recovered.squeeze(), Q)
utt.verify_grad(solve_continuous_lyapunov, pt=[A, Q], rng=rng)
def test_solve_discrete_are_forward():
@pytest.mark.parametrize("add_batch_dim", [False, True])
def test_solve_discrete_are_forward(add_batch_dim):
# TEST CASE 4 : darex #1 -- taken from Scipy tests
a, b, q, r = (
np.array([[4, 3], [-4.5, -3.5]]),
......@@ -590,29 +649,39 @@ def test_solve_discrete_are_forward():
np.array([[9, 6], [6, 4]]),
np.array([[1]]),
)
a, b, q, r = (x.astype(config.floatX) for x in [a, b, q, r])
if add_batch_dim:
a, b, q, r = (np.stack([x] * 5) for x in [a, b, q, r])
x = solve_discrete_are(a, b, q, r).eval()
res = a.T.dot(x.dot(a)) - x + q
res -= (
a.conj()
.T.dot(x.dot(b))
.dot(np.linalg.solve(r + b.conj().T.dot(x.dot(b)), b.T).dot(x.dot(a)))
)
a, b, q, r = (pt.as_tensor_variable(x).astype(config.floatX) for x in [a, b, q, r])
x = solve_discrete_are(a, b, q, r)
def eval_fun(a, b, q, r, x):
term_1 = a.T @ x @ a
term_2 = a.T @ x @ b
term_3 = pt.linalg.solve(r + b.T @ x @ b, b.T) @ x @ a
return term_1 - x - term_2 @ term_3 + q
res = pt.vectorize(eval_fun, "(m,m),(m,n),(m,m),(n,n),(m,m)->(m,m)")(a, b, q, r, x)
res_np = res.eval()
atol = 1e-4 if config.floatX == "float32" else 1e-12
np.testing.assert_allclose(res, np.zeros_like(res), atol=atol)
np.testing.assert_allclose(res_np, np.zeros_like(res_np), atol=atol)
def test_solve_discrete_are_grad():
@pytest.mark.parametrize("add_batch_dim", [False, True])
def test_solve_discrete_are_grad(add_batch_dim):
a, b, q, r = (
np.array([[4, 3], [-4.5, -3.5]]),
np.array([[1], [-1]]),
np.array([[9, 6], [6, 4]]),
np.array([[1]]),
)
a, b, q, r = (x.astype(config.floatX) for x in [a, b, q, r])
if add_batch_dim:
a, b, q, r = (np.stack([x] * 5) for x in [a, b, q, r])
a, b, q, r = (x.astype(config.floatX) for x in [a, b, q, r])
rng = np.random.default_rng(utt.fetch_seed())
# TODO: Is there a "theoretically motivated" value to use here? I pulled 1e-4 out of a hat
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论