提交 5335a680 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix lu_solve with batch inputs

上级 040410f4
import logging import logging
import warnings import warnings
from collections.abc import Sequence from collections.abc import Sequence
from functools import reduce from functools import partial, reduce
from typing import Literal, cast from typing import Literal, cast
import numpy as np import numpy as np
...@@ -589,6 +589,7 @@ def lu( ...@@ -589,6 +589,7 @@ def lu(
class PivotToPermutations(Op): class PivotToPermutations(Op):
gufunc_signature = "(x)->(x)"
__props__ = ("inverse",) __props__ = ("inverse",)
def __init__(self, inverse=True): def __init__(self, inverse=True):
...@@ -723,40 +724,22 @@ def lu_factor( ...@@ -723,40 +724,22 @@ def lu_factor(
) )
def lu_solve( def _lu_solve(
LU_and_pivots: tuple[TensorLike, TensorLike], LU: TensorLike,
pivots: TensorLike,
b: TensorLike, b: TensorLike,
trans: bool = False, trans: bool = False,
b_ndim: int | None = None, b_ndim: int | None = None,
check_finite: bool = True, check_finite: bool = True,
overwrite_b: bool = False,
): ):
"""
Solve a system of linear equations given the LU decomposition of the matrix.
Parameters
----------
LU_and_pivots: tuple[TensorLike, TensorLike]
LU decomposition of the matrix, as returned by `lu_factor`
b: TensorLike
Right-hand side of the equation
trans: bool
If True, solve A^T x = b, instead of Ax = b. Default is False
b_ndim: int, optional
The number of core dimensions in b. Used to distinguish between a batch of vectors (b_ndim=1) and a matrix
of vectors (b_ndim=2). Default is None, which will infer the number of core dimensions from the input.
check_finite: bool
If True, check that the input matrices contain only finite numbers. Default is True.
overwrite_b: bool
Ignored by Pytensor. Pytensor will always compute inplace when possible.
"""
b_ndim = _default_b_ndim(b, b_ndim) b_ndim = _default_b_ndim(b, b_ndim)
LU, pivots = LU_and_pivots
LU, pivots, b = map(pt.as_tensor_variable, [LU, pivots, b]) LU, pivots, b = map(pt.as_tensor_variable, [LU, pivots, b])
inv_permutation = pivot_to_permutation(pivots, inverse=True)
inv_permutation = pivot_to_permutation(pivots, inverse=True)
x = b[inv_permutation] if not trans else b x = b[inv_permutation] if not trans else b
# TODO: Use PermuteRows on b
# x = permute_rows(b, pivots) if not trans else b
x = solve_triangular( x = solve_triangular(
LU, LU,
...@@ -777,11 +760,52 @@ def lu_solve( ...@@ -777,11 +760,52 @@ def lu_solve(
b_ndim=b_ndim, b_ndim=b_ndim,
check_finite=check_finite, check_finite=check_finite,
) )
x = x[pt.argsort(inv_permutation)] if trans else x
# TODO: Use PermuteRows(inverse=True) on x
# if trans:
# x = permute_rows(x, pivots, inverse=True)
x = x[pt.argsort(inv_permutation)] if trans else x
return x return x
def lu_solve(
LU_and_pivots: tuple[TensorLike, TensorLike],
b: TensorLike,
trans: bool = False,
b_ndim: int | None = None,
check_finite: bool = True,
overwrite_b: bool = False,
):
"""
Solve a system of linear equations given the LU decomposition of the matrix.
Parameters
----------
LU_and_pivots: tuple[TensorLike, TensorLike]
LU decomposition of the matrix, as returned by `lu_factor`
b: TensorLike
Right-hand side of the equation
trans: bool
If True, solve A^T x = b, instead of Ax = b. Default is False
b_ndim: int, optional
The number of core dimensions in b. Used to distinguish between a batch of vectors (b_ndim=1) and a matrix
of vectors (b_ndim=2). Default is None, which will infer the number of core dimensions from the input.
check_finite: bool
If True, check that the input matrices contain only finite numbers. Default is True.
overwrite_b: bool
Ignored by Pytensor. Pytensor will always compute inplace when possible.
"""
b_ndim = _default_b_ndim(b, b_ndim)
if b_ndim == 1:
signature = "(m,m),(m),(m)->(m)"
else:
signature = "(m,m),(m),(m,n)->(m,n)"
partialled_func = partial(
_lu_solve, trans=trans, b_ndim=b_ndim, check_finite=check_finite
)
return pt.vectorize(partialled_func, signature=signature)(*LU_and_pivots, b)
class SolveTriangular(SolveBase): class SolveTriangular(SolveBase):
"""Solve a system of linear equations.""" """Solve a system of linear equations."""
......
...@@ -737,6 +737,22 @@ class TestLUSolve(utt.InferShapeTester): ...@@ -737,6 +737,22 @@ class TestLUSolve(utt.InferShapeTester):
test_fn = functools.partial(self.factor_and_solve, sum=True, trans=trans) test_fn = functools.partial(self.factor_and_solve, sum=True, trans=trans)
utt.verify_grad(test_fn, [A_val, b_val], 3, rng) utt.verify_grad(test_fn, [A_val, b_val], 3, rng)
def test_lu_solve_batch_dims(self):
A = pt.tensor("A", shape=(3, 1, 5, 5))
b = pt.tensor("b", shape=(1, 4, 5))
lu_and_pivots = lu_factor(A)
x = lu_solve(lu_and_pivots, b, b_ndim=1)
assert x.type.shape in {(3, 4, None), (3, 4, 5)}
rng = np.random.default_rng(748)
A_test = rng.random(A.type.shape).astype(A.type.dtype)
b_test = rng.random(b.type.shape).astype(b.type.dtype)
np.testing.assert_allclose(
x.eval({A: A_test, b: b_test}),
solve(A, b, b_ndim=1).eval({A: A_test, b: b_test}),
rtol=1e-9 if config.floatX == "float64" else 1e-5,
)
def test_lu_factor(): def test_lu_factor():
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论