提交 3acb9b4f authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Jesse Grabowski

Move linear control Ops to `linear_control.py`

上级 d9889cca
...@@ -3,6 +3,7 @@ import warnings ...@@ -3,6 +3,7 @@ import warnings
import jax import jax
from pytensor.link.jax.dispatch.basic import jax_funcify from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor._linalg.solve.linear_control import SolveSylvester
from pytensor.tensor.slinalg import ( from pytensor.tensor.slinalg import (
LU, LU,
QR, QR,
...@@ -15,7 +16,6 @@ from pytensor.tensor.slinalg import ( ...@@ -15,7 +16,6 @@ from pytensor.tensor.slinalg import (
PivotToPermutations, PivotToPermutations,
Schur, Schur,
Solve, Solve,
SolveSylvester,
SolveTriangular, SolveTriangular,
) )
......
...@@ -42,10 +42,10 @@ from pytensor.link.numba.dispatch.string_codegen import ( ...@@ -42,10 +42,10 @@ from pytensor.link.numba.dispatch.string_codegen import (
CODE_TOKEN, CODE_TOKEN,
build_source_code, build_source_code,
) )
from pytensor.tensor._linalg.solve.linear_control import TRSYL
from pytensor.tensor.slinalg import ( from pytensor.tensor.slinalg import (
LU, LU,
QR, QR,
TRSYL,
BlockDiagonal, BlockDiagonal,
Cholesky, Cholesky,
CholeskySolve, CholeskySolve,
......
from pytensor.tensor._linalg.solve.linear_control import *
from pytensor.tensor.nlinalg import * from pytensor.tensor.nlinalg import *
from pytensor.tensor.slinalg import * from pytensor.tensor.slinalg import *
差异被折叠。
...@@ -10,6 +10,7 @@ from pytensor.configdefaults import config ...@@ -10,6 +10,7 @@ from pytensor.configdefaults import config
from pytensor.tensor import nlinalg as pt_nlinalg from pytensor.tensor import nlinalg as pt_nlinalg
from pytensor.tensor import slinalg as pt_slinalg from pytensor.tensor import slinalg as pt_slinalg
from pytensor.tensor import subtensor as pt_subtensor from pytensor.tensor import subtensor as pt_subtensor
from pytensor.tensor._linalg.solve import linear_control
from pytensor.tensor.math import clip, cosh from pytensor.tensor.math import clip, cosh
from pytensor.tensor.type import matrix, vector from pytensor.tensor.type import matrix, vector
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
...@@ -275,7 +276,7 @@ def test_jax_solve_discrete_lyapunov( ...@@ -275,7 +276,7 @@ def test_jax_solve_discrete_lyapunov(
): ):
A = pt.tensor(name="A", shape=shape) A = pt.tensor(name="A", shape=shape)
B = pt.tensor(name="B", shape=shape) B = pt.tensor(name="B", shape=shape)
out = pt_slinalg.solve_discrete_lyapunov(A, B, method=method) out = linear_control.solve_discrete_lyapunov(A, B, method=method)
atol = rtol = 1e-8 if config.floatX == "float64" else 1e-3 atol = rtol = 1e-8 if config.floatX == "float64" else 1e-3
compare_jax_and_py( compare_jax_and_py(
...@@ -404,6 +405,6 @@ def test_jax_solve_sylvester(): ...@@ -404,6 +405,6 @@ def test_jax_solve_sylvester():
B_val = rng.normal(size=(3, 3)).astype(config.floatX) B_val = rng.normal(size=(3, 3)).astype(config.floatX)
C_val = rng.normal(size=(3, 3)).astype(config.floatX) C_val = rng.normal(size=(3, 3)).astype(config.floatX)
out = pt_slinalg.solve_sylvester(A, B, C) out = linear_control.solve_sylvester(A, B, C)
compare_jax_and_py([A, B, C], [out], [A_val, B_val, C_val]) compare_jax_and_py([A, B, C], [out], [A_val, B_val, C_val])
...@@ -3,6 +3,7 @@ import pytest ...@@ -3,6 +3,7 @@ import pytest
from pytensor import config from pytensor import config
from pytensor import tensor as pt from pytensor import tensor as pt
from pytensor.tensor._linalg.solve import linear_control
from tests.link.numba.test_basic import compare_numba_and_py from tests.link.numba.test_basic import compare_numba_and_py
...@@ -17,7 +18,7 @@ def test_solve_sylvester(): ...@@ -17,7 +18,7 @@ def test_solve_sylvester():
A = pt.matrix("A") A = pt.matrix("A")
B = pt.matrix("B") B = pt.matrix("B")
C = pt.matrix("C") C = pt.matrix("C")
X = pt.linalg.solve_sylvester(A, B, C) X = linear_control.solve_sylvester(A, B, C)
rng = np.random.default_rng() rng = np.random.default_rng()
A_val = rng.normal(size=(5, 5)).astype(floatX) A_val = rng.normal(size=(5, 5)).astype(floatX)
...@@ -30,7 +31,7 @@ def test_solve_sylvester(): ...@@ -30,7 +31,7 @@ def test_solve_sylvester():
def test_solve_continuous_lyapunov(): def test_solve_continuous_lyapunov():
A = pt.matrix("A") A = pt.matrix("A")
Q = pt.matrix("Q") Q = pt.matrix("Q")
X = pt.linalg.solve_continuous_lyapunov(A, Q) X = linear_control.solve_continuous_lyapunov(A, Q)
rng = np.random.default_rng() rng = np.random.default_rng()
A_val = rng.normal(size=(5, 5)).astype(floatX) A_val = rng.normal(size=(5, 5)).astype(floatX)
...@@ -44,7 +45,7 @@ def test_solve_continuous_lyapunov(): ...@@ -44,7 +45,7 @@ def test_solve_continuous_lyapunov():
def test_solve_discrete_lyapunov(method): def test_solve_discrete_lyapunov(method):
A = pt.matrix("A") A = pt.matrix("A")
Q = pt.matrix("Q") Q = pt.matrix("Q")
X = pt.linalg.solve_discrete_lyapunov(A, Q, method=method) X = linear_control.solve_discrete_lyapunov(A, Q, method=method)
rng = np.random.default_rng() rng = np.random.default_rng()
A_val = rng.normal(size=(5, 5)).astype(floatX) A_val = rng.normal(size=(5, 5)).astype(floatX)
......
...@@ -15,6 +15,12 @@ from pytensor.configdefaults import config ...@@ -15,6 +15,12 @@ from pytensor.configdefaults import config
from pytensor.graph.basic import equal_computations from pytensor.graph.basic import equal_computations
from pytensor.link.numba import NumbaLinker from pytensor.link.numba import NumbaLinker
from pytensor.tensor import TensorVariable from pytensor.tensor import TensorVariable
from pytensor.tensor._linalg.solve.linear_control import (
solve_continuous_lyapunov,
solve_discrete_are,
solve_discrete_lyapunov,
solve_sylvester,
)
from pytensor.tensor.slinalg import ( from pytensor.tensor.slinalg import (
Cholesky, Cholesky,
CholeskySolve, CholeskySolve,
...@@ -33,10 +39,6 @@ from pytensor.tensor.slinalg import ( ...@@ -33,10 +39,6 @@ from pytensor.tensor.slinalg import (
qr, qr,
schur, schur,
solve, solve,
solve_continuous_lyapunov,
solve_discrete_are,
solve_discrete_lyapunov,
solve_sylvester,
solve_triangular, solve_triangular,
) )
from pytensor.tensor.type import dmatrix, matrix, tensor, vector from pytensor.tensor.type import dmatrix, matrix, tensor, vector
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论