提交 3acb9b4f authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Jesse Grabowski

Move linear control Ops to `linear_control.py`

上级 d9889cca
......@@ -3,6 +3,7 @@ import warnings
import jax
from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor._linalg.solve.linear_control import SolveSylvester
from pytensor.tensor.slinalg import (
LU,
QR,
......@@ -15,7 +16,6 @@ from pytensor.tensor.slinalg import (
PivotToPermutations,
Schur,
Solve,
SolveSylvester,
SolveTriangular,
)
......
......@@ -42,10 +42,10 @@ from pytensor.link.numba.dispatch.string_codegen import (
CODE_TOKEN,
build_source_code,
)
from pytensor.tensor._linalg.solve.linear_control import TRSYL
from pytensor.tensor.slinalg import (
LU,
QR,
TRSYL,
BlockDiagonal,
Cholesky,
CholeskySolve,
......
from pytensor.tensor._linalg.solve.linear_control import *
from pytensor.tensor.nlinalg import *
from pytensor.tensor.slinalg import *
差异被折叠。
......@@ -10,6 +10,7 @@ from pytensor.configdefaults import config
from pytensor.tensor import nlinalg as pt_nlinalg
from pytensor.tensor import slinalg as pt_slinalg
from pytensor.tensor import subtensor as pt_subtensor
from pytensor.tensor._linalg.solve import linear_control
from pytensor.tensor.math import clip, cosh
from pytensor.tensor.type import matrix, vector
from tests.link.jax.test_basic import compare_jax_and_py
......@@ -275,7 +276,7 @@ def test_jax_solve_discrete_lyapunov(
):
A = pt.tensor(name="A", shape=shape)
B = pt.tensor(name="B", shape=shape)
out = pt_slinalg.solve_discrete_lyapunov(A, B, method=method)
out = linear_control.solve_discrete_lyapunov(A, B, method=method)
atol = rtol = 1e-8 if config.floatX == "float64" else 1e-3
compare_jax_and_py(
......@@ -404,6 +405,6 @@ def test_jax_solve_sylvester():
B_val = rng.normal(size=(3, 3)).astype(config.floatX)
C_val = rng.normal(size=(3, 3)).astype(config.floatX)
out = pt_slinalg.solve_sylvester(A, B, C)
out = linear_control.solve_sylvester(A, B, C)
compare_jax_and_py([A, B, C], [out], [A_val, B_val, C_val])
......@@ -3,6 +3,7 @@ import pytest
from pytensor import config
from pytensor import tensor as pt
from pytensor.tensor._linalg.solve import linear_control
from tests.link.numba.test_basic import compare_numba_and_py
......@@ -17,7 +18,7 @@ def test_solve_sylvester():
A = pt.matrix("A")
B = pt.matrix("B")
C = pt.matrix("C")
X = pt.linalg.solve_sylvester(A, B, C)
X = linear_control.solve_sylvester(A, B, C)
rng = np.random.default_rng()
A_val = rng.normal(size=(5, 5)).astype(floatX)
......@@ -30,7 +31,7 @@ def test_solve_sylvester():
def test_solve_continuous_lyapunov():
A = pt.matrix("A")
Q = pt.matrix("Q")
X = pt.linalg.solve_continuous_lyapunov(A, Q)
X = linear_control.solve_continuous_lyapunov(A, Q)
rng = np.random.default_rng()
A_val = rng.normal(size=(5, 5)).astype(floatX)
......@@ -44,7 +45,7 @@ def test_solve_continuous_lyapunov():
def test_solve_discrete_lyapunov(method):
A = pt.matrix("A")
Q = pt.matrix("Q")
X = pt.linalg.solve_discrete_lyapunov(A, Q, method=method)
X = linear_control.solve_discrete_lyapunov(A, Q, method=method)
rng = np.random.default_rng()
A_val = rng.normal(size=(5, 5)).astype(floatX)
......
......@@ -15,6 +15,12 @@ from pytensor.configdefaults import config
from pytensor.graph.basic import equal_computations
from pytensor.link.numba import NumbaLinker
from pytensor.tensor import TensorVariable
from pytensor.tensor._linalg.solve.linear_control import (
solve_continuous_lyapunov,
solve_discrete_are,
solve_discrete_lyapunov,
solve_sylvester,
)
from pytensor.tensor.slinalg import (
Cholesky,
CholeskySolve,
......@@ -33,10 +39,6 @@ from pytensor.tensor.slinalg import (
qr,
schur,
solve,
solve_continuous_lyapunov,
solve_discrete_are,
solve_discrete_lyapunov,
solve_sylvester,
solve_triangular,
)
from pytensor.tensor.type import dmatrix, matrix, tensor, vector
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论