提交 1aa9a396 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Jesse Grabowski

New Ops related to LU decomposition

上级 ee884b87
差异被折叠。
......@@ -23,6 +23,10 @@ from pytensor.tensor.slinalg import (
cholesky,
eigvalsh,
expm,
lu,
lu_factor,
lu_solve,
pivot_to_permutation,
solve,
solve_continuous_lyapunov,
solve_discrete_are,
......@@ -584,6 +588,177 @@ class TestCholeskySolve(utt.InferShapeTester):
assert x.dtype == x_result.dtype, (A_dtype, b_dtype)
@pytest.mark.parametrize(
"permute_l, p_indices",
[(False, True), (True, False), (False, False)],
ids=["PL", "p_indices", "P"],
)
@pytest.mark.parametrize("complex", [False, True], ids=["real", "complex"])
@pytest.mark.parametrize("shape", [(3, 5, 5), (5, 5)], ids=["batched", "not_batched"])
def test_lu_decomposition(
permute_l: bool, p_indices: bool, complex: bool, shape: tuple[int]
):
dtype = config.floatX if not complex else f"complex{int(config.floatX[-2:]) * 2}"
A = tensor("A", shape=shape, dtype=dtype)
out = lu(A, permute_l=permute_l, p_indices=p_indices)
f = pytensor.function([A], out)
rng = np.random.default_rng(utt.fetch_seed())
x = rng.normal(size=shape).astype(config.floatX)
if complex:
x = x + 1j * rng.normal(size=shape).astype(config.floatX)
out = f(x)
if permute_l:
PL, U = out
elif p_indices:
p, L, U = out
if len(shape) == 2:
P = np.eye(5)[p]
else:
P = np.stack([np.eye(5)[idx] for idx in p])
PL = np.einsum("...nk,...km->...nm", P, L)
else:
P, L, U = out
PL = np.einsum("...nk,...km->...nm", P, L)
x_rebuilt = np.einsum("...nk,...km->...nm", PL, U)
np.testing.assert_allclose(
x,
x_rebuilt,
atol=1e-8 if config.floatX == "float64" else 1e-4,
rtol=1e-8 if config.floatX == "float64" else 1e-4,
)
scipy_out = scipy.linalg.lu(x, permute_l=permute_l, p_indices=p_indices)
for a, b in zip(out, scipy_out, strict=True):
np.testing.assert_allclose(a, b)
@pytest.mark.parametrize(
"grad_case", [0, 1, 2], ids=["dU_only", "dL_only", "dU_and_dL"]
)
@pytest.mark.parametrize(
"permute_l, p_indices",
[(True, False), (False, True), (False, False)],
ids=["PL", "p_indices", "P"],
)
@pytest.mark.parametrize("shape", [(3, 5, 5), (5, 5)], ids=["batched", "not_batched"])
def test_lu_grad(grad_case, permute_l, p_indices, shape):
rng = np.random.default_rng(utt.fetch_seed())
A_value = rng.normal(size=shape).astype(config.floatX)
def f_pt(A):
# lu returns either (P_or_index, L, U) or (PL, U), depending on settings
out = lu(A, permute_l=permute_l, p_indices=p_indices, check_finite=False)
match grad_case:
case 0:
return out[-1].sum()
case 1:
return out[-2].sum()
case 2:
return out[-1].sum() + out[-2].sum()
utt.verify_grad(f_pt, [A_value], rng=rng)
@pytest.mark.parametrize("inverse", [True, False], ids=["inverse", "no_inverse"])
def test_pivot_to_permutation(inverse):
rng = np.random.default_rng(utt.fetch_seed())
A_val = rng.normal(size=(5, 5))
_, pivots = scipy.linalg.lu_factor(A_val)
perm_idx, *_ = scipy.linalg.lu(A_val, p_indices=True)
if not inverse:
perm_idx_pt = pivot_to_permutation(pivots, inverse=False).eval()
np.testing.assert_array_equal(perm_idx_pt, perm_idx)
else:
p_inv_pt = pivot_to_permutation(pivots, inverse=True).eval()
np.testing.assert_array_equal(p_inv_pt, np.argsort(perm_idx))
class TestLUSolve(utt.InferShapeTester):
@staticmethod
def factor_and_solve(A, b, sum=False, **lu_kwargs):
lu_and_pivots = lu_factor(A)
x = lu_solve(lu_and_pivots, b, **lu_kwargs)
if not sum:
return x
return x.sum()
@pytest.mark.parametrize("b_shape", [(5,), (5, 5)], ids=["b_vec", "b_matrix"])
@pytest.mark.parametrize("trans", [True, False], ids=["x_T", "x"])
def test_lu_solve(self, b_shape: tuple[int], trans):
rng = np.random.default_rng(utt.fetch_seed())
A = pt.tensor("A", shape=(5, 5))
b = pt.tensor("b", shape=b_shape)
A_val = (
rng.normal(size=(5, 5)).astype(config.floatX)
+ np.eye(5, dtype=config.floatX) * 0.5
)
b_val = rng.normal(size=b_shape).astype(config.floatX)
x = self.factor_and_solve(A, b, trans=trans, sum=False)
f = pytensor.function([A, b], x)
x_pt = f(A_val.copy(), b_val.copy())
x_sp = scipy.linalg.lu_solve(
scipy.linalg.lu_factor(A_val.copy()), b_val.copy(), trans=trans
)
np.testing.assert_allclose(x_pt, x_sp)
def T(x):
if trans:
return x.T
return x
np.testing.assert_allclose(
T(A_val) @ x_pt,
b_val,
atol=1e-8 if config.floatX == "float64" else 1e-4,
rtol=1e-8 if config.floatX == "float64" else 1e-4,
)
np.testing.assert_allclose(x_pt, x_sp)
@pytest.mark.parametrize("b_shape", [(5,), (5, 5)], ids=["b_vec", "b_matrix"])
@pytest.mark.parametrize("trans", [True, False], ids=["x_T", "x"])
def test_lu_solve_gradient(self, b_shape: tuple[int], trans: bool):
rng = np.random.default_rng(utt.fetch_seed())
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
b_val = rng.normal(size=b_shape).astype(config.floatX)
test_fn = functools.partial(self.factor_and_solve, sum=True, trans=trans)
utt.verify_grad(test_fn, [A_val, b_val], 3, rng)
def test_lu_factor():
rng = np.random.default_rng(utt.fetch_seed())
A = matrix()
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
f = pytensor.function([A], lu_factor(A))
LU, pt_p_idx = f(A_val)
sp_LU, sp_p_idx = scipy.linalg.lu_factor(A_val)
np.testing.assert_allclose(LU, sp_LU)
np.testing.assert_allclose(pt_p_idx, sp_p_idx)
utt.verify_grad(
lambda A: lu_factor(A)[0].sum(),
[A_val],
rng=rng,
)
def test_cho_solve():
rng = np.random.default_rng(utt.fetch_seed())
A = matrix()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论