提交 79961a62 authored 作者: Fabian Hartmann's avatar Fabian Hartmann 提交者: Brandon T. Willard

Add a SolveTriangular Op

`Solve` has also been changed to match SciPy.
上级 6fce270b
import logging import logging
import warnings import warnings
from typing import Union
import numpy as np import numpy as np
import scipy.linalg import scipy.linalg
...@@ -11,6 +12,7 @@ from aesara.tensor import as_tensor_variable ...@@ -11,6 +12,7 @@ from aesara.tensor import as_tensor_variable
from aesara.tensor import basic as aet from aesara.tensor import basic as aet
from aesara.tensor import math as atm from aesara.tensor import math as atm
from aesara.tensor.type import matrix, tensor, vector from aesara.tensor.type import matrix, tensor, vector
from aesara.tensor.var import TensorVariable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -259,93 +261,52 @@ def cho_solve(c_and_lower, b, check_finite=True): ...@@ -259,93 +261,52 @@ def cho_solve(c_and_lower, b, check_finite=True):
return CholeskySolve(lower=lower, check_finite=check_finite)(A, b) return CholeskySolve(lower=lower, check_finite=check_finite)(A, b)
class Solve(Op): class SolveBase(Op):
""" """Base class for `scipy.linalg` matrix equation solvers."""
Solve a system of linear equations.
For on CPU and GPU.
"""
__props__ = ( __props__ = (
"assume_a",
"lower", "lower",
"check_finite", # "transposed" "check_finite",
) )
def __init__( def __init__(
self, self,
assume_a="gen",
lower=False, lower=False,
check_finite=True, # transposed=False check_finite=True,
): ):
if assume_a not in ("gen", "sym", "her", "pos"):
raise ValueError(f"{assume_a} is not a recognized matrix structure")
self.assume_a = assume_a
self.lower = lower self.lower = lower
self.check_finite = check_finite self.check_finite = check_finite
# self.transposed = transposed
def __repr__(self): def perform(self, node, inputs, outputs):
return "Solve{%s}" % str(self._props()) pass
def make_node(self, A, b): def make_node(self, A, b):
A = as_tensor_variable(A) A = as_tensor_variable(A)
b = as_tensor_variable(b) b = as_tensor_variable(b)
assert A.ndim == 2
assert b.ndim in [1, 2]
# infer dtype by solving the most simple if A.ndim != 2:
# case with (1, 1) matrices raise ValueError(f"`A` must be a matrix; got {A.type} instead.")
if b.ndim not in [1, 2]:
raise ValueError(f"`b` must be a matrix or a vector; got {b.type} instead.")
# Infer dtype by solving the most simple case with 1x1 matrices
o_dtype = scipy.linalg.solve( o_dtype = scipy.linalg.solve(
np.eye(1).astype(A.dtype), np.eye(1).astype(b.dtype) np.eye(1).astype(A.dtype), np.eye(1).astype(b.dtype)
).dtype ).dtype
x = tensor(broadcastable=b.broadcastable, dtype=o_dtype) x = tensor(broadcastable=b.broadcastable, dtype=o_dtype)
return Apply(self, [A, b], [x]) return Apply(self, [A, b], [x])
def perform(self, node, inputs, output_storage):
A, b = inputs
if self.assume_a != "gen":
# if self.transposed:
# if self.assume_a == "her":
# trans = "C"
# else:
# trans = "T"
# else:
# trans = "N"
rval = scipy.linalg.solve_triangular(
A,
b,
lower=self.lower,
check_finite=self.check_finite,
# trans=trans
)
else:
rval = scipy.linalg.solve(
A,
b,
assume_a=self.assume_a,
lower=self.lower,
check_finite=self.check_finite,
# transposed=self.transposed,
)
output_storage[0][0] = rval
# computes shape of x where x = inv(A) * b
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
Ashape, Bshape = shapes Ashape, Bshape = shapes
rows = Ashape[1] rows = Ashape[1]
if len(Bshape) == 1: # b is a Vector if len(Bshape) == 1:
return [(rows,)] return [(rows,)]
else: else:
cols = Bshape[1] # b is a Matrix cols = Bshape[1]
return [(rows, cols)] return [(rows, cols)]
def L_op(self, inputs, outputs, output_gradients): def L_op(self, inputs, outputs, output_gradients):
r""" r"""Reverse-mode gradient updates for matrix solve operation :math:`c = A^{-1} b`.
Reverse-mode gradient updates for matrix solve operation :math:`c = A^{-1} b`.
Symbolic expression for updates taken from [#]_. Symbolic expression for updates taken from [#]_.
...@@ -364,31 +325,148 @@ class Solve(Op): ...@@ -364,31 +325,148 @@ class Solve(Op):
# We need to return (dC/d[inv(A)], dC/db) # We need to return (dC/d[inv(A)], dC/db)
c_bar = output_gradients[0] c_bar = output_gradients[0]
trans_solve_op = Solve( trans_solve_op = type(self)(
assume_a=self.assume_a, **{
check_finite=self.check_finite, k: (not getattr(self, k) if k == "lower" else getattr(self, k))
lower=not self.lower, for k in self.__props__
}
) )
b_bar = trans_solve_op(A.T, c_bar) b_bar = trans_solve_op(A.T, c_bar)
# force outer product if vector second input # force outer product if vector second input
A_bar = -atm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T) A_bar = -atm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)
if self.assume_a != "gen":
if self.lower:
A_bar = aet.tril(A_bar)
else:
A_bar = aet.triu(A_bar)
return [A_bar, b_bar] return [A_bar, b_bar]
def __repr__(self):
return f"{type(self).__name__}{self._props()}"
class SolveTriangular(SolveBase):
"""Solve a system of linear equations."""
__props__ = (
"lower",
"trans",
"unit_diagonal",
"check_finite",
)
def __init__(
self,
trans=0,
lower=False,
unit_diagonal=False,
check_finite=True,
):
super().__init__(lower=lower, check_finite=check_finite)
self.trans = trans
self.unit_diagonal = unit_diagonal
def perform(self, node, inputs, outputs):
A, b = inputs
outputs[0][0] = scipy.linalg.solve_triangular(
A,
b,
lower=self.lower,
trans=self.trans,
unit_diagonal=self.unit_diagonal,
check_finite=self.check_finite,
)
def L_op(self, inputs, outputs, output_gradients):
res = super().L_op(inputs, outputs, output_gradients)
if self.lower:
res[0] = aet.tril(res[0])
else:
res[0] = aet.triu(res[0])
return res
solvetriangular = SolveTriangular()
def solve_triangular(
a: TensorVariable,
b: TensorVariable,
trans: Union[int, str] = 0,
lower: bool = False,
unit_diagonal: bool = False,
check_finite: bool = True,
) -> TensorVariable:
"""Solve the equation `a x = b` for `x`, assuming `a` is a triangular matrix.
Parameters
----------
a
Square input data
b
Input data for the right hand side.
lower : bool, optional
Use only data contained in the lower triangle of `a`. Default is to use upper triangle.
trans: {0, 1, 2, ‘N’, ‘T’, ‘C’}, optional
Type of system to solve:
trans system
0 or 'N' a x = b
1 or 'T' a^T x = b
2 or 'C' a^H x = b
unit_diagonal: bool, optional
If True, diagonal elements of `a` are assumed to be 1 and will not be referenced.
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
"""
return SolveTriangular(
lower=lower,
trans=trans,
unit_diagonal=unit_diagonal,
check_finite=check_finite,
)(a, b)
class Solve(SolveBase):
"""
Solve a system of linear equations.
For on CPU and GPU.
"""
__props__ = (
"assume_a",
"lower",
"check_finite",
)
def __init__(
self,
assume_a="gen",
lower=False,
check_finite=True,
):
if assume_a not in ("gen", "sym", "her", "pos"):
raise ValueError(f"{assume_a} is not a recognized matrix structure")
super().__init__(lower=lower, check_finite=check_finite)
self.assume_a = assume_a
def perform(self, node, inputs, outputs):
a, b = inputs
outputs[0][0] = scipy.linalg.solve(
a=a,
b=b,
lower=self.lower,
check_finite=self.check_finite,
assume_a=self.assume_a,
)
solve = Solve() solve = Solve()
def solve(a, b, assume_a="gen", lower=False, check_finite=True): def solve(a, b, assume_a="gen", lower=False, check_finite=True):
""" """Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
Solves the linear equation set ``a * x = b`` for the unknown ``x``
for square ``a`` matrix.
If the data matrix is known to be a particular type then supplying the If the data matrix is known to be a particular type then supplying the
corresponding string to ``assume_a`` key chooses the dedicated solver. corresponding string to ``assume_a`` key chooses the dedicated solver.
...@@ -432,8 +510,8 @@ def solve(a, b, assume_a="gen", lower=False, check_finite=True): ...@@ -432,8 +510,8 @@ def solve(a, b, assume_a="gen", lower=False, check_finite=True):
# TODO: These are deprecated; emit a warning # TODO: These are deprecated; emit a warning
solve_lower_triangular = Solve(assume_a="sym", lower=True) solve_lower_triangular = SolveTriangular(lower=True)
solve_upper_triangular = Solve(assume_a="sym", lower=False) solve_upper_triangular = SolveTriangular(lower=False)
solve_symmetric = Solve(assume_a="sym") solve_symmetric = Solve(assume_a="sym")
# TODO: Optimizations to replace multiplication by matrix inverse # TODO: Optimizations to replace multiplication by matrix inverse
......
...@@ -2174,6 +2174,31 @@ def test_Cholesky(x, lower, exc): ...@@ -2174,6 +2174,31 @@ def test_Cholesky(x, lower, exc):
"gen", "gen",
None, None,
), ),
],
)
def test_Solve(A, x, lower, exc):
g = slinalg.Solve(lower)(A, x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"A, x, lower, exc",
[
( (
set_test_value( set_test_value(
aet.dmatrix(), aet.dmatrix(),
...@@ -2185,8 +2210,8 @@ def test_Cholesky(x, lower, exc): ...@@ -2185,8 +2210,8 @@ def test_Cholesky(x, lower, exc):
), ),
], ],
) )
def test_Solve(A, x, lower, exc): def test_SolveTriangular(A, x, lower, exc):
g = slinalg.Solve(lower)(A, x) g = slinalg.SolveTriangular(lower)(A, x)
if isinstance(g, list): if isinstance(g, list):
g_fg = FunctionGraph(outputs=g) g_fg = FunctionGraph(outputs=g)
......
import functools
import itertools import itertools
import numpy as np import numpy as np
...@@ -14,12 +15,15 @@ from aesara.tensor.slinalg import ( ...@@ -14,12 +15,15 @@ from aesara.tensor.slinalg import (
CholeskyGrad, CholeskyGrad,
CholeskySolve, CholeskySolve,
Solve, Solve,
SolveBase,
SolveTriangular,
cho_solve, cho_solve,
cholesky, cholesky,
eigvalsh, eigvalsh,
expm, expm,
kron, kron,
solve, solve,
solve_triangular,
) )
from aesara.tensor.type import dmatrix, matrix, tensor, vector from aesara.tensor.type import dmatrix, matrix, tensor, vector
from tests import unittest_tools as utt from tests import unittest_tools as utt
...@@ -170,122 +174,107 @@ def test_eigvalsh_grad(): ...@@ -170,122 +174,107 @@ def test_eigvalsh_grad():
) )
class TestSolve(utt.InferShapeTester): class TestSolveBase(utt.InferShapeTester):
def setup_method(self): @pytest.mark.parametrize(
self.op_class = Solve "A_func, b_func, error_message",
self.op = Solve() [
super().setup_method() (vector, matrix, "`A` must be a matrix.*"),
(
def test_infer_shape(self): functools.partial(tensor, dtype="floatX", broadcastable=(False,) * 3),
rng = np.random.default_rng(utt.fetch_seed()) matrix,
"`A` must be a matrix.*",
),
(
matrix,
functools.partial(tensor, dtype="floatX", broadcastable=(False,) * 3),
"`b` must be a matrix or a vector.*",
),
],
)
def test_make_node(self, A_func, b_func, error_message):
np.random.default_rng(utt.fetch_seed())
with pytest.raises(ValueError, match=error_message):
A = A_func()
b = b_func()
SolveBase()(A, b)
def test__repr__(self):
np.random.default_rng(utt.fetch_seed())
A = matrix() A = matrix()
b = matrix() b = matrix()
self._compile_and_check( y = SolveBase()(A, b)
[A, b], # aesara.function inputs assert y.__repr__() == "SolveBase{lower=False, check_finite=True}.0"
[self.op(A, b)], # aesara.function outputs
# A must be square
[ class TestSolve(utt.InferShapeTester):
np.asarray(rng.random((5, 5)), dtype=config.floatX), def test__init__(self):
np.asarray(rng.random((5, 1)), dtype=config.floatX), with pytest.raises(ValueError) as excinfo:
], Solve(assume_a="test")
self.op_class, assert "is not a recognized matrix structure" in str(excinfo.value)
warn=False,
) @pytest.mark.parametrize("b_shape", [(5, 1), (5,)])
def test_infer_shape(self, b_shape):
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
A = matrix() A = matrix()
b = vector() b_val = np.asarray(rng.random(b_shape), dtype=config.floatX)
b = aet.as_tensor_variable(b_val).type()
self._compile_and_check( self._compile_and_check(
[A, b], # aesara.function inputs [A, b],
[self.op(A, b)], # aesara.function outputs [solve(A, b)],
# A must be square
[ [
np.asarray(rng.random((5, 5)), dtype=config.floatX), np.asarray(rng.random((5, 5)), dtype=config.floatX),
np.asarray(rng.random((5)), dtype=config.floatX), b_val,
], ],
self.op_class, Solve,
warn=False, warn=False,
) )
def test_solve_correctness(self): def test_correctness(self):
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
A = matrix() A = matrix()
b = matrix() b = matrix()
y = self.op(A, b) y = solve(A, b)
gen_solve_func = aesara.function([A, b], y) gen_solve_func = aesara.function([A, b], y)
cholesky_lower = Cholesky(lower=True)
L = cholesky_lower(A)
y_lower = self.op(L, b)
lower_solve_func = aesara.function([L, b], y_lower)
cholesky_upper = Cholesky(lower=False)
U = cholesky_upper(A)
y_upper = self.op(U, b)
upper_solve_func = aesara.function([U, b], y_upper)
b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX) b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX)
# 1-test general case
A_val = np.asarray(rng.random((5, 5)), dtype=config.floatX) A_val = np.asarray(rng.random((5, 5)), dtype=config.floatX)
# positive definite matrix:
A_val = np.dot(A_val.transpose(), A_val) A_val = np.dot(A_val.transpose(), A_val)
assert np.allclose( assert np.allclose(
scipy.linalg.solve(A_val, b_val), gen_solve_func(A_val, b_val) scipy.linalg.solve(A_val, b_val), gen_solve_func(A_val, b_val)
) )
# 2-test lower traingular case A_undef = np.array(
L_val = scipy.linalg.cholesky(A_val, lower=True) [
assert np.allclose( [1, 0, 0, 0, 0],
scipy.linalg.solve_triangular(L_val, b_val, lower=True), [0, 1, 0, 0, 0],
lower_solve_func(L_val, b_val), [0, 0, 1, 0, 0],
[0, 0, 0, 1, 1],
[0, 0, 0, 1, 0],
],
dtype=config.floatX,
) )
# 3-test upper traingular case
U_val = scipy.linalg.cholesky(A_val, lower=False)
assert np.allclose( assert np.allclose(
scipy.linalg.solve_triangular(U_val, b_val, lower=False), scipy.linalg.solve(A_undef, b_val), gen_solve_func(A_undef, b_val)
upper_solve_func(U_val, b_val),
) )
def test_solve_dtype(self): @pytest.mark.parametrize(
dtypes = [ "m, n, assume_a, lower",
"uint8", [
"uint16", (5, None, "gen", False),
"uint32", (5, None, "gen", True),
"uint64", (4, 2, "gen", False),
"int8", (4, 2, "gen", True),
"int16", ],
"int32", )
"int64", def test_solve_grad(self, m, n, assume_a, lower):
"float16", rng = np.random.default_rng(utt.fetch_seed())
"float32",
"float64",
]
A_val = np.eye(2)
b_val = np.ones((2, 1))
# try all dtype combinations
for A_dtype, b_dtype in itertools.product(dtypes, dtypes):
A = matrix(dtype=A_dtype)
b = matrix(dtype=b_dtype)
x = solve(A, b)
fn = function([A, b], x)
x_result = fn(A_val.astype(A_dtype), b_val.astype(b_dtype))
assert x.dtype == x_result.dtype
def verify_solve_grad(self, m, n, assume_a, lower, rng): # Ensure diagonal elements of `A` are relatively large to avoid
# ensure diagonal elements of A relatively large to avoid numerical # numerical precision issues
# precision issues
A_val = (rng.normal(size=(m, m)) * 0.5 + np.eye(m)).astype(config.floatX) A_val = (rng.normal(size=(m, m)) * 0.5 + np.eye(m)).astype(config.floatX)
if assume_a != "gen":
if lower:
A_val = np.tril(A_val)
else:
A_val = np.triu(A_val)
if n is None: if n is None:
b_val = rng.normal(size=m).astype(config.floatX) b_val = rng.normal(size=m).astype(config.floatX)
else: else:
...@@ -298,22 +287,76 @@ class TestSolve(utt.InferShapeTester): ...@@ -298,22 +287,76 @@ class TestSolve(utt.InferShapeTester):
solve_op = Solve(assume_a=assume_a, lower=lower) solve_op = Solve(assume_a=assume_a, lower=lower)
utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps) utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)
class TestSolveTriangular(utt.InferShapeTester):
@pytest.mark.parametrize("b_shape", [(5, 1), (5,)])
def test_infer_shape(self, b_shape):
rng = np.random.default_rng(utt.fetch_seed())
A = matrix()
b_val = np.asarray(rng.random(b_shape), dtype=config.floatX)
b = aet.as_tensor_variable(b_val).type()
self._compile_and_check(
[A, b],
[solve_triangular(A, b)],
[
np.asarray(rng.random((5, 5)), dtype=config.floatX),
b_val,
],
SolveTriangular,
warn=False,
)
@pytest.mark.parametrize("lower", [True, False])
def test_correctness(self, lower):
rng = np.random.default_rng(utt.fetch_seed())
b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX)
A_val = np.asarray(rng.random((5, 5)), dtype=config.floatX)
A_val = np.dot(A_val.transpose(), A_val)
C_val = scipy.linalg.cholesky(A_val, lower=lower)
A = matrix()
b = matrix()
cholesky = Cholesky(lower=lower)
C = cholesky(A)
y_lower = solve_triangular(C, b, lower=lower)
lower_solve_func = aesara.function([C, b], y_lower)
assert np.allclose(
scipy.linalg.solve_triangular(C_val, b_val, lower=lower),
lower_solve_func(C_val, b_val),
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"m, n, assume_a, lower", "m, n, lower",
[ [
(5, None, "gen", False), (5, None, False),
(5, None, "gen", True), (5, None, True),
(4, 2, "gen", False), (4, 2, False),
(4, 2, "gen", True), (4, 2, True),
(5, None, "sym", False),
(5, None, "sym", True),
(4, 2, "sym", False),
(4, 2, "sym", True),
], ],
) )
def test_solve_grad(self, m, n, assume_a, lower): def test_solve_grad(self, m, n, lower):
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
self.verify_solve_grad(m, n, assume_a, lower, rng)
# Ensure diagonal elements of `A` are relatively large to avoid
# numerical precision issues
A_val = (rng.normal(size=(m, m)) * 0.5 + np.eye(m)).astype(config.floatX)
if n is None:
b_val = rng.normal(size=m).astype(config.floatX)
else:
b_val = rng.normal(size=(m, n)).astype(config.floatX)
eps = None
if config.floatX == "float64":
eps = 2e-8
solve_op = SolveTriangular(lower=lower)
utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)
class TestCholeskySolve(utt.InferShapeTester): class TestCholeskySolve(utt.InferShapeTester):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论