提交 a149f6c9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Enable new `assume_a` in Solve

上级 6e06f811
import warnings
import jax
from pytensor.link.jax.dispatch.basic import jax_funcify
......@@ -39,13 +41,29 @@ def jax_funcify_Cholesky(op, **kwargs):
@jax_funcify.register(Solve)
def jax_funcify_Solve(op, **kwargs):
if op.assume_a != "gen" and op.lower:
lower = True
assume_a = op.assume_a
lower = op.lower
if assume_a == "tridiagonal":
# jax.scipy.solve does not yet support tridiagonal matrices
# But there's a jax.lax.linalg.tridiaonal_solve we can use instead.
def solve(a, b):
dl = jax.numpy.diagonal(a, offset=-1, axis1=-2, axis2=-1)
d = jax.numpy.diagonal(a, offset=0, axis1=-2, axis2=-1)
du = jax.numpy.diagonal(a, offset=1, axis1=-2, axis2=-1)
return jax.lax.linalg.tridiagonal_solve(dl, d, du, b, lower=lower)
else:
lower = False
if assume_a not in ("gen", "sym", "her", "pos"):
warnings.warn(
f"JAX solve does not support assume_a={op.assume_a}. Defaulting to assume_a='gen'.\n"
f"If appropriate, you may want to set assume_a to one of 'sym', 'pos', 'her' or 'tridiagonal' to improve performance.",
UserWarning,
)
assume_a = "gen"
def solve(a, b, lower=lower):
return jax.scipy.linalg.solve(a, b, lower=lower)
def solve(a, b):
return jax.scipy.linalg.solve(a, b, lower=lower, assume_a=assume_a)
return solve
......
import warnings
from collections.abc import Callable
import numba
......@@ -1071,14 +1072,17 @@ def numba_funcify_Solve(op, node, **kwargs):
elif assume_a == "sym":
solve_fn = _solve_symmetric
elif assume_a == "her":
raise NotImplementedError(
'Use assume_a = "sym" for symmetric real matrices. If you need compelx support, '
"please open an issue on github."
)
# We already ruled out complex inputs
solve_fn = _solve_symmetric
elif assume_a == "pos":
solve_fn = _solve_psd
else:
raise NotImplementedError(f"Assumption {assume_a} not supported in Numba mode")
warnings.warn(
f"Numba assume_a={assume_a} not implemented. Falling back to general solve.\n"
f"If appropriate, you may want to set assume_a to one of 'sym', 'pos', or 'her' to improve performance.",
UserWarning,
)
solve_fn = _solve_gen
@numba_basic.numba_njit(inline="always")
def solve(a, b):
......
......@@ -15,6 +15,7 @@ from pytensor.graph.op import Op
from pytensor.tensor import TensorLike, as_tensor_variable
from pytensor.tensor import basic as ptb
from pytensor.tensor import math as ptm
from pytensor.tensor.basic import diagonal
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.nlinalg import kron, matrix_dot
from pytensor.tensor.shape import reshape
......@@ -260,10 +261,10 @@ class SolveBase(Op):
raise ValueError(f"`b` must have {self.b_ndim} dims; got {b.type} instead.")
# Infer dtype by solving the most simple case with 1x1 matrices
inp_arr = [np.eye(1).astype(A.dtype), np.eye(1).astype(b.dtype)]
out_arr = [[None]]
self.perform(None, inp_arr, out_arr)
o_dtype = out_arr[0][0].dtype
o_dtype = scipy_linalg.solve(
np.ones((1, 1), dtype=A.dtype),
np.ones((1,), dtype=b.dtype),
).dtype
x = tensor(dtype=o_dtype, shape=b.type.shape)
return Apply(self, [A, b], [x])
......@@ -315,7 +316,7 @@ def _default_b_ndim(b, b_ndim):
b = as_tensor_variable(b)
if b_ndim is None:
return min(b.ndim, 2) # By default assume the core case is a matrix
return min(b.ndim, 2) # By default, assume the core case is a matrix
class CholeskySolve(SolveBase):
......@@ -332,6 +333,19 @@ class CholeskySolve(SolveBase):
kwargs.setdefault("lower", True)
super().__init__(**kwargs)
def make_node(self, *inputs):
# Allow base class to do input validation
super_apply = super().make_node(*inputs)
A, b = super_apply.inputs
[super_out] = super_apply.outputs
# The dtype of chol_solve does not match solve, which the base class checks
dtype = scipy_linalg.cho_solve(
(np.ones((1, 1), dtype=A.dtype), False),
np.ones((1,), dtype=b.dtype),
).dtype
out = tensor(dtype=dtype, shape=super_out.type.shape)
return Apply(self, [A, b], [out])
def perform(self, node, inputs, output_storage):
C, b = inputs
rval = scipy_linalg.cho_solve(
......@@ -499,8 +513,33 @@ class Solve(SolveBase):
)
def __init__(self, *, assume_a="gen", **kwargs):
if assume_a not in ("gen", "sym", "her", "pos"):
raise ValueError(f"{assume_a} is not a recognized matrix structure")
# Triangular and diagonal are handled outside of Solve
valid_options = ["gen", "sym", "her", "pos", "tridiagonal", "banded"]
assume_a = assume_a.lower()
# We use the old names as the different dispatches are more likely to support them
long_to_short = {
"general": "gen",
"symmetric": "sym",
"hermitian": "her",
"positive definite": "pos",
}
assume_a = long_to_short.get(assume_a, assume_a)
if assume_a not in valid_options:
raise ValueError(
f"Invalid assume_a: {assume_a}. It must be one of {valid_options} or {list(long_to_short.keys())}"
)
if assume_a in ("tridiagonal", "banded"):
from scipy import __version__ as sp_version
if tuple(map(int, sp_version.split(".")[:-1])) < (1, 15):
warnings.warn(
f"assume_a={assume_a} requires scipy>=1.5.0. Defaulting to assume_a='gen'.",
UserWarning,
)
assume_a = "gen"
super().__init__(**kwargs)
self.assume_a = assume_a
......@@ -536,10 +575,12 @@ def solve(
a,
b,
*,
assume_a="gen",
lower=False,
transposed=False,
check_finite=True,
lower: bool = False,
overwrite_a: bool = False,
overwrite_b: bool = False,
check_finite: bool = True,
assume_a: str = "gen",
transposed: bool = False,
b_ndim: int | None = None,
):
"""Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
......@@ -548,14 +589,19 @@ def solve(
corresponding string to ``assume_a`` key chooses the dedicated solver.
The available options are
=================== ========
generic matrix 'gen'
symmetric 'sym'
hermitian 'her'
positive definite 'pos'
=================== ========
=================== ================================
diagonal 'diagonal'
tridiagonal 'tridiagonal'
banded 'banded'
upper triangular 'upper triangular'
lower triangular 'lower triangular'
symmetric 'symmetric' (or 'sym')
hermitian 'hermitian' (or 'her')
positive definite 'positive definite' (or 'pos')
general 'general' (or 'gen')
=================== ================================
If omitted, ``'gen'`` is the default structure.
If omitted, ``'general'`` is the default structure.
The datatype of the arrays define which solver is called regardless
of the values. In other words, even when the complex array entries have
......@@ -568,23 +614,52 @@ def solve(
Square input data
b : (..., N, NRHS) array_like
Input data for the right hand side.
lower : bool, optional
If True, use only the data contained in the lower triangle of `a`. Default
is to use upper triangle. (ignored for ``'gen'``)
transposed: bool, optional
If True, solves the system A^T x = b. Default is False.
lower : bool, default False
Ignored unless ``assume_a`` is one of ``'sym'``, ``'her'``, or ``'pos'``.
If True, the calculation uses only the data in the lower triangle of `a`;
entries above the diagonal are ignored. If False (default), the
calculation uses only the data in the upper triangle of `a`; entries
below the diagonal are ignored.
overwrite_a : bool
Unused by PyTensor. PyTensor will always perform the operation in-place if possible.
overwrite_b : bool
Unused by PyTensor. PyTensor will always perform the operation in-place if possible.
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
assume_a : str, optional
Valid entries are explained above.
transposed: bool, default False
If True, solves the system A^T x = b. Default is False.
b_ndim : int
Whether the core case of b is a vector (1) or matrix (2).
This will influence how batched dimensions are interpreted.
By default, we assume b_ndim = b.ndim is 2 if b.ndim > 1, else 1.
"""
assume_a = assume_a.lower()
if assume_a in ("lower triangular", "upper triangular"):
lower = "lower" in assume_a
return solve_triangular(
a,
b,
lower=lower,
trans=transposed,
check_finite=check_finite,
b_ndim=b_ndim,
)
b_ndim = _default_b_ndim(b, b_ndim)
if assume_a == "diagonal":
a_diagonal = diagonal(a, axis1=-2, axis2=-1)
b_transposed = b[None, :] if b_ndim == 1 else b.mT
x = (b_transposed / pt.expand_dims(a_diagonal, -2)).mT
if b_ndim == 1:
x = x.squeeze(-1)
return x
if transposed:
a = a.mT
lower = not lower
......
......@@ -10,6 +10,8 @@ import pytensor
from pytensor import function, grad
from pytensor import tensor as pt
from pytensor.configdefaults import config
from pytensor.graph.basic import equal_computations
from pytensor.tensor import TensorVariable
from pytensor.tensor.slinalg import (
Cholesky,
CholeskySolve,
......@@ -211,8 +213,8 @@ class TestSolveBase:
)
def test_solve_raises_on_invalid_A():
with pytest.raises(ValueError, match="is not a recognized matrix structure"):
def test_solve_raises_on_invalid_assume_a():
with pytest.raises(ValueError, match="Invalid assume_a: test. It must be one of"):
Solve(assume_a="test", b_ndim=2)
......@@ -225,6 +227,10 @@ solve_test_cases = [
("pos", False, False),
("pos", True, False),
("pos", True, True),
("diagonal", False, False),
("diagonal", False, True),
("tridiagonal", False, False),
("tridiagonal", False, True),
]
solve_test_ids = [
f'{assume_a}_{"lower" if lower else "upper"}_{"A^T" if transposed else "A"}'
......@@ -239,6 +245,16 @@ class TestSolve(utt.InferShapeTester):
return x @ x.T
elif assume_a == "sym":
return (x + x.T) / 2
elif assume_a == "diagonal":
eye_fn = pt.eye if isinstance(x, TensorVariable) else np.eye
return x * eye_fn(x.shape[1])
elif assume_a == "tridiagonal":
eye_fn = pt.eye if isinstance(x, TensorVariable) else np.eye
return x * (
eye_fn(x.shape[1], k=0)
+ eye_fn(x.shape[1], k=-1)
+ eye_fn(x.shape[1], k=1)
)
else:
return x
......@@ -346,6 +362,22 @@ class TestSolve(utt.InferShapeTester):
lambda A, b: solve_op(A_func(A), b), [A_val, b_val], 3, rng, eps=eps
)
def test_solve_tringular_indirection(self):
a = pt.matrix("a")
b = pt.vector("b")
indirect = solve(a, b, assume_a="lower triangular")
direct = solve_triangular(a, b, lower=True, trans=False)
assert equal_computations([indirect], [direct])
indirect = solve(a, b, assume_a="upper triangular")
direct = solve_triangular(a, b, lower=False, trans=False)
assert equal_computations([indirect], [direct])
indirect = solve(a, b, assume_a="upper triangular", transposed=True)
direct = solve_triangular(a, b, lower=False, trans=True)
assert equal_computations([indirect], [direct])
class TestSolveTriangular(utt.InferShapeTester):
@staticmethod
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论