提交 679b2f71 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Jesse Grabowski

JAX dispatches for LU Ops

上级 1aa9a396
...@@ -4,9 +4,12 @@ import jax ...@@ -4,9 +4,12 @@ import jax
from pytensor.link.jax.dispatch.basic import jax_funcify from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor.slinalg import ( from pytensor.tensor.slinalg import (
LU,
BlockDiagonal, BlockDiagonal,
Cholesky, Cholesky,
Eigvalsh, Eigvalsh,
LUFactor,
PivotToPermutations,
Solve, Solve,
SolveTriangular, SolveTriangular,
) )
...@@ -93,3 +96,46 @@ def jax_funcify_BlockDiagonalMatrix(op, **kwargs): ...@@ -93,3 +96,46 @@ def jax_funcify_BlockDiagonalMatrix(op, **kwargs):
return jax.scipy.linalg.block_diag(*inputs) return jax.scipy.linalg.block_diag(*inputs)
return block_diag return block_diag
@jax_funcify.register(PivotToPermutations)
def jax_funcify_PivotToPermutation(op, **kwargs):
inverse = op.inverse
def pivot_to_permutations(pivots):
p_inv = jax.lax.linalg.lu_pivots_to_permutation(pivots, pivots.shape[0])
if inverse:
return p_inv
return jax.numpy.argsort(p_inv)
return pivot_to_permutations
@jax_funcify.register(LU)
def jax_funcify_LU(op, **kwargs):
permute_l = op.permute_l
p_indices = op.p_indices
check_finite = op.check_finite
if p_indices:
raise ValueError("JAX does not support the p_indices argument")
def lu(*inputs):
return jax.scipy.linalg.lu(
*inputs, permute_l=permute_l, check_finite=check_finite
)
return lu
@jax_funcify.register(LUFactor)
def jax_funcify_LUFactor(op, **kwargs):
check_finite = op.check_finite
overwrite_a = op.overwrite_a
def lu_factor(a):
return jax.scipy.linalg.lu_factor(
a, check_finite=check_finite, overwrite_a=overwrite_a
)
return lu_factor
...@@ -10,9 +10,8 @@ from numpy.exceptions import ComplexWarning ...@@ -10,9 +10,8 @@ from numpy.exceptions import ComplexWarning
import pytensor import pytensor
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply, Variable from pytensor.graph.basic import Apply
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.tensor import TensorLike, as_tensor_variable from pytensor.tensor import TensorLike, as_tensor_variable
from pytensor.tensor import basic as ptb from pytensor.tensor import basic as ptb
...@@ -616,7 +615,7 @@ class PivotToPermutations(Op): ...@@ -616,7 +615,7 @@ class PivotToPermutations(Op):
outputs[0][0] = np.argsort(p_inv) outputs[0][0] = np.argsort(p_inv)
def pivot_to_permutation(p: TensorLike, inverse=False) -> Variable: def pivot_to_permutation(p: TensorLike, inverse=False):
p = pt.as_tensor_variable(p) p = pt.as_tensor_variable(p)
return PivotToPermutations(inverse=inverse)(p) return PivotToPermutations(inverse=inverse)(p)
...@@ -724,29 +723,6 @@ def lu_factor( ...@@ -724,29 +723,6 @@ def lu_factor(
) )
class LUSolve(OpFromGraph):
"""Solve a system of linear equations given the LU decomposition of the matrix."""
__props__ = ("trans", "b_ndim", "check_finite", "overwrite_b")
def __init__(
self,
inputs: list[Variable],
outputs: list[Variable],
trans: bool = False,
b_ndim: int | None = None,
check_finite: bool = False,
overwrite_b: bool = False,
**kwargs,
):
self.trans = trans
self.b_ndim = b_ndim
self.check_finite = check_finite
self.overwrite_b = overwrite_b
super().__init__(inputs=inputs, outputs=outputs, **kwargs)
def lu_solve( def lu_solve(
LU_and_pivots: tuple[TensorLike, TensorLike], LU_and_pivots: tuple[TensorLike, TensorLike],
b: TensorLike, b: TensorLike,
......
...@@ -228,3 +228,60 @@ def test_jax_solve_discrete_lyapunov( ...@@ -228,3 +228,60 @@ def test_jax_solve_discrete_lyapunov(
jax_mode="JAX", jax_mode="JAX",
assert_fn=partial(np.testing.assert_allclose, atol=atol, rtol=rtol), assert_fn=partial(np.testing.assert_allclose, atol=atol, rtol=rtol),
) )
@pytest.mark.parametrize(
"permute_l, p_indices",
[(True, False), (False, True), (False, False)],
ids=["PL", "p_indices", "P"],
)
@pytest.mark.parametrize("complex", [False, True], ids=["real", "complex"])
@pytest.mark.parametrize("shape", [(3, 5, 5), (5, 5)], ids=["batched", "not_batched"])
def test_jax_lu(permute_l, p_indices, complex, shape: tuple[int]):
rng = np.random.default_rng()
A = pt.tensor(
"A",
shape=shape,
dtype=f"complex{int(config.floatX[-2:]) * 2}" if complex else config.floatX,
)
out = pt_slinalg.lu(A, permute_l=permute_l, p_indices=p_indices)
x = rng.normal(size=shape).astype(config.floatX)
if complex:
x = x + 1j * rng.normal(size=shape).astype(config.floatX)
if p_indices:
with pytest.raises(
ValueError, match="JAX does not support the p_indices argument"
):
compare_jax_and_py(graph_inputs=[A], graph_outputs=out, test_inputs=[x])
else:
compare_jax_and_py(graph_inputs=[A], graph_outputs=out, test_inputs=[x])
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batch"])
def test_jax_lu_factor(shape):
rng = np.random.default_rng(utt.fetch_seed())
A = pt.tensor(name="A", shape=shape)
A_value = rng.normal(size=shape).astype(config.floatX)
out = pt_slinalg.lu_factor(A)
compare_jax_and_py(
[A],
out,
[A_value],
)
@pytest.mark.parametrize("b_shape", [(5,), (5, 5)])
def test_jax_lu_solve(b_shape):
rng = np.random.default_rng(utt.fetch_seed())
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
b_val = rng.normal(size=b_shape).astype(config.floatX)
A = pt.tensor(name="A", shape=(5, 5))
b = pt.tensor(name="b", shape=b_shape)
lu_and_pivots = pt_slinalg.lu_factor(A)
out = pt_slinalg.lu_solve(lu_and_pivots, b)
compare_jax_and_py([A, b], [out], [A_val, b_val])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论