Unverified 提交 f72d7e58 authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Add JAX dispatch for `CholeskySolve` `Op` (#1491)

* Add jax dispatch for CholeskySolve * Better typehints on user-facing `cho_solve` * Rename test
上级 d3bbc20a
...@@ -7,6 +7,7 @@ from pytensor.tensor.slinalg import ( ...@@ -7,6 +7,7 @@ from pytensor.tensor.slinalg import (
LU, LU,
BlockDiagonal, BlockDiagonal,
Cholesky, Cholesky,
CholeskySolve,
Eigvalsh, Eigvalsh,
LUFactor, LUFactor,
PivotToPermutations, PivotToPermutations,
...@@ -153,3 +154,17 @@ def jax_funcify_LUFactor(op, **kwargs): ...@@ -153,3 +154,17 @@ def jax_funcify_LUFactor(op, **kwargs):
) )
return lu_factor return lu_factor
@jax_funcify.register(CholeskySolve)
def jax_funcify_ChoSolve(op, **kwargs):
lower = op.lower
check_finite = op.check_finite
overwrite_b = op.overwrite_b
def cho_solve(c, b):
return jax.scipy.linalg.cho_solve(
(c, lower), b, check_finite=check_finite, overwrite_b=overwrite_b
)
return cho_solve
...@@ -376,14 +376,20 @@ class CholeskySolve(SolveBase): ...@@ -376,14 +376,20 @@ class CholeskySolve(SolveBase):
return self return self
def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None): def cho_solve(
c_and_lower: tuple[TensorLike, bool],
b: TensorLike,
*,
check_finite: bool = True,
b_ndim: int | None = None,
):
"""Solve the linear equations A x = b, given the Cholesky factorization of A. """Solve the linear equations A x = b, given the Cholesky factorization of A.
Parameters Parameters
---------- ----------
(c, lower) : tuple, (array, bool) c_and_lower : tuple of (TensorLike, bool)
Cholesky factorization of a, as given by cho_factor Cholesky factorization of a, as given by cho_factor
b : array b : TensorLike
Right-hand side Right-hand side
check_finite : bool, optional check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers. Whether to check that the input matrices contain only finite numbers.
......
...@@ -333,3 +333,19 @@ def test_jax_lu_solve(b_shape): ...@@ -333,3 +333,19 @@ def test_jax_lu_solve(b_shape):
out = pt_slinalg.lu_solve(lu_and_pivots, b) out = pt_slinalg.lu_solve(lu_and_pivots, b)
compare_jax_and_py([A, b], [out], [A_val, b_val]) compare_jax_and_py([A, b], [out], [A_val, b_val])
@pytest.mark.parametrize("b_shape, lower", [((5,), True), ((5, 5), False)])
def test_jax_cho_solve(b_shape, lower):
rng = np.random.default_rng(utt.fetch_seed())
L_val = rng.normal(size=(5, 5)).astype(config.floatX)
A_val = (L_val @ L_val.T).astype(config.floatX)
b_val = rng.normal(size=b_shape).astype(config.floatX)
A = pt.tensor(name="A", shape=(5, 5))
b = pt.tensor(name="b", shape=b_shape)
c = pt_slinalg.cholesky(A, lower=lower)
out = pt_slinalg.cho_solve((c, lower), b, b_ndim=len(b_shape))
compare_jax_and_py([A, b], [out], [A_val, b_val])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论