提交 86282bdd authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Update aesara.tensor.slinalg.Solve to match SciPy interface

上级 a6e461bf
......@@ -800,7 +800,7 @@ def jax_funcify_Cholesky(op, **kwargs):
@jax_funcify.register(Solve)
def jax_funcify_Solve(op, **kwargs):
if op.A_structure == "lower_triangular":
if op.assume_a != "gen" and op.lower:
lower = True
else:
lower = False
......
......@@ -1690,9 +1690,12 @@ def numba_funcify_Cholesky(op, node, **kwargs):
@numba_funcify.register(Solve)
def numba_funcify_Solve(op, node, **kwargs):
if op.A_structure == "lower_triangular" or op.A_structure == "upper_triangular":
assume_a = op.assume_a
# check_finite = op.check_finite
lower = op.A_structure == "lower_triangular"
if assume_a != "gen":
lower = op.lower
warnings.warn(
(
......@@ -1707,16 +1710,26 @@ def numba_funcify_Solve(op, node, **kwargs):
@numba.njit
def solve(a, b):
with numba.objmode(ret=ret_sig):
ret = scipy.linalg.solve_triangular(a, b, lower=lower)
ret = scipy.linalg.solve_triangular(
a,
b,
lower=lower,
# check_finite=check_finite
)
return ret
else:
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit
@numba.njit(inline="always")
def solve(a, b):
return np.linalg.solve(inputs_cast(a), inputs_cast(b)).astype(out_dtype)
return np.linalg.solve(
inputs_cast(a),
inputs_cast(b),
# assume_a=assume_a,
# check_finite=check_finite,
).astype(out_dtype)
return solve
......
......@@ -249,25 +249,25 @@ def tag_solve_triangular(fgraph, node):
replace it with a triangular solve.
"""
if node.op == solve:
if node.op.A_structure == "general":
if isinstance(node.op, Solve):
if node.op.assume_a == "gen":
A, b = node.inputs # result is solution Ax=b
if A.owner and isinstance(A.owner.op, type(cholesky)):
if A.owner and isinstance(A.owner.op, Cholesky):
if A.owner.op.lower:
return [Solve("lower_triangular")(A, b)]
return [Solve(assume_a="sym", lower=True)(A, b)]
else:
return [Solve("upper_triangular")(A, b)]
return [Solve(assume_a="sym", lower=False)(A, b)]
if (
A.owner
and isinstance(A.owner.op, DimShuffle)
and A.owner.op.new_order == (1, 0)
):
(A_T,) = A.owner.inputs
if A_T.owner and isinstance(A_T.owner.op, type(cholesky)):
if A_T.owner and isinstance(A_T.owner.op, Cholesky):
if A_T.owner.op.lower:
return [Solve("upper_triangular")(A, b)]
return [Solve(assume_a="sym", lower=False)(A, b)]
else:
return [Solve("lower_triangular")(A, b)]
return [Solve(assume_a="sym", lower=True)(A, b)]
@register_canonicalize
......@@ -286,15 +286,15 @@ def no_transpose_symmetric(fgraph, node):
@register_stabilize
@local_optimizer(None) # XXX: solve is defined later and can't be used here
def psd_solve_with_chol(fgraph, node):
if node.op == solve:
if isinstance(node.op, Solve):
A, b = node.inputs # result is solution Ax=b
if is_psd(A):
L = cholesky(A)
# N.B. this can be further reduced to a yet-unwritten cho_solve Op
# __if__ no other Op makes use of the the L matrix during the
# stabilization
Li_b = Solve("lower_triangular")(L, b)
x = Solve("upper_triangular")(L.T, Li_b)
Li_b = Solve(assume_a="sym", lower=True)(L, b)
x = Solve(assume_a="sym", lower=False)(L.T, Li_b)
return [x]
......
......@@ -59,6 +59,7 @@ from aesara.tensor import (
nlinalg,
nnet,
opt_uncanonicalize,
slinalg,
xlogx,
)
from aesara.tensor.basic import *
......
......@@ -5,27 +5,16 @@ import numpy as np
import scipy.linalg
import aesara.tensor
import aesara.tensor.basic as aet
import aesara.tensor.math as tm
from aesara.graph.basic import Apply
from aesara.graph.op import Op
from aesara.tensor import as_tensor_variable
from aesara.tensor import basic as aet
from aesara.tensor import math as atm
from aesara.tensor.type import matrix, tensor, vector
logger = logging.getLogger(__name__)
MATRIX_STRUCTURES = (
"general",
"symmetric",
"lower_triangular",
"upper_triangular",
"hermitian",
"banded",
"diagonal",
"toeplitz",
)
class Cholesky(Op):
"""
......@@ -95,7 +84,7 @@ class Cholesky(Op):
# Replace the cholesky decomposition with 1 if there are nans
# or solve_upper_triangular will throw a ValueError.
if self.on_error == "nan":
ok = ~tm.any(tm.isnan(chol_x))
ok = ~atm.any(atm.isnan(chol_x))
chol_x = aet.switch(ok, chol_x, 1)
dz = aet.switch(ok, dz, 1)
......@@ -206,17 +195,24 @@ class Solve(Op):
For on CPU and GPU.
"""
__props__ = ("A_structure", "lower", "overwrite_A", "overwrite_b")
__props__ = (
"assume_a",
"lower",
"check_finite", # "transposed"
)
def __init__(
self, A_structure="general", lower=False, overwrite_A=False, overwrite_b=False
self,
assume_a="gen",
lower=False,
check_finite=True, # transposed=False
):
if A_structure not in MATRIX_STRUCTURES:
raise ValueError("Invalid matrix structure argument", A_structure)
self.A_structure = A_structure
if assume_a not in ("gen", "sym", "her", "pos"):
raise ValueError(f"{assume_a} is not a recognized matrix structure")
self.assume_a = assume_a
self.lower = lower
self.overwrite_A = overwrite_A
self.overwrite_b = overwrite_b
self.check_finite = check_finite
# self.transposed = transposed
def __repr__(self):
return "Solve{%s}" % str(self._props())
......@@ -237,12 +233,33 @@ class Solve(Op):
def perform(self, node, inputs, output_storage):
A, b = inputs
if self.A_structure == "lower_triangular":
rval = scipy.linalg.solve_triangular(A, b, lower=True)
elif self.A_structure == "upper_triangular":
rval = scipy.linalg.solve_triangular(A, b, lower=False)
if self.assume_a != "gen":
# if self.transposed:
# if self.assume_a == "her":
# trans = "C"
# else:
# trans = "T"
# else:
# trans = "N"
rval = scipy.linalg.solve_triangular(
A,
b,
lower=self.lower,
check_finite=self.check_finite,
# trans=trans
)
else:
rval = scipy.linalg.solve(A, b)
rval = scipy.linalg.solve(
A,
b,
assume_a=self.assume_a,
lower=self.lower,
check_finite=self.check_finite,
# transposed=self.transposed,
)
output_storage[0][0] = rval
# computes shape of x where x = inv(A) * b
......@@ -257,7 +274,7 @@ class Solve(Op):
def L_op(self, inputs, outputs, output_gradients):
r"""
Reverse-mode gradient updates for matrix solve operation c = A \\\ b.
Reverse-mode gradient updates for matrix solve operation :math:`c = A^{-1} b`.
Symbolic expression for updates taken from [#]_.
......@@ -269,53 +286,84 @@ class Solve(Op):
"""
A, b = inputs
c = outputs[0]
# C is a scalar representing the entire graph
# `output_gradients` is (dC/dc,)
# We need to return (dC/d[inv(A)], dC/db)
c_bar = output_gradients[0]
trans_map = {
"lower_triangular": "upper_triangular",
"upper_triangular": "lower_triangular",
}
trans_solve_op = Solve(
# update A_structure and lower to account for a transpose operation
A_structure=trans_map.get(self.A_structure, self.A_structure),
assume_a=self.assume_a,
check_finite=self.check_finite,
lower=not self.lower,
)
b_bar = trans_solve_op(A.T, c_bar)
# force outer product if vector second input
A_bar = -tm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)
if self.A_structure == "lower_triangular":
A_bar = aet.tril(A_bar)
elif self.A_structure == "upper_triangular":
A_bar = aet.triu(A_bar)
A_bar = -atm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)
if self.assume_a != "gen":
if self.lower:
A_bar = aet.tril(A_bar)
else:
A_bar = aet.triu(A_bar)
return [A_bar, b_bar]
solve = Solve()
"""
Solves the equation ``a x = b`` for x, where ``a`` is a matrix and
``b`` can be either a vector or a matrix.
Parameters
----------
a : `(M, M) symbolix matrix`
A square matrix
b : `(M,) or (M, N) symbolic vector or matrix`
Right hand side matrix in ``a x = b``
Returns
-------
x : `(M, ) or (M, N) symbolic vector or matrix`
x will have the same shape as b
"""
# lower and upper triangular solves
solve_lower_triangular = Solve(A_structure="lower_triangular", lower=True)
"""Optimized implementation of :func:`aesara.tensor.slinalg.solve` when A is lower triangular."""
solve_upper_triangular = Solve(A_structure="upper_triangular", lower=False)
"""Optimized implementation of :func:`aesara.tensor.slinalg.solve` when A is upper triangular."""
# symmetric solves
solve_symmetric = Solve(A_structure="symmetric")
"""Optimized implementation of :func:`aesara.tensor.slinalg.solve` when A is symmetric."""
def solve(a, b, assume_a="gen", lower=False, check_finite=True):
"""
Solves the linear equation set ``a * x = b`` for the unknown ``x``
for square ``a`` matrix.
If the data matrix is known to be a particular type then supplying the
corresponding string to ``assume_a`` key chooses the dedicated solver.
The available options are
=================== ========
generic matrix 'gen'
symmetric 'sym'
hermitian 'her'
positive definite 'pos'
=================== ========
If omitted, ``'gen'`` is the default structure.
The datatype of the arrays define which solver is called regardless
of the values. In other words, even when the complex array entries have
precisely zero imaginary parts, the complex solver will be called based
on the data type of the array.
Parameters
----------
a : (N, N) array_like
Square input data
b : (N, NRHS) array_like
Input data for the right hand side.
lower : bool, optional
If True, only the data contained in the lower triangle of `a`. Default
is to use upper triangle. (ignored for ``'gen'``)
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
assume_a : str, optional
Valid entries are explained above.
"""
return Solve(
lower=lower,
check_finite=check_finite,
assume_a=assume_a,
)(a, b)
# TODO: These are deprecated; emit a warning
solve_lower_triangular = Solve(assume_a="sym", lower=True)
solve_upper_triangular = Solve(assume_a="sym", lower=False)
solve_symmetric = Solve(assume_a="sym")
# TODO: Optimizations to replace multiplication by matrix inverse
# with solve() Op (still unwritten)
......@@ -456,7 +504,7 @@ def kron(a, b):
"kron: inputs dimensions must sum to 3 or more. "
f"You passed {int(a.ndim)} and {int(b.ndim)}."
)
o = tm.outer(a, b)
o = atm.outer(a, b)
o = o.reshape(aet.concatenate((a.shape, b.shape)), a.ndim + b.ndim)
shf = o.dimshuffle(0, 2, 1, *list(range(3, o.ndim)))
if shf.ndim == 3:
......
......@@ -2000,7 +2000,7 @@ def test_Cholesky(x, lower, exc):
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
set_test_value(aet.dvector(), rng.random(size=(3,)).astype("float64")),
"general",
"gen",
None,
),
(
......@@ -2011,7 +2011,7 @@ def test_Cholesky(x, lower, exc):
),
),
set_test_value(aet.dvector(), rng.random(size=(3,)).astype("float64")),
"general",
"gen",
None,
),
(
......@@ -2020,7 +2020,7 @@ def test_Cholesky(x, lower, exc):
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
set_test_value(aet.dvector(), rng.random(size=(3,)).astype("float64")),
"lower_triangular",
"sym",
UserWarning,
),
],
......
......@@ -144,12 +144,12 @@ def test_tag_solve_triangular():
if config.mode != "FAST_COMPILE":
for node in f.maker.fgraph.toposort():
if isinstance(node.op, Solve):
assert node.op.A_structure == "lower_triangular"
assert node.op.assume_a != "gen" and node.op.lower
f = aesara.function([A, x], b2)
if config.mode != "FAST_COMPILE":
for node in f.maker.fgraph.toposort():
if isinstance(node.op, Solve):
assert node.op.A_structure == "upper_triangular"
assert node.op.assume_a != "gen" and not node.op.lower
def test_matrix_inverse_solve():
......
......@@ -273,38 +273,48 @@ class TestSolve(utt.InferShapeTester):
assert x.dtype == x_result.dtype
def verify_solve_grad(self, m, n, A_structure, lower, rng):
def verify_solve_grad(self, m, n, assume_a, lower, rng):
# ensure diagonal elements of A relatively large to avoid numerical
# precision issues
A_val = (rng.normal(size=(m, m)) * 0.5 + np.eye(m)).astype(config.floatX)
if A_structure == "lower_triangular":
A_val = np.tril(A_val)
elif A_structure == "upper_triangular":
A_val = np.triu(A_val)
if assume_a != "gen":
if lower:
A_val = np.tril(A_val)
else:
A_val = np.triu(A_val)
if n is None:
b_val = rng.normal(size=m).astype(config.floatX)
else:
b_val = rng.normal(size=(m, n)).astype(config.floatX)
eps = None
if config.floatX == "float64":
eps = 2e-8
solve_op = Solve(A_structure=A_structure, lower=lower)
solve_op = Solve(assume_a=assume_a, lower=lower)
utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)
@pytest.mark.parametrize(
"m, n, assume_a, lower",
[
(5, None, "gen", False),
(5, None, "gen", True),
(4, 2, "gen", False),
(4, 2, "gen", True),
(5, None, "sym", False),
(5, None, "sym", True),
(4, 2, "sym", False),
(4, 2, "sym", True),
],
)
def test_solve_grad(self, m, n, assume_a, lower):
rng = np.random.default_rng(utt.fetch_seed())
structures = ["general", "lower_triangular", "upper_triangular"]
for A_structure in structures:
lower = A_structure == "lower_triangular"
self.verify_solve_grad(5, None, A_structure, lower, rng)
self.verify_solve_grad(6, 1, A_structure, lower, rng)
self.verify_solve_grad(4, 3, A_structure, lower, rng)
# lower should have no effect for A_structure == 'general' so also
# check lower=True case
self.verify_solve_grad(4, 3, "general", lower=True, rng=rng)
self.verify_solve_grad(m, n, assume_a, lower, rng)
def test_expm():
scipy = pytest.importorskip("scipy")
rng = np.random.default_rng(utt.fetch_seed())
A = rng.standard_normal((5, 5)).astype(config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论