Unverified 提交 bf628c97 authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Allow `transposed` argument in `linalg.solve` (#1231)

* Add transposed argument to `solve` and `solve_triangular` * Expand test coverage for `Solve` and `SolveTriangular`
上级 757a10cd
......@@ -53,7 +53,6 @@ def jax_funcify_Solve(op, **kwargs):
@jax_funcify.register(SolveTriangular)
def jax_funcify_SolveTriangular(op, **kwargs):
lower = op.lower
trans = op.trans
unit_diagonal = op.unit_diagonal
check_finite = op.check_finite
......@@ -62,7 +61,7 @@ def jax_funcify_SolveTriangular(op, **kwargs):
A,
b,
lower=lower,
trans=trans,
trans=0, # this is handled by explicitly transposing A, so it will always be 0 when we get to here.
unit_diagonal=unit_diagonal,
check_finite=check_finite,
)
......
......@@ -180,7 +180,6 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
@numba_funcify.register(SolveTriangular)
def numba_funcify_SolveTriangular(op, node, **kwargs):
trans = bool(op.trans)
lower = op.lower
unit_diagonal = op.unit_diagonal
check_finite = op.check_finite
......@@ -208,7 +207,7 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
res = _solve_triangular(
a,
b,
trans=trans,
trans=0, # transposing is handled explicitly on the graph, so we never use this argument
lower=lower,
unit_diagonal=unit_diagonal,
overwrite_b=overwrite_b,
......
......@@ -296,13 +296,12 @@ class SolveBase(Op):
# We need to return (dC/d[inv(A)], dC/db)
c_bar = output_gradients[0]
trans_solve_op = type(self)(
**{
k: (not getattr(self, k) if k == "lower" else getattr(self, k))
for k in self.__props__
}
)
b_bar = trans_solve_op(A.T, c_bar)
props_dict = self._props_dict()
props_dict["lower"] = not self.lower
solve_op = type(self)(**props_dict)
b_bar = solve_op(A.T, c_bar)
# force outer product if vector second input
A_bar = -ptm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)
......@@ -385,7 +384,6 @@ class SolveTriangular(SolveBase):
"""Solve a system of linear equations."""
__props__ = (
"trans",
"unit_diagonal",
"lower",
"check_finite",
......@@ -393,11 +391,10 @@ class SolveTriangular(SolveBase):
"overwrite_b",
)
def __init__(self, *, trans=0, unit_diagonal=False, **kwargs):
def __init__(self, *, unit_diagonal=False, **kwargs):
if kwargs.get("overwrite_a", False):
raise ValueError("overwrite_a is not supported for SolverTriangulare")
super().__init__(**kwargs)
self.trans = trans
self.unit_diagonal = unit_diagonal
def perform(self, node, inputs, outputs):
......@@ -406,7 +403,7 @@ class SolveTriangular(SolveBase):
A,
b,
lower=self.lower,
trans=self.trans,
trans=0,
unit_diagonal=self.unit_diagonal,
check_finite=self.check_finite,
overwrite_b=self.overwrite_b,
......@@ -445,9 +442,9 @@ def solve_triangular(
Parameters
----------
a
a: TensorVariable
Square input data
b
b: TensorVariable
Input data for the right hand side.
lower : bool, optional
Use only data contained in the lower triangle of `a`. Default is to use upper triangle.
......@@ -468,10 +465,17 @@ def solve_triangular(
This will influence how batched dimensions are interpreted.
"""
b_ndim = _default_b_ndim(b, b_ndim)
if trans in [1, "T", True]:
a = a.mT
lower = not lower
if trans in [2, "C"]:
a = a.conj().mT
lower = not lower
ret = Blockwise(
SolveTriangular(
lower=lower,
trans=trans,
unit_diagonal=unit_diagonal,
check_finite=check_finite,
b_ndim=b_ndim,
......@@ -534,6 +538,7 @@ def solve(
*,
assume_a="gen",
lower=False,
transposed=False,
check_finite=True,
b_ndim: int | None = None,
):
......@@ -564,8 +569,10 @@ def solve(
b : (..., N, NRHS) array_like
Input data for the right hand side.
lower : bool, optional
If True, only the data contained in the lower triangle of `a`. Default
If True, use only the data contained in the lower triangle of `a`. Default
is to use upper triangle. (ignored for ``'gen'``)
transposed: bool, optional
If True, solves the system A^T x = b. Default is False.
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
......@@ -577,6 +584,11 @@ def solve(
This will influence how batched dimensions are interpreted.
"""
b_ndim = _default_b_ndim(b, b_ndim)
if transposed:
a = a.mT
lower = not lower
return Blockwise(
Solve(
lower=lower,
......
......@@ -5,6 +5,7 @@ import numpy as np
import pytest
import pytensor.tensor as pt
import tests.unittest_tools as utt
from pytensor.configdefaults import config
from pytensor.tensor import nlinalg as pt_nlinalg
from pytensor.tensor import slinalg as pt_slinalg
......@@ -103,28 +104,41 @@ def test_jax_basic():
)
@pytest.mark.parametrize("check_finite", [False, True])
@pytest.mark.parametrize("lower", [False, True])
@pytest.mark.parametrize("trans", [0, 1, 2])
def test_jax_SolveTriangular(trans, lower, check_finite):
x = matrix("x")
b = vector("b")
def test_jax_solve():
rng = np.random.default_rng(utt.fetch_seed())
A = pt.tensor("A", shape=(5, 5))
b = pt.tensor("B", shape=(5, 5))
out = pt_slinalg.solve(A, b, lower=False, transposed=False)
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
b_val = rng.normal(size=(5, 5)).astype(config.floatX)
out = pt_slinalg.solve_triangular(
x,
b,
trans=trans,
lower=lower,
check_finite=check_finite,
)
compare_jax_and_py(
[x, b],
[A, b],
[out],
[
np.eye(10).astype(config.floatX),
np.arange(10).astype(config.floatX),
],
[A_val, b_val],
)
def test_jax_SolveTriangular():
rng = np.random.default_rng(utt.fetch_seed())
A = pt.tensor("A", shape=(5, 5))
b = pt.tensor("B", shape=(5, 5))
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
b_val = rng.normal(size=(5, 5)).astype(config.floatX)
out = pt_slinalg.solve_triangular(
A,
b,
trans=0,
lower=True,
unit_diagonal=False,
)
compare_jax_and_py([A, b], [out], [A_val, b_val])
def test_jax_block_diag():
......
......@@ -5,7 +5,6 @@ from typing import Literal
import numpy as np
import pytest
from numpy.testing import assert_allclose
from scipy import linalg as scipy_linalg
import pytensor
import pytensor.tensor as pt
......@@ -26,9 +25,9 @@ def transpose_func(x, trans):
if trans == 0:
return x
if trans == 1:
return x.conj().T
if trans == 2:
return x.T
if trans == 2:
return x.conj().T
@pytest.mark.parametrize(
......@@ -59,18 +58,18 @@ def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_compl
def A_func(x):
x = x @ x.conj().T
x_tri = scipy_linalg.cholesky(x, lower=lower).astype(dtype)
x_tri = pt.linalg.cholesky(x, lower=lower).astype(dtype)
if unit_diag:
x_tri[np.diag_indices_from(x_tri)] = 1.0
x_tri = pt.fill_diagonal(x_tri, 1.0)
return x_tri.astype(dtype)
return x_tri
solve_op = partial(
pt.linalg.solve_triangular, lower=lower, trans=trans, unit_diagonal=unit_diag
)
X = solve_op(A, b)
X = solve_op(A_func(A), b)
f = pytensor.function([A, b], X, mode="NUMBA")
A_val = np.random.normal(size=(5, 5))
......@@ -80,20 +79,20 @@ def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_compl
A_val = A_val + np.random.normal(size=(5, 5)) * 1j
b_val = b_val + np.random.normal(size=b_shape) * 1j
X_np = f(A_func(A_val), b_val)
test_input = transpose_func(A_func(A_val), trans)
ATOL = 1e-8 if floatX.endswith("64") else 1e-4
RTOL = 1e-8 if floatX.endswith("64") else 1e-4
np.testing.assert_allclose(test_input @ X_np, b_val, atol=ATOL, rtol=RTOL)
X_np = f(A_val.copy(), b_val.copy())
A_val_transformed = transpose_func(A_func(A_val), trans).eval()
np.testing.assert_allclose(
A_val_transformed @ X_np,
b_val,
atol=1e-8 if floatX.endswith("64") else 1e-4,
rtol=1e-8 if floatX.endswith("64") else 1e-4,
)
compiled_fgraph = f.maker.fgraph
compare_numba_and_py(
compiled_fgraph.inputs,
compiled_fgraph.outputs,
[A_func(A_val), b_val],
[A_val, b_val],
)
......@@ -145,7 +144,6 @@ def test_solve_triangular_overwrite_b_correct(overwrite_b):
b_test_nb = b_test_py.copy(order="F")
op = SolveTriangular(
trans=0,
unit_diagonal=False,
lower=False,
check_finite=True,
......
......@@ -214,7 +214,38 @@ def test_solve_raises_on_invalid_A():
Solve(assume_a="test", b_ndim=2)
solve_test_cases = [
("gen", False, False),
("gen", False, True),
("sym", False, False),
("sym", True, False),
("sym", True, True),
("pos", False, False),
("pos", True, False),
("pos", True, True),
]
solve_test_ids = [
f'{assume_a}_{"lower" if lower else "upper"}_{"A^T" if transposed else "A"}'
for assume_a, lower, transposed in solve_test_cases
]
class TestSolve(utt.InferShapeTester):
@staticmethod
def A_func(x, assume_a):
if assume_a == "pos":
return x @ x.T
elif assume_a == "sym":
return (x + x.T) / 2
else:
return x
@staticmethod
def T(x, transposed):
if transposed:
return x.T
return x
@pytest.mark.parametrize("b_shape", [(5, 1), (5,)])
def test_infer_shape(self, b_shape):
rng = np.random.default_rng(utt.fetch_seed())
......@@ -235,8 +266,12 @@ class TestSolve(utt.InferShapeTester):
@pytest.mark.parametrize(
"b_size", [(5, 1), (5, 5), (5,)], ids=["b_col_vec", "b_matrix", "b_vec"]
)
@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str)
def test_solve_correctness(self, b_size: tuple[int], assume_a: str):
@pytest.mark.parametrize(
"assume_a, lower, transposed", solve_test_cases, ids=solve_test_ids
)
def test_solve_correctness(
self, b_size: tuple[int], assume_a: str, lower: bool, transposed: bool
):
rng = np.random.default_rng(utt.fetch_seed())
A = pt.tensor("A", shape=(5, 5))
b = pt.tensor("b", shape=b_size)
......@@ -244,19 +279,18 @@ class TestSolve(utt.InferShapeTester):
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
b_val = rng.normal(size=b_size).astype(config.floatX)
solve_op = functools.partial(solve, assume_a=assume_a, b_ndim=len(b_size))
A_func = functools.partial(self.A_func, assume_a=assume_a)
T = functools.partial(self.T, transposed=transposed)
def A_func(x):
if assume_a == "pos":
return x @ x.T
elif assume_a == "sym":
return (x + x.T) / 2
else:
return x
solve_input_val = A_func(A_val)
y = solve(
A_func(A),
b,
assume_a=assume_a,
lower=lower,
transposed=transposed,
b_ndim=len(b_size),
)
y = solve_op(A_func(A), b)
solve_func = pytensor.function([A, b], y)
X_np = solve_func(A_val.copy(), b_val.copy())
......@@ -264,22 +298,34 @@ class TestSolve(utt.InferShapeTester):
RTOL = 1e-8 if config.floatX.endswith("64") else 1e-4
np.testing.assert_allclose(
scipy.linalg.solve(solve_input_val, b_val, assume_a=assume_a),
scipy.linalg.solve(
A_func(A_val),
b_val,
assume_a=assume_a,
transposed=transposed,
lower=lower,
),
X_np,
atol=ATOL,
rtol=RTOL,
)
np.testing.assert_allclose(A_func(A_val) @ X_np, b_val, atol=ATOL, rtol=RTOL)
np.testing.assert_allclose(T(A_func(A_val)) @ X_np, b_val, atol=ATOL, rtol=RTOL)
@pytest.mark.parametrize(
"b_size", [(5, 1), (5, 5), (5,)], ids=["b_col_vec", "b_matrix", "b_vec"]
)
@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str)
@pytest.mark.parametrize(
"assume_a, lower, transposed",
solve_test_cases,
ids=solve_test_ids,
)
@pytest.mark.skipif(
config.floatX == "float32", reason="Gradients not numerically stable in float32"
)
def test_solve_gradient(self, b_size: tuple[int], assume_a: str):
def test_solve_gradient(
self, b_size: tuple[int], assume_a: str, lower: bool, transposed: bool
):
rng = np.random.default_rng(utt.fetch_seed())
eps = 2e-8 if config.floatX == "float64" else None
......@@ -287,15 +333,8 @@ class TestSolve(utt.InferShapeTester):
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
b_val = rng.normal(size=b_size).astype(config.floatX)
def A_func(x):
if assume_a == "pos":
return x @ x.T
elif assume_a == "sym":
return (x + x.T) / 2
else:
return x
solve_op = functools.partial(solve, assume_a=assume_a, b_ndim=len(b_size))
A_func = functools.partial(self.A_func, assume_a=assume_a)
# To correctly check the gradients, we need to include a transformation from the space of unconstrained matrices
# (A) to a valid input matrix for the given solver. This is done by the A_func function. If this isn't included,
......@@ -307,11 +346,27 @@ class TestSolve(utt.InferShapeTester):
class TestSolveTriangular(utt.InferShapeTester):
@staticmethod
def A_func(x, lower, unit_diagonal):
x = x @ x.T
x = pt.linalg.cholesky(x, lower=lower)
if unit_diagonal:
x = pt.fill_diagonal(x, 1)
return x
@staticmethod
def T(x, trans):
if trans == 1:
return x.T
elif trans == 2:
return x.conj().T
return x
@pytest.mark.parametrize("b_shape", [(5, 1), (5,)])
def test_infer_shape(self, b_shape):
rng = np.random.default_rng(utt.fetch_seed())
A = matrix()
b_val = np.asarray(rng.random(b_shape), dtype=config.floatX)
b_val = rng.random(b_shape).astype(config.floatX)
b = pt.as_tensor_variable(b_val).type()
self._compile_and_check(
[A, b],
......@@ -324,56 +379,78 @@ class TestSolveTriangular(utt.InferShapeTester):
warn=False,
)
@pytest.mark.parametrize(
"b_shape", [(5, 1), (5,), (5, 5)], ids=["b_col_vec", "b_vec", "b_matrix"]
)
@pytest.mark.parametrize("lower", [True, False])
def test_correctness(self, lower):
@pytest.mark.parametrize("trans", [0, 1, 2])
@pytest.mark.parametrize("unit_diagonal", [True, False])
def test_correctness(self, b_shape: tuple[int], lower, trans, unit_diagonal):
rng = np.random.default_rng(utt.fetch_seed())
A = pt.tensor("A", shape=(5, 5))
b = pt.tensor("b", shape=b_shape)
b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX)
A_val = rng.random((5, 5)).astype(config.floatX)
b_val = rng.random(b_shape).astype(config.floatX)
A_val = np.asarray(rng.random((5, 5)), dtype=config.floatX)
A_val = np.dot(A_val.transpose(), A_val)
A_func = functools.partial(
self.A_func, lower=lower, unit_diagonal=unit_diagonal
)
C_val = scipy.linalg.cholesky(A_val, lower=lower)
x = solve_triangular(
A_func(A),
b,
lower=lower,
trans=trans,
unit_diagonal=unit_diagonal,
b_ndim=len(b_shape),
)
A = matrix()
b = matrix()
f = pytensor.function([A, b], x)
cholesky = Cholesky(lower=lower)
C = cholesky(A)
y_lower = solve_triangular(C, b, lower=lower)
lower_solve_func = pytensor.function([C, b], y_lower)
x_pt = f(A_val, b_val)
x_sp = scipy.linalg.solve_triangular(
A_func(A_val).eval(),
b_val,
lower=lower,
trans=trans,
unit_diagonal=unit_diagonal,
)
assert np.allclose(
scipy.linalg.solve_triangular(C_val, b_val, lower=lower),
lower_solve_func(C_val, b_val),
np.testing.assert_allclose(
x_pt,
x_sp,
atol=1e-8 if config.floatX == "float64" else 1e-4,
rtol=1e-8 if config.floatX == "float64" else 1e-4,
)
@pytest.mark.parametrize(
"m, n, lower",
[
(5, None, False),
(5, None, True),
(4, 2, False),
(4, 2, True),
],
"b_shape", [(5, 1), (5,), (5, 5)], ids=["b_col_vec", "b_vec", "b_matrix"]
)
def test_solve_grad(self, m, n, lower):
rng = np.random.default_rng(utt.fetch_seed())
@pytest.mark.parametrize("lower", [True, False])
@pytest.mark.parametrize("trans", [0, 1])
@pytest.mark.parametrize("unit_diagonal", [True, False])
def test_solve_triangular_grad(self, b_shape, lower, trans, unit_diagonal):
if config.floatX == "float32":
pytest.skip(reason="Not enough precision in float32 to get a good gradient")
# Ensure diagonal elements of `A` are relatively large to avoid
# numerical precision issues
A_val = (rng.normal(size=(m, m)) * 0.5 + np.eye(m)).astype(config.floatX)
rng = np.random.default_rng(utt.fetch_seed())
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
b_val = rng.normal(size=b_shape).astype(config.floatX)
if n is None:
b_val = rng.normal(size=m).astype(config.floatX)
else:
b_val = rng.normal(size=(m, n)).astype(config.floatX)
A_func = functools.partial(
self.A_func, lower=lower, unit_diagonal=unit_diagonal
)
eps = None
if config.floatX == "float64":
eps = 2e-8
solve_op = SolveTriangular(lower=lower, b_ndim=1 if n is None else 2)
def solve_op(A, b):
return solve_triangular(
A_func(A), b, lower=lower, trans=trans, unit_diagonal=unit_diagonal
)
utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论