提交 1aa9a396 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Jesse Grabowski

New Ops related to LU decomposition

上级 ee884b87
......@@ -10,7 +10,9 @@ from numpy.exceptions import ComplexWarning
import pytensor
import pytensor.tensor as pt
from pytensor.graph.basic import Apply
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.op import Op
from pytensor.tensor import TensorLike, as_tensor_variable
from pytensor.tensor import basic as ptb
......@@ -225,6 +227,7 @@ class SolveBase(Op):
):
self.lower = lower
self.check_finite = check_finite
assert b_ndim in (1, 2)
self.b_ndim = b_ndim
if b_ndim == 1:
......@@ -302,10 +305,14 @@ class SolveBase(Op):
solve_op = type(self)(**props_dict)
b_bar = solve_op(A.T, c_bar)
b_bar = solve_op(A.mT, c_bar)
# force outer product if vector second input
A_bar = -ptm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)
if props_dict.get("unit_diagonal", False):
n = A_bar.shape[-1]
A_bar = A_bar[pt.arange(n), pt.arange(n)].set(pt.zeros(n))
return [A_bar, b_bar]
......@@ -394,6 +401,411 @@ def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None):
)(A, b)
class LU(Op):
"""Decompose a matrix into lower and upper triangular matrices."""
__props__ = ("permute_l", "overwrite_a", "check_finite", "p_indices")
def __init__(
self, *, permute_l=False, overwrite_a=False, check_finite=True, p_indices=False
):
if permute_l and p_indices:
raise ValueError("Only one of permute_l and p_indices can be True")
self.permute_l = permute_l
self.check_finite = check_finite
self.p_indices = p_indices
self.overwrite_a = overwrite_a
if self.permute_l:
# permute_l overrides p_indices in the scipy function. We can copy that behavior
self.gufunc_signature = "(m,m)->(m,m),(m,m)"
elif self.p_indices:
self.gufunc_signature = "(m,m)->(m),(m,m),(m,m)"
else:
self.gufunc_signature = "(m,m)->(m,m),(m,m),(m,m)"
if self.overwrite_a:
self.destroy_map = {0: [0]} if self.permute_l else {1: [0]}
def infer_shape(self, fgraph, node, shapes):
n = shapes[0][0]
if self.permute_l:
return [(n, n), (n, n)]
elif self.p_indices:
return [(n,), (n, n), (n, n)]
else:
return [(n, n), (n, n), (n, n)]
def make_node(self, x):
x = as_tensor_variable(x)
if x.type.ndim != 2:
raise TypeError(
f"LU only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input"
)
real_dtype = "f" if np.dtype(x.type.dtype).char in "fF" else "d"
p_dtype = "int32" if self.p_indices else np.dtype(real_dtype)
L = tensor(shape=x.type.shape, dtype=x.type.dtype)
U = tensor(shape=x.type.shape, dtype=x.type.dtype)
if self.permute_l:
# In this case, L is actually P @ L
return Apply(self, inputs=[x], outputs=[L, U])
if self.p_indices:
p_indices = tensor(shape=(x.type.shape[0],), dtype=p_dtype)
return Apply(self, inputs=[x], outputs=[p_indices, L, U])
P = tensor(shape=x.type.shape, dtype=p_dtype)
return Apply(self, inputs=[x], outputs=[P, L, U])
def perform(self, node, inputs, outputs):
[A] = inputs
out = scipy_linalg.lu(
A,
permute_l=self.permute_l,
overwrite_a=self.overwrite_a,
check_finite=self.check_finite,
p_indices=self.p_indices,
)
outputs[0][0] = out[0]
outputs[1][0] = out[1]
if not self.permute_l:
# In all cases except permute_l, there are three returns
outputs[2][0] = out[2]
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
if 0 in allowed_inplace_inputs:
new_props = self._props_dict() # type: ignore
new_props["overwrite_a"] = True
return type(self)(**new_props)
else:
return self
def L_op(
self,
inputs: Sequence[ptb.Variable],
outputs: Sequence[ptb.Variable],
output_grads: Sequence[ptb.Variable],
) -> list[ptb.Variable]:
r"""
Derivation is due to Differentiation of Matrix Functionals Using Triangular Factorization
F. R. De Hoog, R.S. Anderssen, M. A. Lukas
"""
[A] = inputs
A = cast(TensorVariable, A)
if self.permute_l:
# P has no gradient contribution (by assumption...), so PL_bar is the same as L_bar
L_bar, U_bar = output_grads
# TODO: Rewrite into permute_l = False for graphs where we need to compute the gradient
# We need L, not PL. It's not possible to recover it from PL, though. So we need to do a new forward pass
P_or_indices, L, U = lu( # type: ignore
A, permute_l=False, check_finite=self.check_finite, p_indices=False
)
else:
# In both other cases, there are 3 outputs. The first output will either be the permutation index itself,
# or indices that can be used to reconstruct the permutation matrix.
P_or_indices, L, U = outputs
_, L_bar, U_bar = output_grads
L_bar = (
L_bar if not isinstance(L_bar.type, DisconnectedType) else pt.zeros_like(A)
)
U_bar = (
U_bar if not isinstance(U_bar.type, DisconnectedType) else pt.zeros_like(A)
)
x1 = ptb.tril(L.T @ L_bar, k=-1)
x2 = ptb.triu(U_bar @ U.T)
LT_inv_x = solve_triangular(L.T, x1 + x2, lower=False, unit_diagonal=True)
# Where B = P.T @ A is a change of variable to avoid the permutation matrix in the gradient derivation
B_bar = solve_triangular(U, LT_inv_x.T, lower=False).T
if not self.p_indices:
A_bar = P_or_indices @ B_bar
else:
A_bar = B_bar[P_or_indices]
return [A_bar]
def lu(
a: TensorLike,
permute_l=False,
check_finite=True,
p_indices=False,
overwrite_a: bool = False,
) -> (
tuple[TensorVariable, TensorVariable, TensorVariable]
| tuple[TensorVariable, TensorVariable]
):
"""
Factorize a matrix as the product of a unit lower triangular matrix and an upper triangular matrix:
... math::
A = P L U
Where P is a permutation matrix, L is lower triangular with unit diagonal elements, and U is upper triangular.
Parameters
----------
a: TensorLike
Matrix to be factorized
permute_l: bool
If True, L is a product of permutation and unit lower triangular matrices. Only two values, PL and U, will
be returned in this case, and PL will not be lower triangular.
check_finite: bool
Whether to check that the input matrix contains only finite numbers.
p_indices: bool
If True, return integer matrix indices for the permutation matrix. Otherwise, return the permutation matrix
itself.
overwrite_a: bool
Ignored by Pytensor. Pytensor will always perform computation inplace if possible.
Returns
-------
P: TensorVariable
Permutation matrix, or array of integer indices for permutation matrix. Not returned if permute_l is True.
L: TensorVariable
Lower triangular matrix, or product of permutation and unit lower triangular matrices if permute_l is True.
U: TensorVariable
Upper triangular matrix
"""
return cast(
tuple[TensorVariable, TensorVariable, TensorVariable]
| tuple[TensorVariable, TensorVariable],
Blockwise(
LU(permute_l=permute_l, p_indices=p_indices, check_finite=check_finite)
)(a),
)
class PivotToPermutations(Op):
__props__ = ("inverse",)
def __init__(self, inverse=True):
self.inverse = inverse
def make_node(self, pivots):
pivots = as_tensor_variable(pivots)
if pivots.ndim != 1:
raise ValueError("PivotToPermutations only works on 1-D inputs")
permutations = pivots.type.clone(dtype="int64")()
return Apply(self, [pivots], [permutations])
def perform(self, node, inputs, outputs):
[pivots] = inputs
p_inv = np.arange(len(pivots), dtype=pivots.dtype)
for i in range(len(pivots)):
p_inv[i], p_inv[pivots[i]] = p_inv[pivots[i]], p_inv[i]
if self.inverse:
outputs[0][0] = p_inv
else:
outputs[0][0] = np.argsort(p_inv)
def pivot_to_permutation(p: TensorLike, inverse=False) -> Variable:
p = pt.as_tensor_variable(p)
return PivotToPermutations(inverse=inverse)(p)
class LUFactor(Op):
__props__ = ("overwrite_a", "check_finite")
gufunc_signature = "(m,m)->(m,m),(m)"
def __init__(self, *, overwrite_a=False, check_finite=True):
self.overwrite_a = overwrite_a
self.check_finite = check_finite
if self.overwrite_a:
self.destroy_map = {1: [0]}
def make_node(self, A):
A = as_tensor_variable(A)
if A.type.ndim != 2:
raise TypeError(
f"LU only allowed on matrix (2-D) inputs, got {A.type.ndim}-D input"
)
LU = matrix(shape=A.type.shape, dtype=A.type.dtype)
pivots = vector(shape=(A.type.shape[0],), dtype="int64")
return Apply(self, [A], [LU, pivots])
def infer_shape(self, fgraph, node, shapes):
n = shapes[0][0]
return [(n, n), (n,)]
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
if 0 in allowed_inplace_inputs:
new_props = self._props_dict() # type: ignore
new_props["overwrite_a"] = True
return type(self)(**new_props)
else:
return self
def perform(self, node, inputs, outputs):
A = inputs[0]
LU, p = scipy_linalg.lu_factor(
A, overwrite_a=self.overwrite_a, check_finite=self.check_finite
)
outputs[0][0] = LU
outputs[1][0] = p
def L_op(self, inputs, outputs, output_gradients):
[A] = inputs
LU_bar, _ = output_gradients
LU, p_indices = outputs
eye = ptb.identity_like(A)
L = cast(TensorVariable, ptb.tril(LU, k=-1) + eye)
U = cast(TensorVariable, ptb.triu(LU))
p_indices = pivot_to_permutation(p_indices, inverse=False)
# Split LU_bar into L_bar and U_bar. This is valid because of the triangular structure of L and U
L_bar = ptb.tril(LU_bar, k=-1)
U_bar = ptb.triu(LU_bar)
# From here we're in the same situation as the LU gradient derivation
x1 = ptb.tril(L.T @ L_bar, k=-1)
x2 = ptb.triu(U_bar @ U.T)
LT_inv_x = solve_triangular(L.T, x1 + x2, lower=False, unit_diagonal=True)
B_bar = solve_triangular(U, LT_inv_x.T, lower=False).T
A_bar = B_bar[p_indices]
return [A_bar]
def lu_factor(
a: TensorLike,
*,
check_finite: bool = True,
overwrite_a: bool = False,
) -> tuple[TensorVariable, TensorVariable]:
"""
LU factorization with partial pivoting.
Parameters
----------
a: TensorLike
Matrix to be factorized
check_finite: bool
Whether to check that the input matrix contains only finite numbers.
overwrite_a: bool
Unused by PyTensor. PyTensor will always perform the operation in-place if possible.
Returns
-------
LU: TensorVariable
LU decomposition of `a`
pivots: TensorVariable
An array of integers representin the pivot indices
"""
return cast(
tuple[TensorVariable, TensorVariable],
Blockwise(LUFactor(check_finite=check_finite))(a),
)
class LUSolve(OpFromGraph):
"""Solve a system of linear equations given the LU decomposition of the matrix."""
__props__ = ("trans", "b_ndim", "check_finite", "overwrite_b")
def __init__(
self,
inputs: list[Variable],
outputs: list[Variable],
trans: bool = False,
b_ndim: int | None = None,
check_finite: bool = False,
overwrite_b: bool = False,
**kwargs,
):
self.trans = trans
self.b_ndim = b_ndim
self.check_finite = check_finite
self.overwrite_b = overwrite_b
super().__init__(inputs=inputs, outputs=outputs, **kwargs)
def lu_solve(
LU_and_pivots: tuple[TensorLike, TensorLike],
b: TensorLike,
trans: bool = False,
b_ndim: int | None = None,
check_finite: bool = True,
overwrite_b: bool = False,
):
"""
Solve a system of linear equations given the LU decomposition of the matrix.
Parameters
----------
LU_and_pivots: tuple[TensorLike, TensorLike]
LU decomposition of the matrix, as returned by `lu_factor`
b: TensorLike
Right-hand side of the equation
trans: bool
If True, solve A^T x = b, instead of Ax = b. Default is False
b_ndim: int, optional
The number of core dimensions in b. Used to distinguish between a batch of vectors (b_ndim=1) and a matrix
of vectors (b_ndim=2). Default is None, which will infer the number of core dimensions from the input.
check_finite: bool
If True, check that the input matrices contain only finite numbers. Default is True.
overwrite_b: bool
Ignored by Pytensor. Pytensor will always compute inplace when possible.
"""
b_ndim = _default_b_ndim(b, b_ndim)
LU, pivots = LU_and_pivots
LU, pivots, b = map(pt.as_tensor_variable, [LU, pivots, b])
inv_permutation = pivot_to_permutation(pivots, inverse=True)
x = b[inv_permutation] if not trans else b
x = solve_triangular(
LU,
x,
lower=not trans,
unit_diagonal=not trans,
trans=trans,
b_ndim=b_ndim,
check_finite=check_finite,
)
x = solve_triangular(
LU,
x,
lower=trans,
unit_diagonal=trans,
trans=trans,
b_ndim=b_ndim,
check_finite=check_finite,
)
x = x[pt.argsort(inv_permutation)] if trans else x
return x
class SolveTriangular(SolveBase):
"""Solve a system of linear equations."""
......@@ -408,6 +820,9 @@ class SolveTriangular(SolveBase):
def __init__(self, *, unit_diagonal=False, **kwargs):
if kwargs.get("overwrite_a", False):
raise ValueError("overwrite_a is not supported for SolverTriangulare")
# There's a naming inconsistency between solve_triangular (trans) and solve (transposed). Internally, we can use
# transpose everywhere, but expose the same API as scipy.linalg.solve_triangular
super().__init__(**kwargs)
self.unit_diagonal = unit_diagonal
......@@ -1265,4 +1680,7 @@ __all__ = [
"solve_triangular",
"block_diag",
"cho_solve",
"lu",
"lu_factor",
"lu_solve",
]
......@@ -23,6 +23,10 @@ from pytensor.tensor.slinalg import (
cholesky,
eigvalsh,
expm,
lu,
lu_factor,
lu_solve,
pivot_to_permutation,
solve,
solve_continuous_lyapunov,
solve_discrete_are,
......@@ -584,6 +588,177 @@ class TestCholeskySolve(utt.InferShapeTester):
assert x.dtype == x_result.dtype, (A_dtype, b_dtype)
@pytest.mark.parametrize(
"permute_l, p_indices",
[(False, True), (True, False), (False, False)],
ids=["PL", "p_indices", "P"],
)
@pytest.mark.parametrize("complex", [False, True], ids=["real", "complex"])
@pytest.mark.parametrize("shape", [(3, 5, 5), (5, 5)], ids=["batched", "not_batched"])
def test_lu_decomposition(
permute_l: bool, p_indices: bool, complex: bool, shape: tuple[int]
):
dtype = config.floatX if not complex else f"complex{int(config.floatX[-2:]) * 2}"
A = tensor("A", shape=shape, dtype=dtype)
out = lu(A, permute_l=permute_l, p_indices=p_indices)
f = pytensor.function([A], out)
rng = np.random.default_rng(utt.fetch_seed())
x = rng.normal(size=shape).astype(config.floatX)
if complex:
x = x + 1j * rng.normal(size=shape).astype(config.floatX)
out = f(x)
if permute_l:
PL, U = out
elif p_indices:
p, L, U = out
if len(shape) == 2:
P = np.eye(5)[p]
else:
P = np.stack([np.eye(5)[idx] for idx in p])
PL = np.einsum("...nk,...km->...nm", P, L)
else:
P, L, U = out
PL = np.einsum("...nk,...km->...nm", P, L)
x_rebuilt = np.einsum("...nk,...km->...nm", PL, U)
np.testing.assert_allclose(
x,
x_rebuilt,
atol=1e-8 if config.floatX == "float64" else 1e-4,
rtol=1e-8 if config.floatX == "float64" else 1e-4,
)
scipy_out = scipy.linalg.lu(x, permute_l=permute_l, p_indices=p_indices)
for a, b in zip(out, scipy_out, strict=True):
np.testing.assert_allclose(a, b)
@pytest.mark.parametrize(
"grad_case", [0, 1, 2], ids=["dU_only", "dL_only", "dU_and_dL"]
)
@pytest.mark.parametrize(
"permute_l, p_indices",
[(True, False), (False, True), (False, False)],
ids=["PL", "p_indices", "P"],
)
@pytest.mark.parametrize("shape", [(3, 5, 5), (5, 5)], ids=["batched", "not_batched"])
def test_lu_grad(grad_case, permute_l, p_indices, shape):
rng = np.random.default_rng(utt.fetch_seed())
A_value = rng.normal(size=shape).astype(config.floatX)
def f_pt(A):
# lu returns either (P_or_index, L, U) or (PL, U), depending on settings
out = lu(A, permute_l=permute_l, p_indices=p_indices, check_finite=False)
match grad_case:
case 0:
return out[-1].sum()
case 1:
return out[-2].sum()
case 2:
return out[-1].sum() + out[-2].sum()
utt.verify_grad(f_pt, [A_value], rng=rng)
@pytest.mark.parametrize("inverse", [True, False], ids=["inverse", "no_inverse"])
def test_pivot_to_permutation(inverse):
rng = np.random.default_rng(utt.fetch_seed())
A_val = rng.normal(size=(5, 5))
_, pivots = scipy.linalg.lu_factor(A_val)
perm_idx, *_ = scipy.linalg.lu(A_val, p_indices=True)
if not inverse:
perm_idx_pt = pivot_to_permutation(pivots, inverse=False).eval()
np.testing.assert_array_equal(perm_idx_pt, perm_idx)
else:
p_inv_pt = pivot_to_permutation(pivots, inverse=True).eval()
np.testing.assert_array_equal(p_inv_pt, np.argsort(perm_idx))
class TestLUSolve(utt.InferShapeTester):
@staticmethod
def factor_and_solve(A, b, sum=False, **lu_kwargs):
lu_and_pivots = lu_factor(A)
x = lu_solve(lu_and_pivots, b, **lu_kwargs)
if not sum:
return x
return x.sum()
@pytest.mark.parametrize("b_shape", [(5,), (5, 5)], ids=["b_vec", "b_matrix"])
@pytest.mark.parametrize("trans", [True, False], ids=["x_T", "x"])
def test_lu_solve(self, b_shape: tuple[int], trans):
rng = np.random.default_rng(utt.fetch_seed())
A = pt.tensor("A", shape=(5, 5))
b = pt.tensor("b", shape=b_shape)
A_val = (
rng.normal(size=(5, 5)).astype(config.floatX)
+ np.eye(5, dtype=config.floatX) * 0.5
)
b_val = rng.normal(size=b_shape).astype(config.floatX)
x = self.factor_and_solve(A, b, trans=trans, sum=False)
f = pytensor.function([A, b], x)
x_pt = f(A_val.copy(), b_val.copy())
x_sp = scipy.linalg.lu_solve(
scipy.linalg.lu_factor(A_val.copy()), b_val.copy(), trans=trans
)
np.testing.assert_allclose(x_pt, x_sp)
def T(x):
if trans:
return x.T
return x
np.testing.assert_allclose(
T(A_val) @ x_pt,
b_val,
atol=1e-8 if config.floatX == "float64" else 1e-4,
rtol=1e-8 if config.floatX == "float64" else 1e-4,
)
np.testing.assert_allclose(x_pt, x_sp)
@pytest.mark.parametrize("b_shape", [(5,), (5, 5)], ids=["b_vec", "b_matrix"])
@pytest.mark.parametrize("trans", [True, False], ids=["x_T", "x"])
def test_lu_solve_gradient(self, b_shape: tuple[int], trans: bool):
rng = np.random.default_rng(utt.fetch_seed())
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
b_val = rng.normal(size=b_shape).astype(config.floatX)
test_fn = functools.partial(self.factor_and_solve, sum=True, trans=trans)
utt.verify_grad(test_fn, [A_val, b_val], 3, rng)
def test_lu_factor():
rng = np.random.default_rng(utt.fetch_seed())
A = matrix()
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
f = pytensor.function([A], lu_factor(A))
LU, pt_p_idx = f(A_val)
sp_LU, sp_p_idx = scipy.linalg.lu_factor(A_val)
np.testing.assert_allclose(LU, sp_LU)
np.testing.assert_allclose(pt_p_idx, sp_p_idx)
utt.verify_grad(
lambda A: lu_factor(A)[0].sum(),
[A_val],
rng=rng,
)
def test_cho_solve():
rng = np.random.default_rng(utt.fetch_seed())
A = matrix()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论