提交 79961a62 authored 作者: Fabian Hartmann's avatar Fabian Hartmann 提交者: Brandon T. Willard

Add a SolveTriangular Op

`Solve` has also been changed to match SciPy.
上级 6fce270b
import logging
import warnings
from typing import Union
import numpy as np
import scipy.linalg
......@@ -11,6 +12,7 @@ from aesara.tensor import as_tensor_variable
from aesara.tensor import basic as aet
from aesara.tensor import math as atm
from aesara.tensor.type import matrix, tensor, vector
from aesara.tensor.var import TensorVariable
logger = logging.getLogger(__name__)
......@@ -259,93 +261,52 @@ def cho_solve(c_and_lower, b, check_finite=True):
return CholeskySolve(lower=lower, check_finite=check_finite)(A, b)
class Solve(Op):
"""
Solve a system of linear equations.
For on CPU and GPU.
"""
class SolveBase(Op):
"""Base class for `scipy.linalg` matrix equation solvers."""
__props__ = (
"assume_a",
"lower",
"check_finite", # "transposed"
"check_finite",
)
def __init__(
self,
assume_a="gen",
lower=False,
check_finite=True, # transposed=False
check_finite=True,
):
if assume_a not in ("gen", "sym", "her", "pos"):
raise ValueError(f"{assume_a} is not a recognized matrix structure")
self.assume_a = assume_a
self.lower = lower
self.check_finite = check_finite
# self.transposed = transposed
def __repr__(self):
return "Solve{%s}" % str(self._props())
def perform(self, node, inputs, outputs):
pass
def make_node(self, A, b):
A = as_tensor_variable(A)
b = as_tensor_variable(b)
assert A.ndim == 2
assert b.ndim in [1, 2]
# infer dtype by solving the most simple
# case with (1, 1) matrices
if A.ndim != 2:
raise ValueError(f"`A` must be a matrix; got {A.type} instead.")
if b.ndim not in [1, 2]:
raise ValueError(f"`b` must be a matrix or a vector; got {b.type} instead.")
# Infer dtype by solving the most simple case with 1x1 matrices
o_dtype = scipy.linalg.solve(
np.eye(1).astype(A.dtype), np.eye(1).astype(b.dtype)
).dtype
x = tensor(broadcastable=b.broadcastable, dtype=o_dtype)
return Apply(self, [A, b], [x])
def perform(self, node, inputs, output_storage):
A, b = inputs
if self.assume_a != "gen":
# if self.transposed:
# if self.assume_a == "her":
# trans = "C"
# else:
# trans = "T"
# else:
# trans = "N"
rval = scipy.linalg.solve_triangular(
A,
b,
lower=self.lower,
check_finite=self.check_finite,
# trans=trans
)
else:
rval = scipy.linalg.solve(
A,
b,
assume_a=self.assume_a,
lower=self.lower,
check_finite=self.check_finite,
# transposed=self.transposed,
)
output_storage[0][0] = rval
# computes shape of x where x = inv(A) * b
def infer_shape(self, fgraph, node, shapes):
Ashape, Bshape = shapes
rows = Ashape[1]
if len(Bshape) == 1: # b is a Vector
if len(Bshape) == 1:
return [(rows,)]
else:
cols = Bshape[1] # b is a Matrix
cols = Bshape[1]
return [(rows, cols)]
def L_op(self, inputs, outputs, output_gradients):
r"""
Reverse-mode gradient updates for matrix solve operation :math:`c = A^{-1} b`.
r"""Reverse-mode gradient updates for matrix solve operation :math:`c = A^{-1} b`.
Symbolic expression for updates taken from [#]_.
......@@ -364,31 +325,148 @@ class Solve(Op):
# We need to return (dC/d[inv(A)], dC/db)
c_bar = output_gradients[0]
trans_solve_op = Solve(
assume_a=self.assume_a,
check_finite=self.check_finite,
lower=not self.lower,
trans_solve_op = type(self)(
**{
k: (not getattr(self, k) if k == "lower" else getattr(self, k))
for k in self.__props__
}
)
b_bar = trans_solve_op(A.T, c_bar)
# force outer product if vector second input
A_bar = -atm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)
if self.assume_a != "gen":
if self.lower:
A_bar = aet.tril(A_bar)
else:
A_bar = aet.triu(A_bar)
return [A_bar, b_bar]
def __repr__(self):
return f"{type(self).__name__}{self._props()}"
class SolveTriangular(SolveBase):
"""Solve a system of linear equations."""
__props__ = (
"lower",
"trans",
"unit_diagonal",
"check_finite",
)
def __init__(
self,
trans=0,
lower=False,
unit_diagonal=False,
check_finite=True,
):
super().__init__(lower=lower, check_finite=check_finite)
self.trans = trans
self.unit_diagonal = unit_diagonal
def perform(self, node, inputs, outputs):
A, b = inputs
outputs[0][0] = scipy.linalg.solve_triangular(
A,
b,
lower=self.lower,
trans=self.trans,
unit_diagonal=self.unit_diagonal,
check_finite=self.check_finite,
)
def L_op(self, inputs, outputs, output_gradients):
res = super().L_op(inputs, outputs, output_gradients)
if self.lower:
res[0] = aet.tril(res[0])
else:
res[0] = aet.triu(res[0])
return res
solvetriangular = SolveTriangular()
def solve_triangular(
a: TensorVariable,
b: TensorVariable,
trans: Union[int, str] = 0,
lower: bool = False,
unit_diagonal: bool = False,
check_finite: bool = True,
) -> TensorVariable:
"""Solve the equation `a x = b` for `x`, assuming `a` is a triangular matrix.
Parameters
----------
a
Square input data
b
Input data for the right hand side.
lower : bool, optional
Use only data contained in the lower triangle of `a`. Default is to use upper triangle.
trans: {0, 1, 2, ‘N’, ‘T’, ‘C’}, optional
Type of system to solve:
trans system
0 or 'N' a x = b
1 or 'T' a^T x = b
2 or 'C' a^H x = b
unit_diagonal: bool, optional
If True, diagonal elements of `a` are assumed to be 1 and will not be referenced.
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
"""
return SolveTriangular(
lower=lower,
trans=trans,
unit_diagonal=unit_diagonal,
check_finite=check_finite,
)(a, b)
class Solve(SolveBase):
"""
Solve a system of linear equations.
For on CPU and GPU.
"""
__props__ = (
"assume_a",
"lower",
"check_finite",
)
def __init__(
self,
assume_a="gen",
lower=False,
check_finite=True,
):
if assume_a not in ("gen", "sym", "her", "pos"):
raise ValueError(f"{assume_a} is not a recognized matrix structure")
super().__init__(lower=lower, check_finite=check_finite)
self.assume_a = assume_a
def perform(self, node, inputs, outputs):
a, b = inputs
outputs[0][0] = scipy.linalg.solve(
a=a,
b=b,
lower=self.lower,
check_finite=self.check_finite,
assume_a=self.assume_a,
)
solve = Solve()
def solve(a, b, assume_a="gen", lower=False, check_finite=True):
"""
Solves the linear equation set ``a * x = b`` for the unknown ``x``
for square ``a`` matrix.
"""Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
If the data matrix is known to be a particular type then supplying the
corresponding string to ``assume_a`` key chooses the dedicated solver.
......@@ -432,8 +510,8 @@ def solve(a, b, assume_a="gen", lower=False, check_finite=True):
# TODO: These are deprecated; emit a warning
solve_lower_triangular = Solve(assume_a="sym", lower=True)
solve_upper_triangular = Solve(assume_a="sym", lower=False)
solve_lower_triangular = SolveTriangular(lower=True)
solve_upper_triangular = SolveTriangular(lower=False)
solve_symmetric = Solve(assume_a="sym")
# TODO: Optimizations to replace multiplication by matrix inverse
......
......@@ -2174,6 +2174,31 @@ def test_Cholesky(x, lower, exc):
"gen",
None,
),
],
)
def test_Solve(A, x, lower, exc):
g = slinalg.Solve(lower)(A, x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"A, x, lower, exc",
[
(
set_test_value(
aet.dmatrix(),
......@@ -2185,8 +2210,8 @@ def test_Cholesky(x, lower, exc):
),
],
)
def test_Solve(A, x, lower, exc):
g = slinalg.Solve(lower)(A, x)
def test_SolveTriangular(A, x, lower, exc):
g = slinalg.SolveTriangular(lower)(A, x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
......
import functools
import itertools
import numpy as np
......@@ -14,12 +15,15 @@ from aesara.tensor.slinalg import (
CholeskyGrad,
CholeskySolve,
Solve,
SolveBase,
SolveTriangular,
cho_solve,
cholesky,
eigvalsh,
expm,
kron,
solve,
solve_triangular,
)
from aesara.tensor.type import dmatrix, matrix, tensor, vector
from tests import unittest_tools as utt
......@@ -170,122 +174,107 @@ def test_eigvalsh_grad():
)
class TestSolve(utt.InferShapeTester):
def setup_method(self):
self.op_class = Solve
self.op = Solve()
super().setup_method()
def test_infer_shape(self):
rng = np.random.default_rng(utt.fetch_seed())
class TestSolveBase(utt.InferShapeTester):
@pytest.mark.parametrize(
"A_func, b_func, error_message",
[
(vector, matrix, "`A` must be a matrix.*"),
(
functools.partial(tensor, dtype="floatX", broadcastable=(False,) * 3),
matrix,
"`A` must be a matrix.*",
),
(
matrix,
functools.partial(tensor, dtype="floatX", broadcastable=(False,) * 3),
"`b` must be a matrix or a vector.*",
),
],
)
def test_make_node(self, A_func, b_func, error_message):
np.random.default_rng(utt.fetch_seed())
with pytest.raises(ValueError, match=error_message):
A = A_func()
b = b_func()
SolveBase()(A, b)
def test__repr__(self):
np.random.default_rng(utt.fetch_seed())
A = matrix()
b = matrix()
self._compile_and_check(
[A, b], # aesara.function inputs
[self.op(A, b)], # aesara.function outputs
# A must be square
[
np.asarray(rng.random((5, 5)), dtype=config.floatX),
np.asarray(rng.random((5, 1)), dtype=config.floatX),
],
self.op_class,
warn=False,
)
y = SolveBase()(A, b)
assert y.__repr__() == "SolveBase{lower=False, check_finite=True}.0"
class TestSolve(utt.InferShapeTester):
def test__init__(self):
with pytest.raises(ValueError) as excinfo:
Solve(assume_a="test")
assert "is not a recognized matrix structure" in str(excinfo.value)
@pytest.mark.parametrize("b_shape", [(5, 1), (5,)])
def test_infer_shape(self, b_shape):
rng = np.random.default_rng(utt.fetch_seed())
A = matrix()
b = vector()
b_val = np.asarray(rng.random(b_shape), dtype=config.floatX)
b = aet.as_tensor_variable(b_val).type()
self._compile_and_check(
[A, b], # aesara.function inputs
[self.op(A, b)], # aesara.function outputs
# A must be square
[A, b],
[solve(A, b)],
[
np.asarray(rng.random((5, 5)), dtype=config.floatX),
np.asarray(rng.random((5)), dtype=config.floatX),
b_val,
],
self.op_class,
Solve,
warn=False,
)
def test_solve_correctness(self):
def test_correctness(self):
rng = np.random.default_rng(utt.fetch_seed())
A = matrix()
b = matrix()
y = self.op(A, b)
y = solve(A, b)
gen_solve_func = aesara.function([A, b], y)
cholesky_lower = Cholesky(lower=True)
L = cholesky_lower(A)
y_lower = self.op(L, b)
lower_solve_func = aesara.function([L, b], y_lower)
cholesky_upper = Cholesky(lower=False)
U = cholesky_upper(A)
y_upper = self.op(U, b)
upper_solve_func = aesara.function([U, b], y_upper)
b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX)
# 1-test general case
A_val = np.asarray(rng.random((5, 5)), dtype=config.floatX)
# positive definite matrix:
A_val = np.dot(A_val.transpose(), A_val)
assert np.allclose(
scipy.linalg.solve(A_val, b_val), gen_solve_func(A_val, b_val)
)
# 2-test lower traingular case
L_val = scipy.linalg.cholesky(A_val, lower=True)
assert np.allclose(
scipy.linalg.solve_triangular(L_val, b_val, lower=True),
lower_solve_func(L_val, b_val),
A_undef = np.array(
[
[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 1],
[0, 0, 0, 1, 0],
],
dtype=config.floatX,
)
# 3-test upper traingular case
U_val = scipy.linalg.cholesky(A_val, lower=False)
assert np.allclose(
scipy.linalg.solve_triangular(U_val, b_val, lower=False),
upper_solve_func(U_val, b_val),
scipy.linalg.solve(A_undef, b_val), gen_solve_func(A_undef, b_val)
)
def test_solve_dtype(self):
dtypes = [
"uint8",
"uint16",
"uint32",
"uint64",
"int8",
"int16",
"int32",
"int64",
"float16",
"float32",
"float64",
]
A_val = np.eye(2)
b_val = np.ones((2, 1))
# try all dtype combinations
for A_dtype, b_dtype in itertools.product(dtypes, dtypes):
A = matrix(dtype=A_dtype)
b = matrix(dtype=b_dtype)
x = solve(A, b)
fn = function([A, b], x)
x_result = fn(A_val.astype(A_dtype), b_val.astype(b_dtype))
assert x.dtype == x_result.dtype
@pytest.mark.parametrize(
"m, n, assume_a, lower",
[
(5, None, "gen", False),
(5, None, "gen", True),
(4, 2, "gen", False),
(4, 2, "gen", True),
],
)
def test_solve_grad(self, m, n, assume_a, lower):
rng = np.random.default_rng(utt.fetch_seed())
def verify_solve_grad(self, m, n, assume_a, lower, rng):
# ensure diagonal elements of A relatively large to avoid numerical
# precision issues
# Ensure diagonal elements of `A` are relatively large to avoid
# numerical precision issues
A_val = (rng.normal(size=(m, m)) * 0.5 + np.eye(m)).astype(config.floatX)
if assume_a != "gen":
if lower:
A_val = np.tril(A_val)
else:
A_val = np.triu(A_val)
if n is None:
b_val = rng.normal(size=m).astype(config.floatX)
else:
......@@ -298,22 +287,76 @@ class TestSolve(utt.InferShapeTester):
solve_op = Solve(assume_a=assume_a, lower=lower)
utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)
class TestSolveTriangular(utt.InferShapeTester):
@pytest.mark.parametrize("b_shape", [(5, 1), (5,)])
def test_infer_shape(self, b_shape):
rng = np.random.default_rng(utt.fetch_seed())
A = matrix()
b_val = np.asarray(rng.random(b_shape), dtype=config.floatX)
b = aet.as_tensor_variable(b_val).type()
self._compile_and_check(
[A, b],
[solve_triangular(A, b)],
[
np.asarray(rng.random((5, 5)), dtype=config.floatX),
b_val,
],
SolveTriangular,
warn=False,
)
@pytest.mark.parametrize("lower", [True, False])
def test_correctness(self, lower):
rng = np.random.default_rng(utt.fetch_seed())
b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX)
A_val = np.asarray(rng.random((5, 5)), dtype=config.floatX)
A_val = np.dot(A_val.transpose(), A_val)
C_val = scipy.linalg.cholesky(A_val, lower=lower)
A = matrix()
b = matrix()
cholesky = Cholesky(lower=lower)
C = cholesky(A)
y_lower = solve_triangular(C, b, lower=lower)
lower_solve_func = aesara.function([C, b], y_lower)
assert np.allclose(
scipy.linalg.solve_triangular(C_val, b_val, lower=lower),
lower_solve_func(C_val, b_val),
)
@pytest.mark.parametrize(
"m, n, assume_a, lower",
"m, n, lower",
[
(5, None, "gen", False),
(5, None, "gen", True),
(4, 2, "gen", False),
(4, 2, "gen", True),
(5, None, "sym", False),
(5, None, "sym", True),
(4, 2, "sym", False),
(4, 2, "sym", True),
(5, None, False),
(5, None, True),
(4, 2, False),
(4, 2, True),
],
)
def test_solve_grad(self, m, n, assume_a, lower):
def test_solve_grad(self, m, n, lower):
rng = np.random.default_rng(utt.fetch_seed())
self.verify_solve_grad(m, n, assume_a, lower, rng)
# Ensure diagonal elements of `A` are relatively large to avoid
# numerical precision issues
A_val = (rng.normal(size=(m, m)) * 0.5 + np.eye(m)).astype(config.floatX)
if n is None:
b_val = rng.normal(size=m).astype(config.floatX)
else:
b_val = rng.normal(size=(m, n)).astype(config.floatX)
eps = None
if config.floatX == "float64":
eps = 2e-8
solve_op = SolveTriangular(lower=lower)
utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)
class TestCholeskySolve(utt.InferShapeTester):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论