Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
3acb9b4f
提交
3acb9b4f
authored
1月 13, 2026
作者:
jessegrabowski
提交者:
Jesse Grabowski
1月 18, 2026
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Move linear control Ops to `linear_control.py`
上级
d9889cca
隐藏空白字符变更
内嵌
并排
正在显示
8 个修改的文件
包含
404 行增加
和
363 行删除
+404
-363
slinalg.py
pytensor/link/jax/dispatch/slinalg.py
+1
-1
slinalg.py
pytensor/link/numba/dispatch/slinalg.py
+1
-1
linear_control.py
pytensor/tensor/_linalg/solve/linear_control.py
+366
-0
linalg.py
pytensor/tensor/linalg.py
+1
-0
slinalg.py
pytensor/tensor/slinalg.py
+22
-352
test_slinalg.py
tests/link/jax/test_slinalg.py
+3
-2
test_linear_control.py
tests/link/numba/linalg/solve/test_linear_control.py
+4
-3
test_slinalg.py
tests/tensor/test_slinalg.py
+6
-4
没有找到文件。
pytensor/link/jax/dispatch/slinalg.py
浏览文件 @
3acb9b4f
...
@@ -3,6 +3,7 @@ import warnings
...
@@ -3,6 +3,7 @@ import warnings
import
jax
import
jax
from
pytensor.link.jax.dispatch.basic
import
jax_funcify
from
pytensor.link.jax.dispatch.basic
import
jax_funcify
from
pytensor.tensor._linalg.solve.linear_control
import
SolveSylvester
from
pytensor.tensor.slinalg
import
(
from
pytensor.tensor.slinalg
import
(
LU
,
LU
,
QR
,
QR
,
...
@@ -15,7 +16,6 @@ from pytensor.tensor.slinalg import (
...
@@ -15,7 +16,6 @@ from pytensor.tensor.slinalg import (
PivotToPermutations
,
PivotToPermutations
,
Schur
,
Schur
,
Solve
,
Solve
,
SolveSylvester
,
SolveTriangular
,
SolveTriangular
,
)
)
...
...
pytensor/link/numba/dispatch/slinalg.py
浏览文件 @
3acb9b4f
...
@@ -42,10 +42,10 @@ from pytensor.link.numba.dispatch.string_codegen import (
...
@@ -42,10 +42,10 @@ from pytensor.link.numba.dispatch.string_codegen import (
CODE_TOKEN
,
CODE_TOKEN
,
build_source_code
,
build_source_code
,
)
)
from
pytensor.tensor._linalg.solve.linear_control
import
TRSYL
from
pytensor.tensor.slinalg
import
(
from
pytensor.tensor.slinalg
import
(
LU
,
LU
,
QR
,
QR
,
TRSYL
,
BlockDiagonal
,
BlockDiagonal
,
Cholesky
,
Cholesky
,
CholeskySolve
,
CholeskySolve
,
...
...
pytensor/tensor/_linalg/solve/linear_control.py
0 → 100644
浏览文件 @
3acb9b4f
from
typing
import
Literal
,
cast
import
numpy
as
np
from
scipy
import
linalg
as
scipy_linalg
from
scipy.linalg
import
get_lapack_funcs
import
pytensor
import
pytensor.tensor.basic
as
ptb
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.graph
import
Apply
,
Op
from
pytensor.tensor
import
TensorLike
from
pytensor.tensor.basic
import
as_tensor_variable
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.functional
import
vectorize
from
pytensor.tensor.nlinalg
import
kron
,
matrix_dot
from
pytensor.tensor.reshape
import
join_dims
from
pytensor.tensor.shape
import
reshape
from
pytensor.tensor.slinalg
import
schur
,
solve
from
pytensor.tensor.type
import
matrix
from
pytensor.tensor.variable
import
TensorVariable
class
TRSYL
(
Op
):
"""
Wrapper around LAPACK's `trsyl` function to solve the Sylvester equation:
op(A) @ X + X @ op(B) = alpha * C
Where `op(A)` is either `A` or `A^T`, depending on the options passed to `trsyl`. A and B must be
in Schur canonical form: block upper triangular matrices with 1x1 and 2x2 blocks on the diagonal;
each 2x2 diagonal block has its diagonal elements equal and its off-diagonal elements opposite in sign.
This Op is not public facing. Instead, it is intended to be used as a building block for higher-level
linear control solvers, such as `SolveSylvester` and `SolveContinuousLyapunov`.
"""
__props__
=
(
"overwrite_c"
,)
gufunc_signature
=
"(m,m),(n,n),(m,n)->(m,n)"
def
__init__
(
self
,
overwrite_c
=
False
):
self
.
overwrite_c
=
overwrite_c
if
self
.
overwrite_c
:
self
.
destroy_map
=
{
0
:
[
2
]}
def
make_node
(
self
,
A
,
B
,
C
):
A
=
as_tensor_variable
(
A
)
B
=
as_tensor_variable
(
B
)
C
=
as_tensor_variable
(
C
)
out_dtype
=
pytensor
.
scalar
.
upcast
(
A
.
dtype
,
B
.
dtype
,
C
.
dtype
)
output_shape
=
list
(
C
.
type
.
shape
)
if
output_shape
[
0
]
is
None
and
A
.
type
.
shape
[
0
]
is
not
None
:
output_shape
[
0
]
=
A
.
type
.
shape
[
0
]
if
output_shape
[
1
]
is
None
and
B
.
type
.
shape
[
0
]
is
not
None
:
output_shape
[
1
]
=
B
.
type
.
shape
[
0
]
X
=
ptb
.
tensor
(
dtype
=
out_dtype
,
shape
=
tuple
(
output_shape
))
return
Apply
(
self
,
[
A
,
B
,
C
],
[
X
])
def
perform
(
self
,
node
,
inputs
,
outputs_storage
):
(
A
,
B
,
C
)
=
inputs
X
=
outputs_storage
[
0
]
out_dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
(
trsyl
,)
=
get_lapack_funcs
((
"trsyl"
,),
(
A
,
B
,
C
))
if
A
.
size
==
0
or
B
.
size
==
0
:
return
np
.
empty_like
(
C
,
dtype
=
out_dtype
)
Y
,
scale
,
info
=
trsyl
(
A
,
B
,
C
,
overwrite_c
=
self
.
overwrite_c
)
if
info
<
0
:
return
np
.
full_like
(
C
,
np
.
nan
,
dtype
=
out_dtype
)
Y
*=
scale
X
[
0
]
=
Y
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
return
[
shapes
[
2
]]
def
inplace_on_inputs
(
self
,
allowed_inplace_inputs
:
list
[
int
])
->
"Op"
:
if
not
allowed_inplace_inputs
:
return
self
new_props
=
self
.
_props_dict
()
# type: ignore
new_props
[
"overwrite_c"
]
=
True
return
type
(
self
)(
**
new_props
)
def
_lop_solve_continuous_sylvester
(
inputs
,
outputs
,
output_grads
):
"""
Closed-form gradients for the solution for the Sylvester equation.
Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
Note that these authors write the equation as AP + PB + C = 0. The code here follows scipy notation,
so P = X and C = -Q. This change of notation requires minor adjustment to equations 10 and 11c
"""
A
,
B
,
_
=
inputs
(
dX
,)
=
output_grads
(
X
,)
=
outputs
S
=
solve_sylvester
(
A
.
conj
()
.
mT
,
B
.
conj
()
.
mT
,
-
dX
)
# Eq 10
A_bar
=
S
@
X
.
conj
()
.
mT
# Eq 11a
B_bar
=
X
.
conj
()
.
mT
@
S
# Eq 11b
Q_bar
=
-
S
# Eq 11c
return
[
A_bar
,
B_bar
,
Q_bar
]
class
SolveSylvester
(
OpFromGraph
):
"""
Wrapper Op for solving the continuous Sylvester equation :math:`AX + XB = C` for :math:`X`.
"""
gufunc_signature
=
"(m,m),(n,n),(m,n)->(m,n)"
def
solve_sylvester
(
A
:
TensorLike
,
B
:
TensorLike
,
Q
:
TensorLike
)
->
TensorVariable
:
"""
Solve the Sylvester equation :math:`AX + XB = C` for :math:`X`.
Following scipy notation, this function solves the continuous-time Sylvester equation.
Parameters
----------
A: TensorLike
Square matrix of shape ``M x M``.
B: TensorLike
Square matrix of shape ``N x N``.
Q: TensorLike
Matrix of shape ``M x N``.
Returns
-------
X: TensorVariable
Matrix of shape ``M x N``.
"""
A
=
as_tensor_variable
(
A
)
B
=
as_tensor_variable
(
B
)
Q
=
as_tensor_variable
(
Q
)
A_matrix
=
matrix
(
dtype
=
A
.
dtype
,
shape
=
A
.
type
.
shape
[
-
2
:])
B_matrix
=
matrix
(
dtype
=
B
.
dtype
,
shape
=
B
.
type
.
shape
[
-
2
:])
Q_matrix
=
matrix
(
dtype
=
Q
.
dtype
,
shape
=
Q
.
type
.
shape
[
-
2
:])
R
,
U
=
schur
(
A_matrix
,
output
=
"real"
)
S
,
V
=
schur
(
B_matrix
,
output
=
"real"
)
F
=
U
.
conj
()
.
mT
@
Q_matrix
@
V
_trsyl
=
Blockwise
(
TRSYL
())
Y
=
_trsyl
(
R
,
S
,
F
)
X
=
U
@
Y
@
V
.
conj
()
.
mT
op
=
SolveSylvester
(
inputs
=
[
A_matrix
,
B_matrix
,
Q_matrix
],
outputs
=
[
X
],
lop_overrides
=
_lop_solve_continuous_sylvester
,
)
return
cast
(
TensorVariable
,
Blockwise
(
op
)(
A
,
B
,
Q
))
def
solve_continuous_lyapunov
(
A
:
TensorLike
,
Q
:
TensorLike
)
->
TensorVariable
:
"""
Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
Note that the lyapunov equation is a special case of the Sylvester equation, with :math:`B = A^H`. This function
thus simply calls `solve_sylvester` with the appropriate arguments.
Parameters
----------
A: TensorLike
Square matrix of shape ``N x N``.
Q: TensorLike
Square matrix of shape ``N x N``.
Returns
-------
X: TensorVariable
Square matrix of shape ``N x N``
"""
A
=
as_tensor_variable
(
A
)
Q
=
as_tensor_variable
(
Q
)
return
solve_sylvester
(
A
,
A
.
conj
()
.
mT
,
Q
)
class
SolveBilinearDiscreteLyapunov
(
OpFromGraph
):
"""
Wrapper Op for solving the discrete Lyapunov equation :math:`A X A^H - X = Q` for :math:`X`.
Required so that backends that do not support method='bilinear' in `solve_discrete_lyapunov` can be rewritten
to method='direct'.
"""
def
solve_discrete_lyapunov
(
A
:
TensorLike
,
Q
:
TensorLike
,
method
:
Literal
[
"direct"
,
"bilinear"
]
=
"bilinear"
,
)
->
TensorVariable
:
"""Solve the discrete Lyapunov equation :math:`A X A^H - X = Q`.
Parameters
----------
A: TensorLike
Square matrix of shape N x N
Q: TensorLike
Square matrix of shape N x N
method: str, one of ``"direct"`` or ``"bilinear"``
Solver method used, . ``"direct"`` solves the problem directly via matrix inversion. This has a pure
PyTensor implementation and can thus be cross-compiled to supported backends, and should be preferred when
``N`` is not large. The direct method scales poorly with the size of ``N``, and the bilinear can be
used in these cases.
Returns
-------
X: TensorVariable
Square matrix of shape ``N x N``. Solution to the Lyapunov equation
"""
if
method
not
in
[
"direct"
,
"bilinear"
]:
raise
ValueError
(
f
'Parameter "method" must be one of "direct" or "bilinear", found {method}'
)
A
=
as_tensor_variable
(
A
)
Q
=
as_tensor_variable
(
Q
)
if
method
==
"direct"
:
vec_kron
=
vectorize
(
kron
,
signature
=
"(n,n),(n,n)->(m,m)"
)
AxA
=
vec_kron
(
A
,
A
.
conj
())
I
=
ptb
.
eye
(
AxA
.
shape
[
-
1
])
vec_Q
=
join_dims
(
Q
,
start_axis
=-
2
,
n_axes
=
2
)
vec_X
=
solve
(
I
-
AxA
,
vec_Q
,
b_ndim
=
1
)
return
reshape
(
vec_X
,
A
.
shape
)
elif
method
==
"bilinear"
:
I
=
ptb
.
eye
(
A
.
shape
[
-
2
])
B_1
=
A
.
conj
()
.
mT
+
I
B_2
=
A
.
conj
()
.
mT
-
I
B
=
solve
(
B_1
.
mT
,
B_2
.
mT
)
.
mT
AI_inv_Q
=
solve
(
A
+
I
,
Q
)
C
=
2
*
solve
(
B_1
.
mT
,
AI_inv_Q
.
mT
)
.
mT
X
=
solve_continuous_lyapunov
(
B
.
conj
()
.
mT
,
-
C
)
op
=
SolveBilinearDiscreteLyapunov
(
inputs
=
[
A
,
Q
],
outputs
=
[
X
])
return
cast
(
TensorVariable
,
op
(
A
,
Q
))
else
:
raise
ValueError
(
f
"Unknown method {method}"
)
class
SolveDiscreteARE
(
Op
):
__props__
=
(
"enforce_Q_symmetric"
,)
gufunc_signature
=
"(m,m),(m,n),(m,m),(n,n)->(m,m)"
def
__init__
(
self
,
enforce_Q_symmetric
:
bool
=
False
):
self
.
enforce_Q_symmetric
=
enforce_Q_symmetric
def
make_node
(
self
,
A
,
B
,
Q
,
R
):
A
=
as_tensor_variable
(
A
)
B
=
as_tensor_variable
(
B
)
Q
=
as_tensor_variable
(
Q
)
R
=
as_tensor_variable
(
R
)
out_dtype
=
pytensor
.
scalar
.
upcast
(
A
.
dtype
,
B
.
dtype
,
Q
.
dtype
,
R
.
dtype
)
X
=
pytensor
.
tensor
.
matrix
(
dtype
=
out_dtype
)
return
pytensor
.
graph
.
basic
.
Apply
(
self
,
[
A
,
B
,
Q
,
R
],
[
X
])
def
perform
(
self
,
node
,
inputs
,
output_storage
):
A
,
B
,
Q
,
R
=
inputs
X
=
output_storage
[
0
]
if
self
.
enforce_Q_symmetric
:
Q
=
0.5
*
(
Q
+
Q
.
T
)
out_dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
X
[
0
]
=
scipy_linalg
.
solve_discrete_are
(
A
,
B
,
Q
,
R
)
.
astype
(
out_dtype
)
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
return
[
shapes
[
0
]]
def
grad
(
self
,
inputs
,
output_grads
):
# Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
A
,
B
,
Q
,
R
=
inputs
(
dX
,)
=
output_grads
X
=
self
(
A
,
B
,
Q
,
R
)
K_inner
=
R
+
matrix_dot
(
B
.
T
,
X
,
B
)
# K_inner is guaranteed to be symmetric, because X and R are symmetric
K_inner_inv_BT
=
solve
(
K_inner
,
B
.
T
,
assume_a
=
"sym"
)
K
=
matrix_dot
(
K_inner_inv_BT
,
X
,
A
)
A_tilde
=
A
-
B
.
dot
(
K
)
dX_symm
=
0.5
*
(
dX
+
dX
.
T
)
S
=
solve_discrete_lyapunov
(
A_tilde
,
dX_symm
)
A_bar
=
2
*
matrix_dot
(
X
,
A_tilde
,
S
)
B_bar
=
-
2
*
matrix_dot
(
X
,
A_tilde
,
S
,
K
.
T
)
Q_bar
=
S
R_bar
=
matrix_dot
(
K
,
S
,
K
.
T
)
return
[
A_bar
,
B_bar
,
Q_bar
,
R_bar
]
def
solve_discrete_are
(
A
:
TensorLike
,
B
:
TensorLike
,
Q
:
TensorLike
,
R
:
TensorLike
,
enforce_Q_symmetric
:
bool
=
False
,
)
->
TensorVariable
:
"""
Solve the discrete Algebraic Riccati equation :math:`A^TXA - X - (A^TXB)(R + B^TXB)^{-1}(B^TXA) + Q = 0`.
Discrete-time Algebraic Riccati equations arise in the context of optimal control and filtering problems, as the
solution to Linear-Quadratic Regulators (LQR), Linear-Quadratic-Guassian (LQG) control problems, and as the
steady-state covariance of the Kalman Filter.
Such problems typically have many solutions, but we are generally only interested in the unique *stabilizing*
solution. This stable solution, if it exists, will be returned by this function.
Parameters
----------
A: TensorLike
Square matrix of shape M x M
B: TensorLike
Square matrix of shape M x M
Q: TensorLike
Symmetric square matrix of shape M x M
R: TensorLike
Square matrix of shape N x N
enforce_Q_symmetric: bool
If True, the provided Q matrix is transformed to 0.5 * (Q + Q.T) to ensure symmetry
Returns
-------
X: TensorVariable
Square matrix of shape M x M, representing the solution to the DARE
"""
return
cast
(
TensorVariable
,
Blockwise
(
SolveDiscreteARE
(
enforce_Q_symmetric
))(
A
,
B
,
Q
,
R
)
)
__all__
=
[
"solve_continuous_lyapunov"
,
"solve_discrete_are"
,
"solve_discrete_lyapunov"
,
"solve_sylvester"
,
]
pytensor/tensor/linalg.py
浏览文件 @
3acb9b4f
from
pytensor.tensor._linalg.solve.linear_control
import
*
from
pytensor.tensor.nlinalg
import
*
from
pytensor.tensor.nlinalg
import
*
from
pytensor.tensor.slinalg
import
*
from
pytensor.tensor.slinalg
import
*
pytensor/tensor/slinalg.py
浏览文件 @
3acb9b4f
...
@@ -11,7 +11,6 @@ from scipy.linalg import get_lapack_funcs
...
@@ -11,7 +11,6 @@ from scipy.linalg import get_lapack_funcs
import
pytensor
import
pytensor
from
pytensor
import
ifelse
from
pytensor
import
ifelse
from
pytensor
import
tensor
as
pt
from
pytensor
import
tensor
as
pt
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.gradient
import
DisconnectedType
,
disconnected_type
from
pytensor.gradient
import
DisconnectedType
,
disconnected_type
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.op
import
Op
from
pytensor.graph.op
import
Op
...
@@ -21,9 +20,6 @@ from pytensor.tensor import basic as ptb
...
@@ -21,9 +20,6 @@ from pytensor.tensor import basic as ptb
from
pytensor.tensor
import
math
as
ptm
from
pytensor.tensor
import
math
as
ptm
from
pytensor.tensor.basic
import
as_tensor_variable
,
diagonal
from
pytensor.tensor.basic
import
as_tensor_variable
,
diagonal
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.nlinalg
import
kron
,
matrix_dot
from
pytensor.tensor.reshape
import
join_dims
from
pytensor.tensor.shape
import
reshape
from
pytensor.tensor.type
import
matrix
,
tensor
,
vector
from
pytensor.tensor.type
import
matrix
,
tensor
,
vector
from
pytensor.tensor.variable
import
TensorVariable
from
pytensor.tensor.variable
import
TensorVariable
...
@@ -1298,350 +1294,6 @@ class Expm(Op):
...
@@ -1298,350 +1294,6 @@ class Expm(Op):
expm
=
Blockwise
(
Expm
())
expm
=
Blockwise
(
Expm
())
class
TRSYL
(
Op
):
"""
Wrapper around LAPACK's `trsyl` function to solve the Sylvester equation:
op(A) @ X + X @ op(B) = alpha * C
Where `op(A)` is either `A` or `A^T`, depending on the options passed to `trsyl`. A and B must be
in Schur canonical form: block upper triangular matrices with 1x1 and 2x2 blocks on the diagonal;
each 2x2 diagonal block has its diagonal elements equal and its off-diagonal elements opposite in sign.
This Op is not public facing. Instead, it is intended to be used as a building block for higher-level
linear control solvers, such as `SolveSylvester` and `SolveContinuousLyapunov`.
"""
__props__
=
(
"overwrite_c"
,)
gufunc_signature
=
"(m,m),(n,n),(m,n)->(m,n)"
def
__init__
(
self
,
overwrite_c
=
False
):
self
.
overwrite_c
=
overwrite_c
if
self
.
overwrite_c
:
self
.
destroy_map
=
{
0
:
[
2
]}
def
make_node
(
self
,
A
,
B
,
C
):
A
=
as_tensor_variable
(
A
)
B
=
as_tensor_variable
(
B
)
C
=
as_tensor_variable
(
C
)
out_dtype
=
pytensor
.
scalar
.
upcast
(
A
.
dtype
,
B
.
dtype
,
C
.
dtype
)
output_shape
=
list
(
C
.
type
.
shape
)
if
output_shape
[
0
]
is
None
and
A
.
type
.
shape
[
0
]
is
not
None
:
output_shape
[
0
]
=
A
.
type
.
shape
[
0
]
if
output_shape
[
1
]
is
None
and
B
.
type
.
shape
[
0
]
is
not
None
:
output_shape
[
1
]
=
B
.
type
.
shape
[
0
]
X
=
tensor
(
dtype
=
out_dtype
,
shape
=
tuple
(
output_shape
))
return
Apply
(
self
,
[
A
,
B
,
C
],
[
X
])
def
perform
(
self
,
node
,
inputs
,
outputs_storage
):
(
A
,
B
,
C
)
=
inputs
X
=
outputs_storage
[
0
]
out_dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
(
trsyl
,)
=
get_lapack_funcs
((
"trsyl"
,),
(
A
,
B
,
C
))
if
A
.
size
==
0
or
B
.
size
==
0
:
return
np
.
empty_like
(
C
,
dtype
=
out_dtype
)
Y
,
scale
,
info
=
trsyl
(
A
,
B
,
C
,
overwrite_c
=
self
.
overwrite_c
)
if
info
<
0
:
return
np
.
full_like
(
C
,
np
.
nan
,
dtype
=
out_dtype
)
Y
*=
scale
X
[
0
]
=
Y
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
return
[
shapes
[
2
]]
def
inplace_on_inputs
(
self
,
allowed_inplace_inputs
:
list
[
int
])
->
"Op"
:
if
not
allowed_inplace_inputs
:
return
self
new_props
=
self
.
_props_dict
()
# type: ignore
new_props
[
"overwrite_c"
]
=
True
return
type
(
self
)(
**
new_props
)
def
_trsyl
(
A
:
TensorLike
,
B
:
TensorLike
,
C
:
TensorLike
)
->
TensorVariable
:
A
=
as_tensor_variable
(
A
)
B
=
as_tensor_variable
(
B
)
C
=
as_tensor_variable
(
C
)
return
cast
(
TensorVariable
,
Blockwise
(
TRSYL
())(
A
,
B
,
C
))
class
SolveSylvester
(
OpFromGraph
):
"""
Wrapper Op for solving the continuous Sylvester equation :math:`AX + XB = C` for :math:`X`.
"""
gufunc_signature
=
"(m,m),(n,n),(m,n)->(m,n)"
def
_lop_solve_continuous_sylvester
(
inputs
,
outputs
,
output_grads
):
"""
Closed-form gradients for the solution for the Sylvester equation.
Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
Note that these authors write the equation as AP + PB + C = 0. The code here follows scipy notation,
so P = X and C = -Q. This change of notation requires minor adjustment to equations 10 and 11c
"""
A
,
B
,
_
=
inputs
(
dX
,)
=
output_grads
(
X
,)
=
outputs
S
=
solve_sylvester
(
A
.
conj
()
.
mT
,
B
.
conj
()
.
mT
,
-
dX
)
# Eq 10
A_bar
=
S
@
X
.
conj
()
.
mT
# Eq 11a
B_bar
=
X
.
conj
()
.
mT
@
S
# Eq 11b
Q_bar
=
-
S
# Eq 11c
return
[
A_bar
,
B_bar
,
Q_bar
]
def
solve_sylvester
(
A
:
TensorLike
,
B
:
TensorLike
,
Q
:
TensorLike
)
->
TensorVariable
:
"""
Solve the Sylvester equation :math:`AX + XB = C` for :math:`X`.
Following scipy notation, this function solves the continuous-time Sylvester equation.
Parameters
----------
A: TensorLike
Square matrix of shape ``M x M``.
B: TensorLike
Square matrix of shape ``N x N``.
Q: TensorLike
Matrix of shape ``M x N``.
Returns
-------
X: TensorVariable
Matrix of shape ``M x N``.
"""
A
=
as_tensor_variable
(
A
)
B
=
as_tensor_variable
(
B
)
Q
=
as_tensor_variable
(
Q
)
A_matrix
=
pt
.
matrix
(
dtype
=
A
.
dtype
,
shape
=
A
.
type
.
shape
[
-
2
:])
B_matrix
=
pt
.
matrix
(
dtype
=
B
.
dtype
,
shape
=
B
.
type
.
shape
[
-
2
:])
Q_matrix
=
pt
.
matrix
(
dtype
=
Q
.
dtype
,
shape
=
Q
.
type
.
shape
[
-
2
:])
R
,
U
=
schur
(
A_matrix
,
output
=
"real"
)
S
,
V
=
schur
(
B_matrix
,
output
=
"real"
)
F
=
U
.
conj
()
.
mT
@
Q_matrix
@
V
Y
=
_trsyl
(
R
,
S
,
F
)
X
=
U
@
Y
@
V
.
conj
()
.
mT
op
=
SolveSylvester
(
inputs
=
[
A_matrix
,
B_matrix
,
Q_matrix
],
outputs
=
[
X
],
lop_overrides
=
_lop_solve_continuous_sylvester
,
)
return
cast
(
TensorVariable
,
Blockwise
(
op
)(
A
,
B
,
Q
))
def
solve_continuous_lyapunov
(
A
:
TensorLike
,
Q
:
TensorLike
)
->
TensorVariable
:
"""
Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
Note that the lyapunov equation is a special case of the Sylvester equation, with :math:`B = A^H`. This function
thus simply calls `solve_sylvester` with the appropriate arguments.
Parameters
----------
A: TensorLike
Square matrix of shape ``N x N``.
Q: TensorLike
Square matrix of shape ``N x N``.
Returns
-------
X: TensorVariable
Square matrix of shape ``N x N``
"""
A
=
as_tensor_variable
(
A
)
Q
=
as_tensor_variable
(
Q
)
return
solve_sylvester
(
A
,
A
.
conj
()
.
mT
,
Q
)
class
SolveBilinearDiscreteLyapunov
(
OpFromGraph
):
"""
Wrapper Op for solving the discrete Lyapunov equation :math:`A X A^H - X = Q` for :math:`X`.
Required so that backends that do not support method='bilinear' in `solve_discrete_lyapunov` can be rewritten
to method='direct'.
"""
def
solve_discrete_lyapunov
(
A
:
TensorLike
,
Q
:
TensorLike
,
method
:
Literal
[
"direct"
,
"bilinear"
]
=
"bilinear"
,
)
->
TensorVariable
:
"""Solve the discrete Lyapunov equation :math:`A X A^H - X = Q`.
Parameters
----------
A: TensorLike
Square matrix of shape N x N
Q: TensorLike
Square matrix of shape N x N
method: str, one of ``"direct"`` or ``"bilinear"``
Solver method used, . ``"direct"`` solves the problem directly via matrix inversion. This has a pure
PyTensor implementation and can thus be cross-compiled to supported backends, and should be preferred when
``N`` is not large. The direct method scales poorly with the size of ``N``, and the bilinear can be
used in these cases.
Returns
-------
X: TensorVariable
Square matrix of shape ``N x N``. Solution to the Lyapunov equation
"""
if
method
not
in
[
"direct"
,
"bilinear"
]:
raise
ValueError
(
f
'Parameter "method" must be one of "direct" or "bilinear", found {method}'
)
A
=
as_tensor_variable
(
A
)
Q
=
as_tensor_variable
(
Q
)
if
method
==
"direct"
:
vec_kron
=
pt
.
vectorize
(
kron
,
signature
=
"(n,n),(n,n)->(m,m)"
)
AxA
=
vec_kron
(
A
,
A
.
conj
())
eye
=
pt
.
eye
(
AxA
.
shape
[
-
1
])
vec_Q
=
join_dims
(
Q
,
start_axis
=-
2
,
n_axes
=
2
)
vec_X
=
solve
(
eye
-
AxA
,
vec_Q
,
b_ndim
=
1
)
return
reshape
(
vec_X
,
A
.
shape
)
elif
method
==
"bilinear"
:
I
=
pt
.
eye
(
A
.
shape
[
-
2
])
B_1
=
A
.
conj
()
.
mT
+
I
B_2
=
A
.
conj
()
.
mT
-
I
B
=
solve
(
B_1
.
mT
,
B_2
.
mT
)
.
mT
AI_inv_Q
=
solve
(
A
+
I
,
Q
)
C
=
2
*
solve
(
B_1
.
mT
,
AI_inv_Q
.
mT
)
.
mT
X
=
solve_continuous_lyapunov
(
B
.
conj
()
.
mT
,
-
C
)
op
=
SolveBilinearDiscreteLyapunov
(
inputs
=
[
A
,
Q
],
outputs
=
[
X
])
return
cast
(
TensorVariable
,
op
(
A
,
Q
))
else
:
raise
ValueError
(
f
"Unknown method {method}"
)
class
SolveDiscreteARE
(
Op
):
__props__
=
(
"enforce_Q_symmetric"
,)
gufunc_signature
=
"(m,m),(m,n),(m,m),(n,n)->(m,m)"
def
__init__
(
self
,
enforce_Q_symmetric
:
bool
=
False
):
self
.
enforce_Q_symmetric
=
enforce_Q_symmetric
def
make_node
(
self
,
A
,
B
,
Q
,
R
):
A
=
as_tensor_variable
(
A
)
B
=
as_tensor_variable
(
B
)
Q
=
as_tensor_variable
(
Q
)
R
=
as_tensor_variable
(
R
)
out_dtype
=
pytensor
.
scalar
.
upcast
(
A
.
dtype
,
B
.
dtype
,
Q
.
dtype
,
R
.
dtype
)
X
=
pytensor
.
tensor
.
matrix
(
dtype
=
out_dtype
)
return
pytensor
.
graph
.
basic
.
Apply
(
self
,
[
A
,
B
,
Q
,
R
],
[
X
])
def
perform
(
self
,
node
,
inputs
,
output_storage
):
A
,
B
,
Q
,
R
=
inputs
X
=
output_storage
[
0
]
if
self
.
enforce_Q_symmetric
:
Q
=
0.5
*
(
Q
+
Q
.
T
)
out_dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
X
[
0
]
=
scipy_linalg
.
solve_discrete_are
(
A
,
B
,
Q
,
R
)
.
astype
(
out_dtype
)
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
return
[
shapes
[
0
]]
def
grad
(
self
,
inputs
,
output_grads
):
# Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
A
,
B
,
Q
,
R
=
inputs
(
dX
,)
=
output_grads
X
=
self
(
A
,
B
,
Q
,
R
)
K_inner
=
R
+
matrix_dot
(
B
.
T
,
X
,
B
)
# K_inner is guaranteed to be symmetric, because X and R are symmetric
K_inner_inv_BT
=
solve
(
K_inner
,
B
.
T
,
assume_a
=
"sym"
)
K
=
matrix_dot
(
K_inner_inv_BT
,
X
,
A
)
A_tilde
=
A
-
B
.
dot
(
K
)
dX_symm
=
0.5
*
(
dX
+
dX
.
T
)
S
=
solve_discrete_lyapunov
(
A_tilde
,
dX_symm
)
A_bar
=
2
*
matrix_dot
(
X
,
A_tilde
,
S
)
B_bar
=
-
2
*
matrix_dot
(
X
,
A_tilde
,
S
,
K
.
T
)
Q_bar
=
S
R_bar
=
matrix_dot
(
K
,
S
,
K
.
T
)
return
[
A_bar
,
B_bar
,
Q_bar
,
R_bar
]
def
solve_discrete_are
(
A
:
TensorLike
,
B
:
TensorLike
,
Q
:
TensorLike
,
R
:
TensorLike
,
enforce_Q_symmetric
:
bool
=
False
,
)
->
TensorVariable
:
"""
Solve the discrete Algebraic Riccati equation :math:`A^TXA - X - (A^TXB)(R + B^TXB)^{-1}(B^TXA) + Q = 0`.
Discrete-time Algebraic Riccati equations arise in the context of optimal control and filtering problems, as the
solution to Linear-Quadratic Regulators (LQR), Linear-Quadratic-Guassian (LQG) control problems, and as the
steady-state covariance of the Kalman Filter.
Such problems typically have many solutions, but we are generally only interested in the unique *stabilizing*
solution. This stable solution, if it exists, will be returned by this function.
Parameters
----------
A: TensorLike
Square matrix of shape M x M
B: TensorLike
Square matrix of shape M x M
Q: TensorLike
Symmetric square matrix of shape M x M
R: TensorLike
Square matrix of shape N x N
enforce_Q_symmetric: bool
If True, the provided Q matrix is transformed to 0.5 * (Q + Q.T) to ensure symmetry
Returns
-------
X: TensorVariable
Square matrix of shape M x M, representing the solution to the DARE
"""
return
cast
(
TensorVariable
,
Blockwise
(
SolveDiscreteARE
(
enforce_Q_symmetric
))(
A
,
B
,
Q
,
R
)
)
def
_largest_common_dtype
(
tensors
:
Sequence
[
TensorVariable
])
->
np
.
dtype
:
def
_largest_common_dtype
(
tensors
:
Sequence
[
TensorVariable
])
->
np
.
dtype
:
return
reduce
(
lambda
l
,
r
:
np
.
promote_types
(
l
,
r
),
[
x
.
dtype
for
x
in
tensors
])
return
reduce
(
lambda
l
,
r
:
np
.
promote_types
(
l
,
r
),
[
x
.
dtype
for
x
in
tensors
])
...
@@ -2311,6 +1963,28 @@ def schur(
...
@@ -2311,6 +1963,28 @@ def schur(
return
Blockwise
(
Schur
(
output
=
output
,
sort
=
sort
))(
A
)
# type: ignore[return-value]
return
Blockwise
(
Schur
(
output
=
output
,
sort
=
sort
))(
A
)
# type: ignore[return-value]
_deprecated_names
=
{
"solve_continuous_lyapunov"
,
"solve_discrete_are"
,
"solve_discrete_lyapunov"
,
}
def
__getattr__
(
name
):
if
name
in
_deprecated_names
:
warnings
.
warn
(
f
"{name} has been moved from tensor/slinalg.py as part of a reorganization "
"of linear algebra routines in Pytensor. Imports from slinalg.py will fail in Pytensor 3.0.
\n
"
f
"Please use the stable user-facing linalg API: from pytensor.tensor.linalg import {name}"
,
DeprecationWarning
,
stacklevel
=
2
,
)
from
pytensor.tensor._linalg.solve
import
linear_control
return
getattr
(
linear_control
,
name
)
raise
AttributeError
(
f
"module {__name__} has no attribute {name}"
)
__all__
=
[
__all__
=
[
"block_diag"
,
"block_diag"
,
"cho_solve"
,
"cho_solve"
,
...
@@ -2323,9 +1997,5 @@ __all__ = [
...
@@ -2323,9 +1997,5 @@ __all__ = [
"qr"
,
"qr"
,
"schur"
,
"schur"
,
"solve"
,
"solve"
,
"solve_continuous_lyapunov"
,
"solve_discrete_are"
,
"solve_discrete_lyapunov"
,
"solve_sylvester"
,
"solve_triangular"
,
"solve_triangular"
,
]
]
tests/link/jax/test_slinalg.py
浏览文件 @
3acb9b4f
...
@@ -10,6 +10,7 @@ from pytensor.configdefaults import config
...
@@ -10,6 +10,7 @@ from pytensor.configdefaults import config
from
pytensor.tensor
import
nlinalg
as
pt_nlinalg
from
pytensor.tensor
import
nlinalg
as
pt_nlinalg
from
pytensor.tensor
import
slinalg
as
pt_slinalg
from
pytensor.tensor
import
slinalg
as
pt_slinalg
from
pytensor.tensor
import
subtensor
as
pt_subtensor
from
pytensor.tensor
import
subtensor
as
pt_subtensor
from
pytensor.tensor._linalg.solve
import
linear_control
from
pytensor.tensor.math
import
clip
,
cosh
from
pytensor.tensor.math
import
clip
,
cosh
from
pytensor.tensor.type
import
matrix
,
vector
from
pytensor.tensor.type
import
matrix
,
vector
from
tests.link.jax.test_basic
import
compare_jax_and_py
from
tests.link.jax.test_basic
import
compare_jax_and_py
...
@@ -275,7 +276,7 @@ def test_jax_solve_discrete_lyapunov(
...
@@ -275,7 +276,7 @@ def test_jax_solve_discrete_lyapunov(
):
):
A
=
pt
.
tensor
(
name
=
"A"
,
shape
=
shape
)
A
=
pt
.
tensor
(
name
=
"A"
,
shape
=
shape
)
B
=
pt
.
tensor
(
name
=
"B"
,
shape
=
shape
)
B
=
pt
.
tensor
(
name
=
"B"
,
shape
=
shape
)
out
=
pt_slinalg
.
solve_discrete_lyapunov
(
A
,
B
,
method
=
method
)
out
=
linear_control
.
solve_discrete_lyapunov
(
A
,
B
,
method
=
method
)
atol
=
rtol
=
1e-8
if
config
.
floatX
==
"float64"
else
1e-3
atol
=
rtol
=
1e-8
if
config
.
floatX
==
"float64"
else
1e-3
compare_jax_and_py
(
compare_jax_and_py
(
...
@@ -404,6 +405,6 @@ def test_jax_solve_sylvester():
...
@@ -404,6 +405,6 @@ def test_jax_solve_sylvester():
B_val
=
rng
.
normal
(
size
=
(
3
,
3
))
.
astype
(
config
.
floatX
)
B_val
=
rng
.
normal
(
size
=
(
3
,
3
))
.
astype
(
config
.
floatX
)
C_val
=
rng
.
normal
(
size
=
(
3
,
3
))
.
astype
(
config
.
floatX
)
C_val
=
rng
.
normal
(
size
=
(
3
,
3
))
.
astype
(
config
.
floatX
)
out
=
pt_slinalg
.
solve_sylvester
(
A
,
B
,
C
)
out
=
linear_control
.
solve_sylvester
(
A
,
B
,
C
)
compare_jax_and_py
([
A
,
B
,
C
],
[
out
],
[
A_val
,
B_val
,
C_val
])
compare_jax_and_py
([
A
,
B
,
C
],
[
out
],
[
A_val
,
B_val
,
C_val
])
tests/link/numba/linalg/solve/test_linear_control.py
浏览文件 @
3acb9b4f
...
@@ -3,6 +3,7 @@ import pytest
...
@@ -3,6 +3,7 @@ import pytest
from
pytensor
import
config
from
pytensor
import
config
from
pytensor
import
tensor
as
pt
from
pytensor
import
tensor
as
pt
from
pytensor.tensor._linalg.solve
import
linear_control
from
tests.link.numba.test_basic
import
compare_numba_and_py
from
tests.link.numba.test_basic
import
compare_numba_and_py
...
@@ -17,7 +18,7 @@ def test_solve_sylvester():
...
@@ -17,7 +18,7 @@ def test_solve_sylvester():
A
=
pt
.
matrix
(
"A"
)
A
=
pt
.
matrix
(
"A"
)
B
=
pt
.
matrix
(
"B"
)
B
=
pt
.
matrix
(
"B"
)
C
=
pt
.
matrix
(
"C"
)
C
=
pt
.
matrix
(
"C"
)
X
=
pt
.
linalg
.
solve_sylvester
(
A
,
B
,
C
)
X
=
linear_control
.
solve_sylvester
(
A
,
B
,
C
)
rng
=
np
.
random
.
default_rng
()
rng
=
np
.
random
.
default_rng
()
A_val
=
rng
.
normal
(
size
=
(
5
,
5
))
.
astype
(
floatX
)
A_val
=
rng
.
normal
(
size
=
(
5
,
5
))
.
astype
(
floatX
)
...
@@ -30,7 +31,7 @@ def test_solve_sylvester():
...
@@ -30,7 +31,7 @@ def test_solve_sylvester():
def
test_solve_continuous_lyapunov
():
def
test_solve_continuous_lyapunov
():
A
=
pt
.
matrix
(
"A"
)
A
=
pt
.
matrix
(
"A"
)
Q
=
pt
.
matrix
(
"Q"
)
Q
=
pt
.
matrix
(
"Q"
)
X
=
pt
.
linalg
.
solve_continuous_lyapunov
(
A
,
Q
)
X
=
linear_control
.
solve_continuous_lyapunov
(
A
,
Q
)
rng
=
np
.
random
.
default_rng
()
rng
=
np
.
random
.
default_rng
()
A_val
=
rng
.
normal
(
size
=
(
5
,
5
))
.
astype
(
floatX
)
A_val
=
rng
.
normal
(
size
=
(
5
,
5
))
.
astype
(
floatX
)
...
@@ -44,7 +45,7 @@ def test_solve_continuous_lyapunov():
...
@@ -44,7 +45,7 @@ def test_solve_continuous_lyapunov():
def
test_solve_discrete_lyapunov
(
method
):
def
test_solve_discrete_lyapunov
(
method
):
A
=
pt
.
matrix
(
"A"
)
A
=
pt
.
matrix
(
"A"
)
Q
=
pt
.
matrix
(
"Q"
)
Q
=
pt
.
matrix
(
"Q"
)
X
=
pt
.
linalg
.
solve_discrete_lyapunov
(
A
,
Q
,
method
=
method
)
X
=
linear_control
.
solve_discrete_lyapunov
(
A
,
Q
,
method
=
method
)
rng
=
np
.
random
.
default_rng
()
rng
=
np
.
random
.
default_rng
()
A_val
=
rng
.
normal
(
size
=
(
5
,
5
))
.
astype
(
floatX
)
A_val
=
rng
.
normal
(
size
=
(
5
,
5
))
.
astype
(
floatX
)
...
...
tests/tensor/test_slinalg.py
浏览文件 @
3acb9b4f
...
@@ -15,6 +15,12 @@ from pytensor.configdefaults import config
...
@@ -15,6 +15,12 @@ from pytensor.configdefaults import config
from
pytensor.graph.basic
import
equal_computations
from
pytensor.graph.basic
import
equal_computations
from
pytensor.link.numba
import
NumbaLinker
from
pytensor.link.numba
import
NumbaLinker
from
pytensor.tensor
import
TensorVariable
from
pytensor.tensor
import
TensorVariable
from
pytensor.tensor._linalg.solve.linear_control
import
(
solve_continuous_lyapunov
,
solve_discrete_are
,
solve_discrete_lyapunov
,
solve_sylvester
,
)
from
pytensor.tensor.slinalg
import
(
from
pytensor.tensor.slinalg
import
(
Cholesky
,
Cholesky
,
CholeskySolve
,
CholeskySolve
,
...
@@ -33,10 +39,6 @@ from pytensor.tensor.slinalg import (
...
@@ -33,10 +39,6 @@ from pytensor.tensor.slinalg import (
qr
,
qr
,
schur
,
schur
,
solve
,
solve
,
solve_continuous_lyapunov
,
solve_discrete_are
,
solve_discrete_lyapunov
,
solve_sylvester
,
solve_triangular
,
solve_triangular
,
)
)
from
pytensor.tensor.type
import
dmatrix
,
matrix
,
tensor
,
vector
from
pytensor.tensor.type
import
dmatrix
,
matrix
,
tensor
,
vector
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论