提交 a149f6c9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Enable new `assume_a` in Solve

上级 6e06f811
import warnings
import jax import jax
from pytensor.link.jax.dispatch.basic import jax_funcify from pytensor.link.jax.dispatch.basic import jax_funcify
...@@ -39,13 +41,29 @@ def jax_funcify_Cholesky(op, **kwargs): ...@@ -39,13 +41,29 @@ def jax_funcify_Cholesky(op, **kwargs):
@jax_funcify.register(Solve) @jax_funcify.register(Solve)
def jax_funcify_Solve(op, **kwargs): def jax_funcify_Solve(op, **kwargs):
if op.assume_a != "gen" and op.lower: assume_a = op.assume_a
lower = True lower = op.lower
if assume_a == "tridiagonal":
# jax.scipy.solve does not yet support tridiagonal matrices
# But there's a jax.lax.linalg.tridiaonal_solve we can use instead.
def solve(a, b):
dl = jax.numpy.diagonal(a, offset=-1, axis1=-2, axis2=-1)
d = jax.numpy.diagonal(a, offset=0, axis1=-2, axis2=-1)
du = jax.numpy.diagonal(a, offset=1, axis1=-2, axis2=-1)
return jax.lax.linalg.tridiagonal_solve(dl, d, du, b, lower=lower)
else: else:
lower = False if assume_a not in ("gen", "sym", "her", "pos"):
warnings.warn(
f"JAX solve does not support assume_a={op.assume_a}. Defaulting to assume_a='gen'.\n"
f"If appropriate, you may want to set assume_a to one of 'sym', 'pos', 'her' or 'tridiagonal' to improve performance.",
UserWarning,
)
assume_a = "gen"
def solve(a, b, lower=lower): def solve(a, b):
return jax.scipy.linalg.solve(a, b, lower=lower) return jax.scipy.linalg.solve(a, b, lower=lower, assume_a=assume_a)
return solve return solve
......
import warnings
from collections.abc import Callable from collections.abc import Callable
import numba import numba
...@@ -1071,14 +1072,17 @@ def numba_funcify_Solve(op, node, **kwargs): ...@@ -1071,14 +1072,17 @@ def numba_funcify_Solve(op, node, **kwargs):
elif assume_a == "sym": elif assume_a == "sym":
solve_fn = _solve_symmetric solve_fn = _solve_symmetric
elif assume_a == "her": elif assume_a == "her":
raise NotImplementedError( # We already ruled out complex inputs
'Use assume_a = "sym" for symmetric real matrices. If you need compelx support, ' solve_fn = _solve_symmetric
"please open an issue on github."
)
elif assume_a == "pos": elif assume_a == "pos":
solve_fn = _solve_psd solve_fn = _solve_psd
else: else:
raise NotImplementedError(f"Assumption {assume_a} not supported in Numba mode") warnings.warn(
f"Numba assume_a={assume_a} not implemented. Falling back to general solve.\n"
f"If appropriate, you may want to set assume_a to one of 'sym', 'pos', or 'her' to improve performance.",
UserWarning,
)
solve_fn = _solve_gen
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit(inline="always")
def solve(a, b): def solve(a, b):
......
...@@ -15,6 +15,7 @@ from pytensor.graph.op import Op ...@@ -15,6 +15,7 @@ from pytensor.graph.op import Op
from pytensor.tensor import TensorLike, as_tensor_variable from pytensor.tensor import TensorLike, as_tensor_variable
from pytensor.tensor import basic as ptb from pytensor.tensor import basic as ptb
from pytensor.tensor import math as ptm from pytensor.tensor import math as ptm
from pytensor.tensor.basic import diagonal
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.nlinalg import kron, matrix_dot from pytensor.tensor.nlinalg import kron, matrix_dot
from pytensor.tensor.shape import reshape from pytensor.tensor.shape import reshape
...@@ -260,10 +261,10 @@ class SolveBase(Op): ...@@ -260,10 +261,10 @@ class SolveBase(Op):
raise ValueError(f"`b` must have {self.b_ndim} dims; got {b.type} instead.") raise ValueError(f"`b` must have {self.b_ndim} dims; got {b.type} instead.")
# Infer dtype by solving the most simple case with 1x1 matrices # Infer dtype by solving the most simple case with 1x1 matrices
inp_arr = [np.eye(1).astype(A.dtype), np.eye(1).astype(b.dtype)] o_dtype = scipy_linalg.solve(
out_arr = [[None]] np.ones((1, 1), dtype=A.dtype),
self.perform(None, inp_arr, out_arr) np.ones((1,), dtype=b.dtype),
o_dtype = out_arr[0][0].dtype ).dtype
x = tensor(dtype=o_dtype, shape=b.type.shape) x = tensor(dtype=o_dtype, shape=b.type.shape)
return Apply(self, [A, b], [x]) return Apply(self, [A, b], [x])
...@@ -315,7 +316,7 @@ def _default_b_ndim(b, b_ndim): ...@@ -315,7 +316,7 @@ def _default_b_ndim(b, b_ndim):
b = as_tensor_variable(b) b = as_tensor_variable(b)
if b_ndim is None: if b_ndim is None:
return min(b.ndim, 2) # By default assume the core case is a matrix return min(b.ndim, 2) # By default, assume the core case is a matrix
class CholeskySolve(SolveBase): class CholeskySolve(SolveBase):
...@@ -332,6 +333,19 @@ class CholeskySolve(SolveBase): ...@@ -332,6 +333,19 @@ class CholeskySolve(SolveBase):
kwargs.setdefault("lower", True) kwargs.setdefault("lower", True)
super().__init__(**kwargs) super().__init__(**kwargs)
def make_node(self, *inputs):
# Allow base class to do input validation
super_apply = super().make_node(*inputs)
A, b = super_apply.inputs
[super_out] = super_apply.outputs
# The dtype of chol_solve does not match solve, which the base class checks
dtype = scipy_linalg.cho_solve(
(np.ones((1, 1), dtype=A.dtype), False),
np.ones((1,), dtype=b.dtype),
).dtype
out = tensor(dtype=dtype, shape=super_out.type.shape)
return Apply(self, [A, b], [out])
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
C, b = inputs C, b = inputs
rval = scipy_linalg.cho_solve( rval = scipy_linalg.cho_solve(
...@@ -499,8 +513,33 @@ class Solve(SolveBase): ...@@ -499,8 +513,33 @@ class Solve(SolveBase):
) )
def __init__(self, *, assume_a="gen", **kwargs): def __init__(self, *, assume_a="gen", **kwargs):
if assume_a not in ("gen", "sym", "her", "pos"): # Triangular and diagonal are handled outside of Solve
raise ValueError(f"{assume_a} is not a recognized matrix structure") valid_options = ["gen", "sym", "her", "pos", "tridiagonal", "banded"]
assume_a = assume_a.lower()
# We use the old names as the different dispatches are more likely to support them
long_to_short = {
"general": "gen",
"symmetric": "sym",
"hermitian": "her",
"positive definite": "pos",
}
assume_a = long_to_short.get(assume_a, assume_a)
if assume_a not in valid_options:
raise ValueError(
f"Invalid assume_a: {assume_a}. It must be one of {valid_options} or {list(long_to_short.keys())}"
)
if assume_a in ("tridiagonal", "banded"):
from scipy import __version__ as sp_version
if tuple(map(int, sp_version.split(".")[:-1])) < (1, 15):
warnings.warn(
f"assume_a={assume_a} requires scipy>=1.5.0. Defaulting to assume_a='gen'.",
UserWarning,
)
assume_a = "gen"
super().__init__(**kwargs) super().__init__(**kwargs)
self.assume_a = assume_a self.assume_a = assume_a
...@@ -536,10 +575,12 @@ def solve( ...@@ -536,10 +575,12 @@ def solve(
a, a,
b, b,
*, *,
assume_a="gen", lower: bool = False,
lower=False, overwrite_a: bool = False,
transposed=False, overwrite_b: bool = False,
check_finite=True, check_finite: bool = True,
assume_a: str = "gen",
transposed: bool = False,
b_ndim: int | None = None, b_ndim: int | None = None,
): ):
"""Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix. """Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
...@@ -548,14 +589,19 @@ def solve( ...@@ -548,14 +589,19 @@ def solve(
corresponding string to ``assume_a`` key chooses the dedicated solver. corresponding string to ``assume_a`` key chooses the dedicated solver.
The available options are The available options are
=================== ======== =================== ================================
generic matrix 'gen' diagonal 'diagonal'
symmetric 'sym' tridiagonal 'tridiagonal'
hermitian 'her' banded 'banded'
positive definite 'pos' upper triangular 'upper triangular'
=================== ======== lower triangular 'lower triangular'
symmetric 'symmetric' (or 'sym')
hermitian 'hermitian' (or 'her')
positive definite 'positive definite' (or 'pos')
general 'general' (or 'gen')
=================== ================================
If omitted, ``'gen'`` is the default structure. If omitted, ``'general'`` is the default structure.
The datatype of the arrays define which solver is called regardless The datatype of the arrays define which solver is called regardless
of the values. In other words, even when the complex array entries have of the values. In other words, even when the complex array entries have
...@@ -568,23 +614,52 @@ def solve( ...@@ -568,23 +614,52 @@ def solve(
Square input data Square input data
b : (..., N, NRHS) array_like b : (..., N, NRHS) array_like
Input data for the right hand side. Input data for the right hand side.
lower : bool, optional lower : bool, default False
If True, use only the data contained in the lower triangle of `a`. Default Ignored unless ``assume_a`` is one of ``'sym'``, ``'her'``, or ``'pos'``.
is to use upper triangle. (ignored for ``'gen'``) If True, the calculation uses only the data in the lower triangle of `a`;
transposed: bool, optional entries above the diagonal are ignored. If False (default), the
If True, solves the system A^T x = b. Default is False. calculation uses only the data in the upper triangle of `a`; entries
below the diagonal are ignored.
overwrite_a : bool
Unused by PyTensor. PyTensor will always perform the operation in-place if possible.
overwrite_b : bool
Unused by PyTensor. PyTensor will always perform the operation in-place if possible.
check_finite : bool, optional check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers. Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs. (crashes, non-termination) if the inputs do contain infinities or NaNs.
assume_a : str, optional assume_a : str, optional
Valid entries are explained above. Valid entries are explained above.
transposed: bool, default False
If True, solves the system A^T x = b. Default is False.
b_ndim : int b_ndim : int
Whether the core case of b is a vector (1) or matrix (2). Whether the core case of b is a vector (1) or matrix (2).
This will influence how batched dimensions are interpreted. This will influence how batched dimensions are interpreted.
By default, we assume b_ndim = b.ndim is 2 if b.ndim > 1, else 1.
""" """
assume_a = assume_a.lower()
if assume_a in ("lower triangular", "upper triangular"):
lower = "lower" in assume_a
return solve_triangular(
a,
b,
lower=lower,
trans=transposed,
check_finite=check_finite,
b_ndim=b_ndim,
)
b_ndim = _default_b_ndim(b, b_ndim) b_ndim = _default_b_ndim(b, b_ndim)
if assume_a == "diagonal":
a_diagonal = diagonal(a, axis1=-2, axis2=-1)
b_transposed = b[None, :] if b_ndim == 1 else b.mT
x = (b_transposed / pt.expand_dims(a_diagonal, -2)).mT
if b_ndim == 1:
x = x.squeeze(-1)
return x
if transposed: if transposed:
a = a.mT a = a.mT
lower = not lower lower = not lower
......
...@@ -10,6 +10,8 @@ import pytensor ...@@ -10,6 +10,8 @@ import pytensor
from pytensor import function, grad from pytensor import function, grad
from pytensor import tensor as pt from pytensor import tensor as pt
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import equal_computations
from pytensor.tensor import TensorVariable
from pytensor.tensor.slinalg import ( from pytensor.tensor.slinalg import (
Cholesky, Cholesky,
CholeskySolve, CholeskySolve,
...@@ -211,8 +213,8 @@ class TestSolveBase: ...@@ -211,8 +213,8 @@ class TestSolveBase:
) )
def test_solve_raises_on_invalid_A(): def test_solve_raises_on_invalid_assume_a():
with pytest.raises(ValueError, match="is not a recognized matrix structure"): with pytest.raises(ValueError, match="Invalid assume_a: test. It must be one of"):
Solve(assume_a="test", b_ndim=2) Solve(assume_a="test", b_ndim=2)
...@@ -225,6 +227,10 @@ solve_test_cases = [ ...@@ -225,6 +227,10 @@ solve_test_cases = [
("pos", False, False), ("pos", False, False),
("pos", True, False), ("pos", True, False),
("pos", True, True), ("pos", True, True),
("diagonal", False, False),
("diagonal", False, True),
("tridiagonal", False, False),
("tridiagonal", False, True),
] ]
solve_test_ids = [ solve_test_ids = [
f'{assume_a}_{"lower" if lower else "upper"}_{"A^T" if transposed else "A"}' f'{assume_a}_{"lower" if lower else "upper"}_{"A^T" if transposed else "A"}'
...@@ -239,6 +245,16 @@ class TestSolve(utt.InferShapeTester): ...@@ -239,6 +245,16 @@ class TestSolve(utt.InferShapeTester):
return x @ x.T return x @ x.T
elif assume_a == "sym": elif assume_a == "sym":
return (x + x.T) / 2 return (x + x.T) / 2
elif assume_a == "diagonal":
eye_fn = pt.eye if isinstance(x, TensorVariable) else np.eye
return x * eye_fn(x.shape[1])
elif assume_a == "tridiagonal":
eye_fn = pt.eye if isinstance(x, TensorVariable) else np.eye
return x * (
eye_fn(x.shape[1], k=0)
+ eye_fn(x.shape[1], k=-1)
+ eye_fn(x.shape[1], k=1)
)
else: else:
return x return x
...@@ -346,6 +362,22 @@ class TestSolve(utt.InferShapeTester): ...@@ -346,6 +362,22 @@ class TestSolve(utt.InferShapeTester):
lambda A, b: solve_op(A_func(A), b), [A_val, b_val], 3, rng, eps=eps lambda A, b: solve_op(A_func(A), b), [A_val, b_val], 3, rng, eps=eps
) )
def test_solve_tringular_indirection(self):
a = pt.matrix("a")
b = pt.vector("b")
indirect = solve(a, b, assume_a="lower triangular")
direct = solve_triangular(a, b, lower=True, trans=False)
assert equal_computations([indirect], [direct])
indirect = solve(a, b, assume_a="upper triangular")
direct = solve_triangular(a, b, lower=False, trans=False)
assert equal_computations([indirect], [direct])
indirect = solve(a, b, assume_a="upper triangular", transposed=True)
direct = solve_triangular(a, b, lower=False, trans=True)
assert equal_computations([indirect], [direct])
class TestSolveTriangular(utt.InferShapeTester): class TestSolveTriangular(utt.InferShapeTester):
@staticmethod @staticmethod
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论