提交 d9889cca authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Jesse Grabowski

JAX dispatch for linear control Ops

上级 f7d1c644
...@@ -15,6 +15,7 @@ from pytensor.tensor.slinalg import ( ...@@ -15,6 +15,7 @@ from pytensor.tensor.slinalg import (
PivotToPermutations, PivotToPermutations,
Schur, Schur,
Solve, Solve,
SolveSylvester,
SolveTriangular, SolveTriangular,
) )
...@@ -200,3 +201,11 @@ def jax_funcify_Schur(op, **kwargs): ...@@ -200,3 +201,11 @@ def jax_funcify_Schur(op, **kwargs):
return T, Z return T, Z
return schur return schur
@jax_funcify.register(SolveSylvester)
def jax_funcify_SolveSylsterer(op, **kwargs):
def solve_sylvester(a, b, c):
return jax.scipy.linalg.solve_sylvester(a, b, c)
return solve_sylvester
...@@ -6,11 +6,9 @@ import numpy as np ...@@ -6,11 +6,9 @@ import numpy as np
from pytensor import Variable from pytensor import Variable
from pytensor import tensor as pt from pytensor import tensor as pt
from pytensor.compile import optdb
from pytensor.graph import Apply, FunctionGraph from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.rewriting.basic import ( from pytensor.graph.rewriting.basic import (
copy_stack_trace, copy_stack_trace,
dfs_rewriter,
node_rewriter, node_rewriter,
) )
from pytensor.graph.rewriting.unify import OpPattern from pytensor.graph.rewriting.unify import OpPattern
...@@ -55,12 +53,10 @@ from pytensor.tensor.slinalg import ( ...@@ -55,12 +53,10 @@ from pytensor.tensor.slinalg import (
LUFactor, LUFactor,
Solve, Solve,
SolveBase, SolveBase,
SolveBilinearDiscreteLyapunov,
SolveTriangular, SolveTriangular,
block_diag, block_diag,
cholesky, cholesky,
solve, solve,
solve_discrete_lyapunov,
solve_triangular, solve_triangular,
) )
...@@ -1045,25 +1041,6 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node): ...@@ -1045,25 +1041,6 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
return [eye_input * (non_eye_input**0.5)] return [eye_input * (non_eye_input**0.5)]
@node_rewriter([SolveBilinearDiscreteLyapunov])
def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
"""
Replace SolveBilinearDiscreteLyapunov with a direct computation that is supported by JAX
"""
A, B = (cast(TensorVariable, x) for x in node.inputs)
result = solve_discrete_lyapunov(A, B, method="direct")
return [result]
optdb.register(
"jax_bilinaer_lyapunov_to_direct",
dfs_rewriter(jax_bilinaer_lyapunov_to_direct),
"jax",
position=0.9, # Run before canonicalization
)
@register_specialize @register_specialize
@node_rewriter([det]) @node_rewriter([det])
def slogdet_specialization(fgraph, node): def slogdet_specialization(fgraph, node):
......
...@@ -392,3 +392,18 @@ def test_jax_schur(output): ...@@ -392,3 +392,18 @@ def test_jax_schur(output):
T, Z = pt_slinalg.schur(A, output=output) T, Z = pt_slinalg.schur(A, output=output)
compare_jax_and_py([A], [T, Z], [A_val]) compare_jax_and_py([A], [T, Z], [A_val])
def test_jax_solve_sylvester():
rng = np.random.default_rng(utt.fetch_seed())
A = pt.tensor(name="A", shape=(3, 3))
B = pt.tensor(name="B", shape=(3, 3))
C = pt.tensor(name="C", shape=(3, 3))
A_val = rng.normal(size=(3, 3)).astype(config.floatX)
B_val = rng.normal(size=(3, 3)).astype(config.floatX)
C_val = rng.normal(size=(3, 3)).astype(config.floatX)
out = pt_slinalg.solve_sylvester(A, B, C)
compare_jax_and_py([A, B, C], [out], [A_val, B_val, C_val])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论