提交 a3eed0b4 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Thomas Wiecki

Blockwise some linalg Ops by default

上级 7fb4e70a
......@@ -3764,7 +3764,7 @@ def stacklists(arg):
return arg
def swapaxes(y, axis1, axis2):
def swapaxes(y, axis1: int, axis2: int) -> TensorVariable:
"Swap the axes of a tensor."
y = as_tensor_variable(y)
ndim = y.ndim
......
......@@ -10,11 +10,13 @@ from pytensor.graph.op import Op
from pytensor.tensor import basic as at
from pytensor.tensor import math as tm
from pytensor.tensor.basic import as_tensor_variable, extract_diag
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector
class MatrixPinv(Op):
__props__ = ("hermitian",)
gufunc_signature = "(m,n)->(n,m)"
def __init__(self, hermitian):
self.hermitian = hermitian
......@@ -75,7 +77,7 @@ def pinv(x, hermitian=False):
solve op.
"""
return MatrixPinv(hermitian=hermitian)(x)
return Blockwise(MatrixPinv(hermitian=hermitian))(x)
class MatrixInverse(Op):
......@@ -93,6 +95,8 @@ class MatrixInverse(Op):
"""
__props__ = ()
gufunc_signature = "(m,m)->(m,m)"
gufunc_spec = ("numpy.linalg.inv", 1, 1)
def __init__(self):
pass
......@@ -150,7 +154,7 @@ class MatrixInverse(Op):
return shapes
inv = matrix_inverse = MatrixInverse()
inv = matrix_inverse = Blockwise(MatrixInverse())
def matrix_dot(*args):
......@@ -181,6 +185,8 @@ class Det(Op):
"""
__props__ = ()
gufunc_signature = "(m,m)->()"
gufunc_spec = ("numpy.linalg.det", 1, 1)
def make_node(self, x):
x = as_tensor_variable(x)
......@@ -209,7 +215,7 @@ class Det(Op):
return "Det"
det = Det()
det = Blockwise(Det())
class SLogDet(Op):
......@@ -218,6 +224,8 @@ class SLogDet(Op):
"""
__props__ = ()
gufunc_signature = "(m, m)->(),()"
gufunc_spec = ("numpy.linalg.slogdet", 1, 2)
def make_node(self, x):
x = as_tensor_variable(x)
......@@ -242,7 +250,7 @@ class SLogDet(Op):
return "SLogDet"
slogdet = SLogDet()
slogdet = Blockwise(SLogDet())
class Eig(Op):
......@@ -252,6 +260,8 @@ class Eig(Op):
"""
__props__: Tuple[str, ...] = ()
gufunc_signature = "(m,m)->(m),(m,m)"
gufunc_spec = ("numpy.linalg.eig", 1, 2)
def make_node(self, x):
x = as_tensor_variable(x)
......@@ -270,7 +280,7 @@ class Eig(Op):
return [(n,), (n, n)]
eig = Eig()
eig = Blockwise(Eig())
class Eigh(Eig):
......
import logging
import typing
import warnings
from typing import TYPE_CHECKING, Literal, Union
from typing import TYPE_CHECKING, Literal, Optional, Union
import numpy as np
import scipy.linalg
......@@ -13,6 +13,7 @@ from pytensor.graph.op import Op
from pytensor.tensor import as_tensor_variable
from pytensor.tensor import basic as at
from pytensor.tensor import math as atm
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.nlinalg import matrix_dot
from pytensor.tensor.shape import reshape
from pytensor.tensor.type import matrix, tensor, vector
......@@ -48,6 +49,7 @@ class Cholesky(Op):
# TODO: LAPACK wrapper with in-place behavior, for solve also
__props__ = ("lower", "destructive", "on_error")
gufunc_signature = "(m,m)->(m,m)"
def __init__(self, *, lower=True, on_error="raise"):
self.lower = lower
......@@ -109,7 +111,7 @@ class Cholesky(Op):
def conjugate_solve_triangular(outer, inner):
"""Computes L^{-T} P L^{-1} for lower-triangular L."""
solve_upper = SolveTriangular(lower=False)
solve_upper = SolveTriangular(lower=False, b_ndim=2)
return solve_upper(outer.T, solve_upper(outer.T, inner.T).T)
s = conjugate_solve_triangular(
......@@ -128,7 +130,7 @@ class Cholesky(Op):
def cholesky(x, lower=True, on_error="raise"):
return Cholesky(lower=lower, on_error=on_error)(x)
return Blockwise(Cholesky(lower=lower, on_error=on_error))(x)
class SolveBase(Op):
......@@ -137,6 +139,7 @@ class SolveBase(Op):
__props__ = (
"lower",
"check_finite",
"b_ndim",
)
def __init__(
......@@ -144,9 +147,16 @@ class SolveBase(Op):
*,
lower=False,
check_finite=True,
b_ndim,
):
self.lower = lower
self.check_finite = check_finite
assert b_ndim in (1, 2)
self.b_ndim = b_ndim
if b_ndim == 1:
self.gufunc_signature = "(m,m),(m)->(m)"
else:
self.gufunc_signature = "(m,m),(m,n)->(m,n)"
def perform(self, node, inputs, outputs):
pass
......@@ -157,8 +167,8 @@ class SolveBase(Op):
if A.ndim != 2:
raise ValueError(f"`A` must be a matrix; got {A.type} instead.")
if b.ndim not in (1, 2):
raise ValueError(f"`b` must be a matrix or a vector; got {b.type} instead.")
if b.ndim != self.b_ndim:
raise ValueError(f"`b` must have {self.b_ndim} dims; got {b.type} instead.")
# Infer dtype by solving the most simple case with 1x1 matrices
o_dtype = scipy.linalg.solve(
......@@ -209,6 +219,16 @@ class SolveBase(Op):
return [A_bar, b_bar]
def _default_b_ndim(b, b_ndim):
if b_ndim is not None:
assert b_ndim in (1, 2)
return b_ndim
b = as_tensor_variable(b)
if b_ndim is None:
return min(b.ndim, 2) # By default assume the core case is a matrix
class CholeskySolve(SolveBase):
def __init__(self, **kwargs):
kwargs.setdefault("lower", True)
......@@ -228,7 +248,7 @@ class CholeskySolve(SolveBase):
raise NotImplementedError()
def cho_solve(c_and_lower, b, *, check_finite=True):
def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: Optional[int] = None):
"""Solve the linear equations A x = b, given the Cholesky factorization of A.
Parameters
......@@ -241,9 +261,15 @@ def cho_solve(c_and_lower, b, *, check_finite=True):
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
b_ndim : int
Whether the core case of b is a vector (1) or matrix (2).
This will influence how batched dimensions are interpreted.
"""
A, lower = c_and_lower
return CholeskySolve(lower=lower, check_finite=check_finite)(A, b)
b_ndim = _default_b_ndim(b, b_ndim)
return Blockwise(
CholeskySolve(lower=lower, check_finite=check_finite, b_ndim=b_ndim)
)(A, b)
class SolveTriangular(SolveBase):
......@@ -254,6 +280,7 @@ class SolveTriangular(SolveBase):
"unit_diagonal",
"lower",
"check_finite",
"b_ndim",
)
def __init__(self, *, trans=0, unit_diagonal=False, **kwargs):
......@@ -291,6 +318,7 @@ def solve_triangular(
lower: bool = False,
unit_diagonal: bool = False,
check_finite: bool = True,
b_ndim: Optional[int] = None,
) -> TensorVariable:
"""Solve the equation `a x = b` for `x`, assuming `a` is a triangular matrix.
......@@ -314,12 +342,19 @@ def solve_triangular(
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
b_ndim : int
Whether the core case of b is a vector (1) or matrix (2).
This will influence how batched dimensions are interpreted.
"""
return SolveTriangular(
lower=lower,
trans=trans,
unit_diagonal=unit_diagonal,
check_finite=check_finite,
b_ndim = _default_b_ndim(b, b_ndim)
return Blockwise(
SolveTriangular(
lower=lower,
trans=trans,
unit_diagonal=unit_diagonal,
check_finite=check_finite,
b_ndim=b_ndim,
)
)(a, b)
......@@ -332,6 +367,7 @@ class Solve(SolveBase):
"assume_a",
"lower",
"check_finite",
"b_ndim",
)
def __init__(self, *, assume_a="gen", **kwargs):
......@@ -352,7 +388,15 @@ class Solve(SolveBase):
)
def solve(a, b, *, assume_a="gen", lower=False, check_finite=True):
def solve(
a,
b,
*,
assume_a="gen",
lower=False,
check_finite=True,
b_ndim: Optional[int] = None,
):
"""Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
If the data matrix is known to be a particular type then supplying the
......@@ -375,9 +419,9 @@ def solve(a, b, *, assume_a="gen", lower=False, check_finite=True):
Parameters
----------
a : (N, N) array_like
a : (..., N, N) array_like
Square input data
b : (N, NRHS) array_like
b : (..., N, NRHS) array_like
Input data for the right hand side.
lower : bool, optional
If True, only the data contained in the lower triangle of `a`. Default
......@@ -388,11 +432,18 @@ def solve(a, b, *, assume_a="gen", lower=False, check_finite=True):
(crashes, non-termination) if the inputs do contain infinities or NaNs.
assume_a : str, optional
Valid entries are explained above.
b_ndim : int
Whether the core case of b is a vector (1) or matrix (2).
This will influence how batched dimensions are interpreted.
"""
return Solve(
lower=lower,
check_finite=check_finite,
assume_a=assume_a,
b_ndim = _default_b_ndim(b, b_ndim)
return Blockwise(
Solve(
lower=lower,
check_finite=check_finite,
assume_a=assume_a,
b_ndim=b_ndim,
)
)(a, b)
......
......@@ -91,7 +91,7 @@ def test_Cholesky(x, lower, exc):
],
)
def test_Solve(A, x, lower, exc):
g = slinalg.Solve(lower=lower)(A, x)
g = slinalg.Solve(lower=lower, b_ndim=1)(A, x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
......@@ -125,7 +125,7 @@ def test_Solve(A, x, lower, exc):
],
)
def test_SolveTriangular(A, x, lower, exc):
g = slinalg.SolveTriangular(lower=lower)(A, x)
g = slinalg.SolveTriangular(lower=lower, b_ndim=1)(A, x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
......
......@@ -9,11 +9,12 @@ from pytensor import function
from pytensor import tensor as at
from pytensor.compile import get_default_mode
from pytensor.configdefaults import config
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import _allclose
from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse
from pytensor.tensor.rewriting.linalg import inv_as_solve
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, solve
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, cholesky, solve
from pytensor.tensor.type import dmatrix, matrix, vector
from tests import unittest_tools as utt
from tests.test_rop import break_op
......@@ -23,7 +24,7 @@ def test_rop_lop():
mx = matrix("mx")
mv = matrix("mv")
v = vector("v")
y = matrix_inverse(mx).sum(axis=0)
y = MatrixInverse()(mx).sum(axis=0)
yv = pytensor.gradient.Rop(y, mx, mv)
rop_f = function([mx, mv], yv)
......@@ -83,13 +84,11 @@ def test_transinv_to_invtrans():
def test_generic_solve_to_solve_triangular():
cholesky_lower = Cholesky(lower=True)
cholesky_upper = Cholesky(lower=False)
A = matrix("A")
x = matrix("x")
L = cholesky_lower(A)
U = cholesky_upper(A)
L = cholesky(A, lower=True)
U = cholesky(A, lower=False)
b1 = solve(L, x)
b2 = solve(U, x)
f = pytensor.function([A, x], b1)
......@@ -130,15 +129,15 @@ def test_matrix_inverse_solve():
b = dmatrix("b")
node = matrix_inverse(A).dot(b).owner
[out] = inv_as_solve.transform(None, node)
assert isinstance(out.owner.op, Solve)
assert isinstance(out.owner.op, Blockwise) and isinstance(
out.owner.op.core_op, Solve
)
@pytest.mark.parametrize("tag", ("lower", "upper", None))
@pytest.mark.parametrize("cholesky_form", ("lower", "upper"))
@pytest.mark.parametrize("product", ("lower", "upper", None))
def test_cholesky_ldotlt(tag, cholesky_form, product):
cholesky = Cholesky(lower=(cholesky_form == "lower"))
transform_removes_chol = tag is not None and product == tag
transform_transposes = transform_removes_chol and cholesky_form != tag
......@@ -153,11 +152,9 @@ def test_cholesky_ldotlt(tag, cholesky_form, product):
else:
M = A
C = cholesky(M)
C = cholesky(M, lower=(cholesky_form == "lower"))
f = pytensor.function([A], C, mode=get_default_mode().including("cholesky_ldotlt"))
print(f.maker.fgraph.apply_nodes)
no_cholesky_in_graph = not any(
isinstance(node.op, Cholesky) for node in f.maker.fgraph.apply_nodes
)
......
......@@ -24,6 +24,7 @@ def test_vectorize_blockwise():
assert isinstance(vect_node.op, Blockwise) and isinstance(
vect_node.op.core_op, MatrixInverse
)
assert vect_node.op.signature == ("(m,m)->(m,m)")
assert vect_node.inputs[0] is tns
# Useless blockwise
......@@ -253,6 +254,11 @@ class TestMatrixInverse(MatrixOpBlockwiseTester):
signature = "(m, m) -> (m, m)"
class TestSolve(BlockwiseOpTester):
core_op = Solve(lower=True)
class TestSolveVector(BlockwiseOpTester):
core_op = Solve(lower=True, b_ndim=1)
signature = "(m, m),(m) -> (m)"
class TestSolveMatrix(BlockwiseOpTester):
core_op = Solve(lower=True, b_ndim=2)
signature = "(m, m),(m, n) -> (m, n)"
......@@ -181,7 +181,7 @@ class TestSolveBase(utt.InferShapeTester):
(
matrix,
functools.partial(tensor, dtype="floatX", shape=(None,) * 3),
"`b` must be a matrix or a vector.*",
"`b` must have 2 dims.*",
),
],
)
......@@ -190,20 +190,20 @@ class TestSolveBase(utt.InferShapeTester):
with pytest.raises(ValueError, match=error_message):
A = A_func()
b = b_func()
SolveBase()(A, b)
SolveBase(b_ndim=2)(A, b)
def test__repr__(self):
np.random.default_rng(utt.fetch_seed())
A = matrix()
b = matrix()
y = SolveBase()(A, b)
assert y.__repr__() == "SolveBase{lower=False, check_finite=True}.0"
y = SolveBase(b_ndim=2)(A, b)
assert y.__repr__() == "SolveBase{lower=False, check_finite=True, b_ndim=2}.0"
class TestSolve(utt.InferShapeTester):
def test__init__(self):
with pytest.raises(ValueError) as excinfo:
Solve(assume_a="test")
Solve(assume_a="test", b_ndim=2)
assert "is not a recognized matrix structure" in str(excinfo.value)
@pytest.mark.parametrize("b_shape", [(5, 1), (5,)])
......@@ -278,7 +278,7 @@ class TestSolve(utt.InferShapeTester):
if config.floatX == "float64":
eps = 2e-8
solve_op = Solve(assume_a=assume_a, lower=lower)
solve_op = Solve(assume_a=assume_a, lower=lower, b_ndim=1 if n is None else 2)
utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)
......@@ -349,19 +349,20 @@ class TestSolveTriangular(utt.InferShapeTester):
if config.floatX == "float64":
eps = 2e-8
solve_op = SolveTriangular(lower=lower)
solve_op = SolveTriangular(lower=lower, b_ndim=1 if n is None else 2)
utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)
class TestCholeskySolve(utt.InferShapeTester):
def setup_method(self):
self.op_class = CholeskySolve
self.op = CholeskySolve()
self.op_upper = CholeskySolve(lower=False)
super().setup_method()
def test_repr(self):
assert repr(CholeskySolve()) == "CholeskySolve(lower=True,check_finite=True)"
assert (
repr(CholeskySolve(lower=True, b_ndim=1))
== "CholeskySolve(lower=True,check_finite=True,b_ndim=1)"
)
def test_infer_shape(self):
rng = np.random.default_rng(utt.fetch_seed())
......@@ -369,7 +370,7 @@ class TestCholeskySolve(utt.InferShapeTester):
b = matrix()
self._compile_and_check(
[A, b], # pytensor.function inputs
[self.op(A, b)], # pytensor.function outputs
[self.op_class(b_ndim=2)(A, b)], # pytensor.function outputs
# A must be square
[
np.asarray(rng.random((5, 5)), dtype=config.floatX),
......@@ -383,7 +384,7 @@ class TestCholeskySolve(utt.InferShapeTester):
b = vector()
self._compile_and_check(
[A, b], # pytensor.function inputs
[self.op(A, b)], # pytensor.function outputs
[self.op_class(b_ndim=1)(A, b)], # pytensor.function outputs
# A must be square
[
np.asarray(rng.random((5, 5)), dtype=config.floatX),
......@@ -397,10 +398,10 @@ class TestCholeskySolve(utt.InferShapeTester):
rng = np.random.default_rng(utt.fetch_seed())
A = matrix()
b = matrix()
y = self.op(A, b)
y = self.op_class(lower=True, b_ndim=2)(A, b)
cho_solve_lower_func = pytensor.function([A, b], y)
y = self.op_upper(A, b)
y = self.op_class(lower=False, b_ndim=2)(A, b)
cho_solve_upper_func = pytensor.function([A, b], y)
b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX)
......@@ -435,12 +436,13 @@ class TestCholeskySolve(utt.InferShapeTester):
A_val = np.eye(2)
b_val = np.ones((2, 1))
op = self.op_class(b_ndim=2)
# try all dtype combinations
for A_dtype, b_dtype in itertools.product(dtypes, dtypes):
A = matrix(dtype=A_dtype)
b = matrix(dtype=b_dtype)
x = self.op(A, b)
x = op(A, b)
fn = function([A, b], x)
x_result = fn(A_val.astype(A_dtype), b_val.astype(b_dtype))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论