提交 9834e96c authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

Restore jax bilinear_to_direct rewrite when `solve_sylvester` is not available

上级 ca26e81a
...@@ -6,6 +6,10 @@ from pytensor.graph import Constant, graph_inputs ...@@ -6,6 +6,10 @@ from pytensor.graph import Constant, graph_inputs
from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter, node_rewriter from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter, node_rewriter
from pytensor.scan.op import Scan from pytensor.scan.op import Scan
from pytensor.scan.rewriting import scan_seqopt1 from pytensor.scan.rewriting import scan_seqopt1
from pytensor.tensor._linalg.solve.linear_control import (
SolveBilinearDiscreteLyapunov,
solve_discrete_lyapunov,
)
from pytensor.tensor._linalg.solve.tridiagonal import ( from pytensor.tensor._linalg.solve.tridiagonal import (
tridiagonal_lu_factor, tridiagonal_lu_factor,
tridiagonal_lu_solve, tridiagonal_lu_solve,
...@@ -269,3 +273,36 @@ scan_seqopt1.register( ...@@ -269,3 +273,36 @@ scan_seqopt1.register(
use_db_name_as_tag=False, use_db_name_as_tag=False,
position=2, position=2,
) )
def _load_solve_sylvester():
# Thin import wrapper to help with testing
from jax.scipy.linalg import solve_sylvester
return solve_sylvester
@node_rewriter([SolveBilinearDiscreteLyapunov])
def jax_bilinear_lyapunov_to_direct(fgraph, node):
"""
Replace SolveBilinearDiscreteLyapunov with a direct computation that is supported by JAX < 0.8
"""
try:
_load_solve_sylvester()
return None
except ImportError:
# solve_sylvester is only available in jax > 0.8, which is not available on conda-forge.
# If it's not available, we can drop back to method="direct"
A, B = node.inputs
result = solve_discrete_lyapunov(A, B, method="direct")
return [result]
optdb.register(
"jax_bilinear_lyapunov_to_direct",
dfs_rewriter(jax_bilinear_lyapunov_to_direct),
"jax",
position=0.9, # Run before canonicalization
)
...@@ -291,6 +291,37 @@ def test_jax_solve_discrete_lyapunov( ...@@ -291,6 +291,37 @@ def test_jax_solve_discrete_lyapunov(
) )
def test_bilinear_to_direct_rewrite(monkeypatch):
mock_called = []
def mock_load_solve_sylvester():
mock_called.append(True)
raise ImportError("Simulated ImportError for testing.")
monkeypatch.setattr(
"pytensor.tensor._linalg.solve.rewriting._load_solve_sylvester",
mock_load_solve_sylvester,
)
A = pt.tensor(name="A", shape=(3, 3))
B = pt.tensor(name="B", shape=(3, 3))
out = linear_control.solve_discrete_lyapunov(A, B, method="bilinear")
atol = rtol = 1e-8 if config.floatX == "float64" else 1e-3
compare_jax_and_py(
[A, B],
[out],
[
np.random.normal(size=(3, 3)).astype(config.floatX),
np.random.normal(size=(3, 3)).astype(config.floatX),
],
jax_mode="JAX",
assert_fn=partial(np.testing.assert_allclose, atol=atol, rtol=rtol),
)
assert mock_called
@pytest.mark.parametrize( @pytest.mark.parametrize(
"permute_l, p_indices", "permute_l, p_indices",
[(True, False), (False, True), (False, False)], [(True, False), (False, True), (False, False)],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论