提交 bbe663d9 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

Implement numba dispatch for all `linalg.solve` modes

上级 8e5e8a40
差异被折叠。
...@@ -367,7 +367,7 @@ def numba_typify(data, dtype=None, **kwargs): ...@@ -367,7 +367,7 @@ def numba_typify(data, dtype=None, **kwargs):
def generate_fallback_impl(op, node=None, storage_map=None, **kwargs): def generate_fallback_impl(op, node=None, storage_map=None, **kwargs):
"""Create a Numba compatible function from an Aesara `Op`.""" """Create a Numba compatible function from a Pytensor `Op`."""
warnings.warn( warnings.warn(
f"Numba will use object mode to run {op}'s perform method", f"Numba will use object mode to run {op}'s perform method",
......
import logging import logging
import typing
import warnings import warnings
from collections.abc import Sequence
from functools import reduce from functools import reduce
from typing import Literal, cast from typing import Literal, cast
import numpy as np import numpy as np
import scipy.linalg import scipy.linalg as scipy_linalg
import pytensor import pytensor
import pytensor.tensor as pt import pytensor.tensor as pt
...@@ -58,7 +58,7 @@ class Cholesky(Op): ...@@ -58,7 +58,7 @@ class Cholesky(Op):
f"Cholesky only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input" f"Cholesky only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input"
) )
# Call scipy to find output dtype # Call scipy to find output dtype
dtype = scipy.linalg.cholesky(np.eye(1, dtype=x.type.dtype)).dtype dtype = scipy_linalg.cholesky(np.eye(1, dtype=x.type.dtype)).dtype
return Apply(self, [x], [tensor(shape=x.type.shape, dtype=dtype)]) return Apply(self, [x], [tensor(shape=x.type.shape, dtype=dtype)])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
...@@ -68,21 +68,21 @@ class Cholesky(Op): ...@@ -68,21 +68,21 @@ class Cholesky(Op):
# Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS # Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS
# If we have a `C_CONTIGUOUS` array we transpose to benefit from it # If we have a `C_CONTIGUOUS` array we transpose to benefit from it
if self.overwrite_a and x.flags["C_CONTIGUOUS"]: if self.overwrite_a and x.flags["C_CONTIGUOUS"]:
out[0] = scipy.linalg.cholesky( out[0] = scipy_linalg.cholesky(
x.T, x.T,
lower=not self.lower, lower=not self.lower,
check_finite=self.check_finite, check_finite=self.check_finite,
overwrite_a=True, overwrite_a=True,
).T ).T
else: else:
out[0] = scipy.linalg.cholesky( out[0] = scipy_linalg.cholesky(
x, x,
lower=self.lower, lower=self.lower,
check_finite=self.check_finite, check_finite=self.check_finite,
overwrite_a=self.overwrite_a, overwrite_a=self.overwrite_a,
) )
except scipy.linalg.LinAlgError: except scipy_linalg.LinAlgError:
if self.on_error == "raise": if self.on_error == "raise":
raise raise
else: else:
...@@ -334,7 +334,7 @@ class CholeskySolve(SolveBase): ...@@ -334,7 +334,7 @@ class CholeskySolve(SolveBase):
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
C, b = inputs C, b = inputs
rval = scipy.linalg.cho_solve( rval = scipy_linalg.cho_solve(
(C, self.lower), (C, self.lower),
b, b,
check_finite=self.check_finite, check_finite=self.check_finite,
...@@ -369,7 +369,7 @@ def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None): ...@@ -369,7 +369,7 @@ def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None):
Whether to check that the input matrices contain only finite numbers. Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs. (crashes, non-termination) if the inputs do contain infinities or NaNs.
b_ndim : int b_ndim : int
Whether the core case of b is a vector (1) or matrix (2). Whether the core case of b is a vector (1) or matrix (2).
This will influence how batched dimensions are interpreted. This will influence how batched dimensions are interpreted.
""" """
...@@ -401,7 +401,7 @@ class SolveTriangular(SolveBase): ...@@ -401,7 +401,7 @@ class SolveTriangular(SolveBase):
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
A, b = inputs A, b = inputs
outputs[0][0] = scipy.linalg.solve_triangular( outputs[0][0] = scipy_linalg.solve_triangular(
A, A,
b, b,
lower=self.lower, lower=self.lower,
...@@ -502,7 +502,7 @@ class Solve(SolveBase): ...@@ -502,7 +502,7 @@ class Solve(SolveBase):
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
a, b = inputs a, b = inputs
outputs[0][0] = scipy.linalg.solve( outputs[0][0] = scipy_linalg.solve(
a=a, a=a,
b=b, b=b,
lower=self.lower, lower=self.lower,
...@@ -619,9 +619,9 @@ class Eigvalsh(Op): ...@@ -619,9 +619,9 @@ class Eigvalsh(Op):
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(w,) = outputs (w,) = outputs
if len(inputs) == 2: if len(inputs) == 2:
w[0] = scipy.linalg.eigvalsh(a=inputs[0], b=inputs[1], lower=self.lower) w[0] = scipy_linalg.eigvalsh(a=inputs[0], b=inputs[1], lower=self.lower)
else: else:
w[0] = scipy.linalg.eigvalsh(a=inputs[0], b=None, lower=self.lower) w[0] = scipy_linalg.eigvalsh(a=inputs[0], b=None, lower=self.lower)
def grad(self, inputs, g_outputs): def grad(self, inputs, g_outputs):
a, b = inputs a, b = inputs
...@@ -675,7 +675,7 @@ class EigvalshGrad(Op): ...@@ -675,7 +675,7 @@ class EigvalshGrad(Op):
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(a, b, gw) = inputs (a, b, gw) = inputs
w, v = scipy.linalg.eigh(a, b, lower=self.lower) w, v = scipy_linalg.eigh(a, b, lower=self.lower)
gA = v.dot(np.diag(gw).dot(v.T)) gA = v.dot(np.diag(gw).dot(v.T))
gB = -v.dot(np.diag(gw * w).dot(v.T)) gB = -v.dot(np.diag(gw * w).dot(v.T))
...@@ -718,7 +718,7 @@ class Expm(Op): ...@@ -718,7 +718,7 @@ class Expm(Op):
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(A,) = inputs (A,) = inputs
(expm,) = outputs (expm,) = outputs
expm[0] = scipy.linalg.expm(A) expm[0] = scipy_linalg.expm(A)
def grad(self, inputs, outputs): def grad(self, inputs, outputs):
(A,) = inputs (A,) = inputs
...@@ -758,8 +758,8 @@ class ExpmGrad(Op): ...@@ -758,8 +758,8 @@ class ExpmGrad(Op):
# this expression. # this expression.
(A, gA) = inputs (A, gA) = inputs
(out,) = outputs (out,) = outputs
w, V = scipy.linalg.eig(A, right=True) w, V = scipy_linalg.eig(A, right=True)
U = scipy.linalg.inv(V).T U = scipy_linalg.inv(V).T
exp_w = np.exp(w) exp_w = np.exp(w)
X = np.subtract.outer(exp_w, exp_w) / np.subtract.outer(w, w) X = np.subtract.outer(exp_w, exp_w) / np.subtract.outer(w, w)
...@@ -800,7 +800,7 @@ class SolveContinuousLyapunov(Op): ...@@ -800,7 +800,7 @@ class SolveContinuousLyapunov(Op):
X = output_storage[0] X = output_storage[0]
out_dtype = node.outputs[0].type.dtype out_dtype = node.outputs[0].type.dtype
X[0] = scipy.linalg.solve_continuous_lyapunov(A, B).astype(out_dtype) X[0] = scipy_linalg.solve_continuous_lyapunov(A, B).astype(out_dtype)
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
return [shapes[0]] return [shapes[0]]
...@@ -870,7 +870,7 @@ class BilinearSolveDiscreteLyapunov(Op): ...@@ -870,7 +870,7 @@ class BilinearSolveDiscreteLyapunov(Op):
X = output_storage[0] X = output_storage[0]
out_dtype = node.outputs[0].type.dtype out_dtype = node.outputs[0].type.dtype
X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear").astype( X[0] = scipy_linalg.solve_discrete_lyapunov(A, B, method="bilinear").astype(
out_dtype out_dtype
) )
...@@ -992,7 +992,7 @@ class SolveDiscreteARE(Op): ...@@ -992,7 +992,7 @@ class SolveDiscreteARE(Op):
Q = 0.5 * (Q + Q.T) Q = 0.5 * (Q + Q.T)
out_dtype = node.outputs[0].type.dtype out_dtype = node.outputs[0].type.dtype
X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype) X[0] = scipy_linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype)
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
return [shapes[0]] return [shapes[0]]
...@@ -1064,7 +1064,7 @@ def solve_discrete_are( ...@@ -1064,7 +1064,7 @@ def solve_discrete_are(
) )
def _largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype: def _largest_common_dtype(tensors: Sequence[TensorVariable]) -> np.dtype:
return reduce(lambda l, r: np.promote_types(l, r), [x.dtype for x in tensors]) return reduce(lambda l, r: np.promote_types(l, r), [x.dtype for x in tensors])
...@@ -1118,7 +1118,7 @@ class BlockDiagonal(BaseBlockDiagonal): ...@@ -1118,7 +1118,7 @@ class BlockDiagonal(BaseBlockDiagonal):
def perform(self, node, inputs, output_storage, params=None): def perform(self, node, inputs, output_storage, params=None):
dtype = node.outputs[0].type.dtype dtype = node.outputs[0].type.dtype
output_storage[0][0] = scipy.linalg.block_diag(*inputs).astype(dtype) output_storage[0][0] = scipy_linalg.block_diag(*inputs).astype(dtype)
def block_diag(*matrices: TensorVariable): def block_diag(*matrices: TensorVariable):
...@@ -1175,4 +1175,5 @@ __all__ = [ ...@@ -1175,4 +1175,5 @@ __all__ = [
"solve_discrete_are", "solve_discrete_are",
"solve_triangular", "solve_triangular",
"block_diag", "block_diag",
"cho_solve",
] ]
...@@ -7,58 +7,13 @@ import pytensor.tensor as pt ...@@ -7,58 +7,13 @@ import pytensor.tensor as pt
from pytensor.compile.sharedvalue import SharedVariable from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import nlinalg, slinalg from pytensor.tensor import nlinalg
from tests.link.numba.test_basic import compare_numba_and_py, set_test_value from tests.link.numba.test_basic import compare_numba_and_py, set_test_value
rng = np.random.default_rng(42849) rng = np.random.default_rng(42849)
@pytest.mark.parametrize(
"A, x, lower, exc",
[
(
set_test_value(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
set_test_value(pt.dvector(), rng.random(size=(3,)).astype("float64")),
"gen",
None,
),
(
set_test_value(
pt.lmatrix(),
(lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64")
),
),
set_test_value(pt.dvector(), rng.random(size=(3,)).astype("float64")),
"gen",
None,
),
],
)
def test_Solve(A, x, lower, exc):
g = slinalg.Solve(lower=lower, b_ndim=1)(A, x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, exc", "x, exc",
[ [
......
...@@ -209,12 +209,12 @@ class TestSolveBase: ...@@ -209,12 +209,12 @@ class TestSolveBase:
) )
class TestSolve(utt.InferShapeTester): def test_solve_raises_on_invalid_A():
def test__init__(self): with pytest.raises(ValueError, match="is not a recognized matrix structure"):
with pytest.raises(ValueError) as excinfo: Solve(assume_a="test", b_ndim=2)
Solve(assume_a="test", b_ndim=2)
assert "is not a recognized matrix structure" in str(excinfo.value)
class TestSolve(utt.InferShapeTester):
@pytest.mark.parametrize("b_shape", [(5, 1), (5,)]) @pytest.mark.parametrize("b_shape", [(5, 1), (5,)])
def test_infer_shape(self, b_shape): def test_infer_shape(self, b_shape):
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
...@@ -232,64 +232,78 @@ class TestSolve(utt.InferShapeTester): ...@@ -232,64 +232,78 @@ class TestSolve(utt.InferShapeTester):
warn=False, warn=False,
) )
def test_correctness(self): @pytest.mark.parametrize(
"b_size", [(5, 1), (5, 5), (5,)], ids=["b_col_vec", "b_matrix", "b_vec"]
)
@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str)
def test_solve_correctness(self, b_size: tuple[int], assume_a: str):
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
A = matrix() A = pt.tensor("A", shape=(5, 5))
b = matrix() b = pt.tensor("b", shape=b_size)
y = solve(A, b)
gen_solve_func = pytensor.function([A, b], y)
b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX) A_val = rng.normal(size=(5, 5)).astype(config.floatX)
b_val = rng.normal(size=b_size).astype(config.floatX)
A_val = np.asarray(rng.random((5, 5)), dtype=config.floatX) solve_op = functools.partial(solve, assume_a=assume_a, b_ndim=len(b_size))
A_val = np.dot(A_val.transpose(), A_val)
np.testing.assert_allclose( def A_func(x):
scipy.linalg.solve(A_val, b_val, assume_a="gen"), if assume_a == "pos":
gen_solve_func(A_val, b_val), return x @ x.T
) elif assume_a == "sym":
return (x + x.T) / 2
else:
return x
solve_input_val = A_func(A_val)
y = solve_op(A_func(A), b)
solve_func = pytensor.function([A, b], y)
X_np = solve_func(A_val.copy(), b_val.copy())
ATOL = 1e-8 if config.floatX.endswith("64") else 1e-4
RTOL = 1e-8 if config.floatX.endswith("64") else 1e-4
A_undef = np.array(
[
[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 1],
[0, 0, 0, 1, 0],
],
dtype=config.floatX,
)
np.testing.assert_allclose( np.testing.assert_allclose(
scipy.linalg.solve(A_undef, b_val), gen_solve_func(A_undef, b_val) scipy.linalg.solve(solve_input_val, b_val, assume_a=assume_a),
X_np,
atol=ATOL,
rtol=RTOL,
) )
np.testing.assert_allclose(A_func(A_val) @ X_np, b_val, atol=ATOL, rtol=RTOL)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"m, n, assume_a, lower", "b_size", [(5, 1), (5, 5), (5,)], ids=["b_col_vec", "b_matrix", "b_vec"]
[
(5, None, "gen", False),
(5, None, "gen", True),
(4, 2, "gen", False),
(4, 2, "gen", True),
],
) )
def test_solve_grad(self, m, n, assume_a, lower): @pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str)
@pytest.mark.skipif(
config.floatX == "float32", reason="Gradients not numerically stable in float32"
)
def test_solve_gradient(self, b_size: tuple[int], assume_a: str):
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
# Ensure diagonal elements of `A` are relatively large to avoid eps = 2e-8 if config.floatX == "float64" else None
# numerical precision issues
A_val = (rng.normal(size=(m, m)) * 0.5 + np.eye(m)).astype(config.floatX)
if n is None: A_val = rng.normal(size=(5, 5)).astype(config.floatX)
b_val = rng.normal(size=m).astype(config.floatX) b_val = rng.normal(size=b_size).astype(config.floatX)
else:
b_val = rng.normal(size=(m, n)).astype(config.floatX)
eps = None def A_func(x):
if config.floatX == "float64": if assume_a == "pos":
eps = 2e-8 return x @ x.T
elif assume_a == "sym":
return (x + x.T) / 2
else:
return x
solve_op = Solve(assume_a=assume_a, lower=lower, b_ndim=1 if n is None else 2) solve_op = functools.partial(solve, assume_a=assume_a, b_ndim=len(b_size))
utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)
# To correctly check the gradients, we need to include a transformation from the space of unconstrained matrices
# (A) to a valid input matrix for the given solver. This is done by the A_func function. If this isn't included,
# the random perturbations used by verify_grad will result in invalid input matrices, and
# LAPACK will silently do the wrong thing, making the gradients wrong
utt.verify_grad(
lambda A, b: solve_op(A_func(A), b), [A_val, b_val], 3, rng, eps=eps
)
class TestSolveTriangular(utt.InferShapeTester): class TestSolveTriangular(utt.InferShapeTester):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论