提交 d9889cca authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Jesse Grabowski

JAX dispatch for linear control Ops

上级 f7d1c644
......@@ -15,6 +15,7 @@ from pytensor.tensor.slinalg import (
PivotToPermutations,
Schur,
Solve,
SolveSylvester,
SolveTriangular,
)
......@@ -200,3 +201,11 @@ def jax_funcify_Schur(op, **kwargs):
return T, Z
return schur
@jax_funcify.register(SolveSylvester)
def jax_funcify_SolveSylsterer(op, **kwargs):
def solve_sylvester(a, b, c):
return jax.scipy.linalg.solve_sylvester(a, b, c)
return solve_sylvester
......@@ -6,11 +6,9 @@ import numpy as np
from pytensor import Variable
from pytensor import tensor as pt
from pytensor.compile import optdb
from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.rewriting.basic import (
copy_stack_trace,
dfs_rewriter,
node_rewriter,
)
from pytensor.graph.rewriting.unify import OpPattern
......@@ -55,12 +53,10 @@ from pytensor.tensor.slinalg import (
LUFactor,
Solve,
SolveBase,
SolveBilinearDiscreteLyapunov,
SolveTriangular,
block_diag,
cholesky,
solve,
solve_discrete_lyapunov,
solve_triangular,
)
......@@ -1045,25 +1041,6 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
return [eye_input * (non_eye_input**0.5)]
@node_rewriter([SolveBilinearDiscreteLyapunov])
def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
"""
Replace SolveBilinearDiscreteLyapunov with a direct computation that is supported by JAX
"""
A, B = (cast(TensorVariable, x) for x in node.inputs)
result = solve_discrete_lyapunov(A, B, method="direct")
return [result]
optdb.register(
"jax_bilinaer_lyapunov_to_direct",
dfs_rewriter(jax_bilinaer_lyapunov_to_direct),
"jax",
position=0.9, # Run before canonicalization
)
@register_specialize
@node_rewriter([det])
def slogdet_specialization(fgraph, node):
......
......@@ -392,3 +392,18 @@ def test_jax_schur(output):
T, Z = pt_slinalg.schur(A, output=output)
compare_jax_and_py([A], [T, Z], [A_val])
def test_jax_solve_sylvester():
rng = np.random.default_rng(utt.fetch_seed())
A = pt.tensor(name="A", shape=(3, 3))
B = pt.tensor(name="B", shape=(3, 3))
C = pt.tensor(name="C", shape=(3, 3))
A_val = rng.normal(size=(3, 3)).astype(config.floatX)
B_val = rng.normal(size=(3, 3)).astype(config.floatX)
C_val = rng.normal(size=(3, 3)).astype(config.floatX)
out = pt_slinalg.solve_sylvester(A, B, C)
compare_jax_and_py([A, B, C], [out], [A_val, B_val, C_val])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论