Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
d9889cca
提交
d9889cca
authored
1月 13, 2026
作者:
jessegrabowski
提交者:
Jesse Grabowski
1月 18, 2026
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
JAX dispatch for linear control Ops
上级
f7d1c644
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
24 行增加
和
23 行删除
+24
-23
slinalg.py
pytensor/link/jax/dispatch/slinalg.py
+9
-0
linalg.py
pytensor/tensor/rewriting/linalg.py
+0
-23
test_slinalg.py
tests/link/jax/test_slinalg.py
+15
-0
没有找到文件。
pytensor/link/jax/dispatch/slinalg.py
浏览文件 @
d9889cca
...
@@ -15,6 +15,7 @@ from pytensor.tensor.slinalg import (
...
@@ -15,6 +15,7 @@ from pytensor.tensor.slinalg import (
PivotToPermutations
,
PivotToPermutations
,
Schur
,
Schur
,
Solve
,
Solve
,
SolveSylvester
,
SolveTriangular
,
SolveTriangular
,
)
)
...
@@ -200,3 +201,11 @@ def jax_funcify_Schur(op, **kwargs):
...
@@ -200,3 +201,11 @@ def jax_funcify_Schur(op, **kwargs):
return
T
,
Z
return
T
,
Z
return
schur
return
schur
@jax_funcify.register
(
SolveSylvester
)
def
jax_funcify_SolveSylsterer
(
op
,
**
kwargs
):
def
solve_sylvester
(
a
,
b
,
c
):
return
jax
.
scipy
.
linalg
.
solve_sylvester
(
a
,
b
,
c
)
return
solve_sylvester
pytensor/tensor/rewriting/linalg.py
浏览文件 @
d9889cca
...
@@ -6,11 +6,9 @@ import numpy as np
...
@@ -6,11 +6,9 @@ import numpy as np
from
pytensor
import
Variable
from
pytensor
import
Variable
from
pytensor
import
tensor
as
pt
from
pytensor
import
tensor
as
pt
from
pytensor.compile
import
optdb
from
pytensor.graph
import
Apply
,
FunctionGraph
from
pytensor.graph
import
Apply
,
FunctionGraph
from
pytensor.graph.rewriting.basic
import
(
from
pytensor.graph.rewriting.basic
import
(
copy_stack_trace
,
copy_stack_trace
,
dfs_rewriter
,
node_rewriter
,
node_rewriter
,
)
)
from
pytensor.graph.rewriting.unify
import
OpPattern
from
pytensor.graph.rewriting.unify
import
OpPattern
...
@@ -55,12 +53,10 @@ from pytensor.tensor.slinalg import (
...
@@ -55,12 +53,10 @@ from pytensor.tensor.slinalg import (
LUFactor
,
LUFactor
,
Solve
,
Solve
,
SolveBase
,
SolveBase
,
SolveBilinearDiscreteLyapunov
,
SolveTriangular
,
SolveTriangular
,
block_diag
,
block_diag
,
cholesky
,
cholesky
,
solve
,
solve
,
solve_discrete_lyapunov
,
solve_triangular
,
solve_triangular
,
)
)
...
@@ -1045,25 +1041,6 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
...
@@ -1045,25 +1041,6 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
return
[
eye_input
*
(
non_eye_input
**
0.5
)]
return
[
eye_input
*
(
non_eye_input
**
0.5
)]
@node_rewriter
([
SolveBilinearDiscreteLyapunov
])
def
jax_bilinaer_lyapunov_to_direct
(
fgraph
:
FunctionGraph
,
node
:
Apply
):
"""
Replace SolveBilinearDiscreteLyapunov with a direct computation that is supported by JAX
"""
A
,
B
=
(
cast
(
TensorVariable
,
x
)
for
x
in
node
.
inputs
)
result
=
solve_discrete_lyapunov
(
A
,
B
,
method
=
"direct"
)
return
[
result
]
optdb
.
register
(
"jax_bilinaer_lyapunov_to_direct"
,
dfs_rewriter
(
jax_bilinaer_lyapunov_to_direct
),
"jax"
,
position
=
0.9
,
# Run before canonicalization
)
@register_specialize
@register_specialize
@node_rewriter
([
det
])
@node_rewriter
([
det
])
def
slogdet_specialization
(
fgraph
,
node
):
def
slogdet_specialization
(
fgraph
,
node
):
...
...
tests/link/jax/test_slinalg.py
浏览文件 @
d9889cca
...
@@ -392,3 +392,18 @@ def test_jax_schur(output):
...
@@ -392,3 +392,18 @@ def test_jax_schur(output):
T
,
Z
=
pt_slinalg
.
schur
(
A
,
output
=
output
)
T
,
Z
=
pt_slinalg
.
schur
(
A
,
output
=
output
)
compare_jax_and_py
([
A
],
[
T
,
Z
],
[
A_val
])
compare_jax_and_py
([
A
],
[
T
,
Z
],
[
A_val
])
def
test_jax_solve_sylvester
():
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
A
=
pt
.
tensor
(
name
=
"A"
,
shape
=
(
3
,
3
))
B
=
pt
.
tensor
(
name
=
"B"
,
shape
=
(
3
,
3
))
C
=
pt
.
tensor
(
name
=
"C"
,
shape
=
(
3
,
3
))
A_val
=
rng
.
normal
(
size
=
(
3
,
3
))
.
astype
(
config
.
floatX
)
B_val
=
rng
.
normal
(
size
=
(
3
,
3
))
.
astype
(
config
.
floatX
)
C_val
=
rng
.
normal
(
size
=
(
3
,
3
))
.
astype
(
config
.
floatX
)
out
=
pt_slinalg
.
solve_sylvester
(
A
,
B
,
C
)
compare_jax_and_py
([
A
,
B
,
C
],
[
out
],
[
A_val
,
B_val
,
C_val
])
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论