提交 86282bdd authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Update aesara.tensor.slinalg.Solve to match SciPy interface

上级 a6e461bf
...@@ -800,7 +800,7 @@ def jax_funcify_Cholesky(op, **kwargs): ...@@ -800,7 +800,7 @@ def jax_funcify_Cholesky(op, **kwargs):
@jax_funcify.register(Solve) @jax_funcify.register(Solve)
def jax_funcify_Solve(op, **kwargs): def jax_funcify_Solve(op, **kwargs):
if op.A_structure == "lower_triangular": if op.assume_a != "gen" and op.lower:
lower = True lower = True
else: else:
lower = False lower = False
......
...@@ -1690,9 +1690,12 @@ def numba_funcify_Cholesky(op, node, **kwargs): ...@@ -1690,9 +1690,12 @@ def numba_funcify_Cholesky(op, node, **kwargs):
@numba_funcify.register(Solve) @numba_funcify.register(Solve)
def numba_funcify_Solve(op, node, **kwargs): def numba_funcify_Solve(op, node, **kwargs):
if op.A_structure == "lower_triangular" or op.A_structure == "upper_triangular": assume_a = op.assume_a
# check_finite = op.check_finite
lower = op.A_structure == "lower_triangular" if assume_a != "gen":
lower = op.lower
warnings.warn( warnings.warn(
( (
...@@ -1707,16 +1710,26 @@ def numba_funcify_Solve(op, node, **kwargs): ...@@ -1707,16 +1710,26 @@ def numba_funcify_Solve(op, node, **kwargs):
@numba.njit @numba.njit
def solve(a, b): def solve(a, b):
with numba.objmode(ret=ret_sig): with numba.objmode(ret=ret_sig):
ret = scipy.linalg.solve_triangular(a, b, lower=lower) ret = scipy.linalg.solve_triangular(
a,
b,
lower=lower,
# check_finite=check_finite
)
return ret return ret
else: else:
out_dtype = node.outputs[0].type.numpy_dtype out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype) inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit @numba.njit(inline="always")
def solve(a, b): def solve(a, b):
return np.linalg.solve(inputs_cast(a), inputs_cast(b)).astype(out_dtype) return np.linalg.solve(
inputs_cast(a),
inputs_cast(b),
# assume_a=assume_a,
# check_finite=check_finite,
).astype(out_dtype)
return solve return solve
......
...@@ -249,25 +249,25 @@ def tag_solve_triangular(fgraph, node): ...@@ -249,25 +249,25 @@ def tag_solve_triangular(fgraph, node):
replace it with a triangular solve. replace it with a triangular solve.
""" """
if node.op == solve: if isinstance(node.op, Solve):
if node.op.A_structure == "general": if node.op.assume_a == "gen":
A, b = node.inputs # result is solution Ax=b A, b = node.inputs # result is solution Ax=b
if A.owner and isinstance(A.owner.op, type(cholesky)): if A.owner and isinstance(A.owner.op, Cholesky):
if A.owner.op.lower: if A.owner.op.lower:
return [Solve("lower_triangular")(A, b)] return [Solve(assume_a="sym", lower=True)(A, b)]
else: else:
return [Solve("upper_triangular")(A, b)] return [Solve(assume_a="sym", lower=False)(A, b)]
if ( if (
A.owner A.owner
and isinstance(A.owner.op, DimShuffle) and isinstance(A.owner.op, DimShuffle)
and A.owner.op.new_order == (1, 0) and A.owner.op.new_order == (1, 0)
): ):
(A_T,) = A.owner.inputs (A_T,) = A.owner.inputs
if A_T.owner and isinstance(A_T.owner.op, type(cholesky)): if A_T.owner and isinstance(A_T.owner.op, Cholesky):
if A_T.owner.op.lower: if A_T.owner.op.lower:
return [Solve("upper_triangular")(A, b)] return [Solve(assume_a="sym", lower=False)(A, b)]
else: else:
return [Solve("lower_triangular")(A, b)] return [Solve(assume_a="sym", lower=True)(A, b)]
@register_canonicalize @register_canonicalize
...@@ -286,15 +286,15 @@ def no_transpose_symmetric(fgraph, node): ...@@ -286,15 +286,15 @@ def no_transpose_symmetric(fgraph, node):
@register_stabilize @register_stabilize
@local_optimizer(None) # XXX: solve is defined later and can't be used here @local_optimizer(None) # XXX: solve is defined later and can't be used here
def psd_solve_with_chol(fgraph, node): def psd_solve_with_chol(fgraph, node):
if node.op == solve: if isinstance(node.op, Solve):
A, b = node.inputs # result is solution Ax=b A, b = node.inputs # result is solution Ax=b
if is_psd(A): if is_psd(A):
L = cholesky(A) L = cholesky(A)
# N.B. this can be further reduced to a yet-unwritten cho_solve Op # N.B. this can be further reduced to a yet-unwritten cho_solve Op
# __if__ no other Op makes use of the the L matrix during the # __if__ no other Op makes use of the the L matrix during the
# stabilization # stabilization
Li_b = Solve("lower_triangular")(L, b) Li_b = Solve(assume_a="sym", lower=True)(L, b)
x = Solve("upper_triangular")(L.T, Li_b) x = Solve(assume_a="sym", lower=False)(L.T, Li_b)
return [x] return [x]
......
...@@ -59,6 +59,7 @@ from aesara.tensor import ( ...@@ -59,6 +59,7 @@ from aesara.tensor import (
nlinalg, nlinalg,
nnet, nnet,
opt_uncanonicalize, opt_uncanonicalize,
slinalg,
xlogx, xlogx,
) )
from aesara.tensor.basic import * from aesara.tensor.basic import *
......
...@@ -5,27 +5,16 @@ import numpy as np ...@@ -5,27 +5,16 @@ import numpy as np
import scipy.linalg import scipy.linalg
import aesara.tensor import aesara.tensor
import aesara.tensor.basic as aet
import aesara.tensor.math as tm
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.tensor import as_tensor_variable from aesara.tensor import as_tensor_variable
from aesara.tensor import basic as aet
from aesara.tensor import math as atm
from aesara.tensor.type import matrix, tensor, vector from aesara.tensor.type import matrix, tensor, vector
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MATRIX_STRUCTURES = (
"general",
"symmetric",
"lower_triangular",
"upper_triangular",
"hermitian",
"banded",
"diagonal",
"toeplitz",
)
class Cholesky(Op): class Cholesky(Op):
""" """
...@@ -95,7 +84,7 @@ class Cholesky(Op): ...@@ -95,7 +84,7 @@ class Cholesky(Op):
# Replace the cholesky decomposition with 1 if there are nans # Replace the cholesky decomposition with 1 if there are nans
# or solve_upper_triangular will throw a ValueError. # or solve_upper_triangular will throw a ValueError.
if self.on_error == "nan": if self.on_error == "nan":
ok = ~tm.any(tm.isnan(chol_x)) ok = ~atm.any(atm.isnan(chol_x))
chol_x = aet.switch(ok, chol_x, 1) chol_x = aet.switch(ok, chol_x, 1)
dz = aet.switch(ok, dz, 1) dz = aet.switch(ok, dz, 1)
...@@ -206,17 +195,24 @@ class Solve(Op): ...@@ -206,17 +195,24 @@ class Solve(Op):
For on CPU and GPU. For on CPU and GPU.
""" """
__props__ = ("A_structure", "lower", "overwrite_A", "overwrite_b") __props__ = (
"assume_a",
"lower",
"check_finite", # "transposed"
)
def __init__( def __init__(
self, A_structure="general", lower=False, overwrite_A=False, overwrite_b=False self,
assume_a="gen",
lower=False,
check_finite=True, # transposed=False
): ):
if A_structure not in MATRIX_STRUCTURES: if assume_a not in ("gen", "sym", "her", "pos"):
raise ValueError("Invalid matrix structure argument", A_structure) raise ValueError(f"{assume_a} is not a recognized matrix structure")
self.A_structure = A_structure self.assume_a = assume_a
self.lower = lower self.lower = lower
self.overwrite_A = overwrite_A self.check_finite = check_finite
self.overwrite_b = overwrite_b # self.transposed = transposed
def __repr__(self): def __repr__(self):
return "Solve{%s}" % str(self._props()) return "Solve{%s}" % str(self._props())
...@@ -237,12 +233,33 @@ class Solve(Op): ...@@ -237,12 +233,33 @@ class Solve(Op):
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
A, b = inputs A, b = inputs
if self.A_structure == "lower_triangular":
rval = scipy.linalg.solve_triangular(A, b, lower=True) if self.assume_a != "gen":
elif self.A_structure == "upper_triangular": # if self.transposed:
rval = scipy.linalg.solve_triangular(A, b, lower=False) # if self.assume_a == "her":
# trans = "C"
# else:
# trans = "T"
# else:
# trans = "N"
rval = scipy.linalg.solve_triangular(
A,
b,
lower=self.lower,
check_finite=self.check_finite,
# trans=trans
)
else: else:
rval = scipy.linalg.solve(A, b) rval = scipy.linalg.solve(
A,
b,
assume_a=self.assume_a,
lower=self.lower,
check_finite=self.check_finite,
# transposed=self.transposed,
)
output_storage[0][0] = rval output_storage[0][0] = rval
# computes shape of x where x = inv(A) * b # computes shape of x where x = inv(A) * b
...@@ -257,7 +274,7 @@ class Solve(Op): ...@@ -257,7 +274,7 @@ class Solve(Op):
def L_op(self, inputs, outputs, output_gradients): def L_op(self, inputs, outputs, output_gradients):
r""" r"""
Reverse-mode gradient updates for matrix solve operation c = A \\\ b. Reverse-mode gradient updates for matrix solve operation :math:`c = A^{-1} b`.
Symbolic expression for updates taken from [#]_. Symbolic expression for updates taken from [#]_.
...@@ -269,53 +286,84 @@ class Solve(Op): ...@@ -269,53 +286,84 @@ class Solve(Op):
""" """
A, b = inputs A, b = inputs
c = outputs[0] c = outputs[0]
# C is a scalar representing the entire graph
# `output_gradients` is (dC/dc,)
# We need to return (dC/d[inv(A)], dC/db)
c_bar = output_gradients[0] c_bar = output_gradients[0]
trans_map = {
"lower_triangular": "upper_triangular",
"upper_triangular": "lower_triangular",
}
trans_solve_op = Solve( trans_solve_op = Solve(
# update A_structure and lower to account for a transpose operation assume_a=self.assume_a,
A_structure=trans_map.get(self.A_structure, self.A_structure), check_finite=self.check_finite,
lower=not self.lower, lower=not self.lower,
) )
b_bar = trans_solve_op(A.T, c_bar) b_bar = trans_solve_op(A.T, c_bar)
# force outer product if vector second input # force outer product if vector second input
A_bar = -tm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T) A_bar = -atm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)
if self.A_structure == "lower_triangular":
A_bar = aet.tril(A_bar) if self.assume_a != "gen":
elif self.A_structure == "upper_triangular": if self.lower:
A_bar = aet.triu(A_bar) A_bar = aet.tril(A_bar)
else:
A_bar = aet.triu(A_bar)
return [A_bar, b_bar] return [A_bar, b_bar]
solve = Solve() solve = Solve()
"""
Solves the equation ``a x = b`` for x, where ``a`` is a matrix and
``b`` can be either a vector or a matrix. def solve(a, b, assume_a="gen", lower=False, check_finite=True):
"""
Parameters Solves the linear equation set ``a * x = b`` for the unknown ``x``
---------- for square ``a`` matrix.
a : `(M, M) symbolix matrix`
A square matrix If the data matrix is known to be a particular type then supplying the
b : `(M,) or (M, N) symbolic vector or matrix` corresponding string to ``assume_a`` key chooses the dedicated solver.
Right hand side matrix in ``a x = b`` The available options are
=================== ========
Returns generic matrix 'gen'
------- symmetric 'sym'
x : `(M, ) or (M, N) symbolic vector or matrix` hermitian 'her'
x will have the same shape as b positive definite 'pos'
""" =================== ========
# lower and upper triangular solves
solve_lower_triangular = Solve(A_structure="lower_triangular", lower=True) If omitted, ``'gen'`` is the default structure.
"""Optimized implementation of :func:`aesara.tensor.slinalg.solve` when A is lower triangular."""
solve_upper_triangular = Solve(A_structure="upper_triangular", lower=False) The datatype of the arrays define which solver is called regardless
"""Optimized implementation of :func:`aesara.tensor.slinalg.solve` when A is upper triangular.""" of the values. In other words, even when the complex array entries have
# symmetric solves precisely zero imaginary parts, the complex solver will be called based
solve_symmetric = Solve(A_structure="symmetric") on the data type of the array.
"""Optimized implementation of :func:`aesara.tensor.slinalg.solve` when A is symmetric."""
Parameters
----------
a : (N, N) array_like
Square input data
b : (N, NRHS) array_like
Input data for the right hand side.
lower : bool, optional
If True, only the data contained in the lower triangle of `a`. Default
is to use upper triangle. (ignored for ``'gen'``)
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
assume_a : str, optional
Valid entries are explained above.
"""
return Solve(
lower=lower,
check_finite=check_finite,
assume_a=assume_a,
)(a, b)
# TODO: These are deprecated; emit a warning
solve_lower_triangular = Solve(assume_a="sym", lower=True)
solve_upper_triangular = Solve(assume_a="sym", lower=False)
solve_symmetric = Solve(assume_a="sym")
# TODO: Optimizations to replace multiplication by matrix inverse # TODO: Optimizations to replace multiplication by matrix inverse
# with solve() Op (still unwritten) # with solve() Op (still unwritten)
...@@ -456,7 +504,7 @@ def kron(a, b): ...@@ -456,7 +504,7 @@ def kron(a, b):
"kron: inputs dimensions must sum to 3 or more. " "kron: inputs dimensions must sum to 3 or more. "
f"You passed {int(a.ndim)} and {int(b.ndim)}." f"You passed {int(a.ndim)} and {int(b.ndim)}."
) )
o = tm.outer(a, b) o = atm.outer(a, b)
o = o.reshape(aet.concatenate((a.shape, b.shape)), a.ndim + b.ndim) o = o.reshape(aet.concatenate((a.shape, b.shape)), a.ndim + b.ndim)
shf = o.dimshuffle(0, 2, 1, *list(range(3, o.ndim))) shf = o.dimshuffle(0, 2, 1, *list(range(3, o.ndim)))
if shf.ndim == 3: if shf.ndim == 3:
......
...@@ -2000,7 +2000,7 @@ def test_Cholesky(x, lower, exc): ...@@ -2000,7 +2000,7 @@ def test_Cholesky(x, lower, exc):
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
), ),
set_test_value(aet.dvector(), rng.random(size=(3,)).astype("float64")), set_test_value(aet.dvector(), rng.random(size=(3,)).astype("float64")),
"general", "gen",
None, None,
), ),
( (
...@@ -2011,7 +2011,7 @@ def test_Cholesky(x, lower, exc): ...@@ -2011,7 +2011,7 @@ def test_Cholesky(x, lower, exc):
), ),
), ),
set_test_value(aet.dvector(), rng.random(size=(3,)).astype("float64")), set_test_value(aet.dvector(), rng.random(size=(3,)).astype("float64")),
"general", "gen",
None, None,
), ),
( (
...@@ -2020,7 +2020,7 @@ def test_Cholesky(x, lower, exc): ...@@ -2020,7 +2020,7 @@ def test_Cholesky(x, lower, exc):
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
), ),
set_test_value(aet.dvector(), rng.random(size=(3,)).astype("float64")), set_test_value(aet.dvector(), rng.random(size=(3,)).astype("float64")),
"lower_triangular", "sym",
UserWarning, UserWarning,
), ),
], ],
......
...@@ -144,12 +144,12 @@ def test_tag_solve_triangular(): ...@@ -144,12 +144,12 @@ def test_tag_solve_triangular():
if config.mode != "FAST_COMPILE": if config.mode != "FAST_COMPILE":
for node in f.maker.fgraph.toposort(): for node in f.maker.fgraph.toposort():
if isinstance(node.op, Solve): if isinstance(node.op, Solve):
assert node.op.A_structure == "lower_triangular" assert node.op.assume_a != "gen" and node.op.lower
f = aesara.function([A, x], b2) f = aesara.function([A, x], b2)
if config.mode != "FAST_COMPILE": if config.mode != "FAST_COMPILE":
for node in f.maker.fgraph.toposort(): for node in f.maker.fgraph.toposort():
if isinstance(node.op, Solve): if isinstance(node.op, Solve):
assert node.op.A_structure == "upper_triangular" assert node.op.assume_a != "gen" and not node.op.lower
def test_matrix_inverse_solve(): def test_matrix_inverse_solve():
......
...@@ -273,38 +273,48 @@ class TestSolve(utt.InferShapeTester): ...@@ -273,38 +273,48 @@ class TestSolve(utt.InferShapeTester):
assert x.dtype == x_result.dtype assert x.dtype == x_result.dtype
def verify_solve_grad(self, m, n, A_structure, lower, rng): def verify_solve_grad(self, m, n, assume_a, lower, rng):
# ensure diagonal elements of A relatively large to avoid numerical # ensure diagonal elements of A relatively large to avoid numerical
# precision issues # precision issues
A_val = (rng.normal(size=(m, m)) * 0.5 + np.eye(m)).astype(config.floatX) A_val = (rng.normal(size=(m, m)) * 0.5 + np.eye(m)).astype(config.floatX)
if A_structure == "lower_triangular":
A_val = np.tril(A_val) if assume_a != "gen":
elif A_structure == "upper_triangular": if lower:
A_val = np.triu(A_val) A_val = np.tril(A_val)
else:
A_val = np.triu(A_val)
if n is None: if n is None:
b_val = rng.normal(size=m).astype(config.floatX) b_val = rng.normal(size=m).astype(config.floatX)
else: else:
b_val = rng.normal(size=(m, n)).astype(config.floatX) b_val = rng.normal(size=(m, n)).astype(config.floatX)
eps = None eps = None
if config.floatX == "float64": if config.floatX == "float64":
eps = 2e-8 eps = 2e-8
solve_op = Solve(A_structure=A_structure, lower=lower)
solve_op = Solve(assume_a=assume_a, lower=lower)
utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps) utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)
@pytest.mark.parametrize(
"m, n, assume_a, lower",
[
(5, None, "gen", False),
(5, None, "gen", True),
(4, 2, "gen", False),
(4, 2, "gen", True),
(5, None, "sym", False),
(5, None, "sym", True),
(4, 2, "sym", False),
(4, 2, "sym", True),
],
)
def test_solve_grad(self, m, n, assume_a, lower):
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
structures = ["general", "lower_triangular", "upper_triangular"] self.verify_solve_grad(m, n, assume_a, lower, rng)
for A_structure in structures:
lower = A_structure == "lower_triangular"
self.verify_solve_grad(5, None, A_structure, lower, rng)
self.verify_solve_grad(6, 1, A_structure, lower, rng)
self.verify_solve_grad(4, 3, A_structure, lower, rng)
# lower should have no effect for A_structure == 'general' so also
# check lower=True case
self.verify_solve_grad(4, 3, "general", lower=True, rng=rng)
def test_expm(): def test_expm():
scipy = pytest.importorskip("scipy")
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
A = rng.standard_normal((5, 5)).astype(config.floatX) A = rng.standard_normal((5, 5)).astype(config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论