Unverified 提交 c4ae6e34 authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Add `linalg.block_diag` and sparse equivalent (#576)

* Copy `block_diag` and support functions from `pymc.math` * Evaluate output in sphinx code example Co-authored-by: 's avatarRicardo Vieira <28983449+ricardoV94@users.noreply.github.com> * Test type equivalence with `isinstance` instead of `==` Co-authored-by: 's avatarRicardo Vieira <28983449+ricardoV94@users.noreply.github.com> * Typo in test function * Split `block_diag` into sparse and dense version Closely follow scipy function signature for `block_diag` * Use `as_sparse_or_tensor_variable` in `SparseBlockDiagonalMatrix` to allow sparse matrix inputs to `pytensor.sparse.block_diag` * Test sparse and dense inputs to `pytensor.sparse.block_diag` * Add numba overload for `pytensor.tensor.slinalg.block_diag` * add jax overload for `pytensor.tensor.slinalg.block_diag` * Move stand-alone `block_diag_grad` function into `grad` method * Add `format` prop to `SparseBlockDiagonalMatrix` * Use `compare_numba_and_py` in `numba\test_slinalg.py::test_block_diag` * Add support for Blockwise to `slinalg.block_diag` * Add gradient test Remove `Matrix` from `BlockDiagonal` and `SparseBlockDiagonal` `Op` names Correct errors in docstrings Move input validation to a shared class method * Remove `gufunc_signature` from `__props__` Co-authored-by: 's avatarRicardo Vieira <28983449+ricardoV94@users.noreply.github.com> * Implement correct `__props__` for subclasses of `BaseBlockMatrix` --------- Co-authored-by: 's avatarRicardo Vieira <28983449+ricardoV94@users.noreply.github.com>
上级 96f753b0
import jax
from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular
from pytensor.tensor.slinalg import BlockDiagonal, Cholesky, Solve, SolveTriangular
@jax_funcify.register(Cholesky)
......@@ -45,3 +45,11 @@ def jax_funcify_SolveTriangular(op, **kwargs):
)
return solve_triangular
@jax_funcify.register(BlockDiagonal)
def jax_funcify_BlockDiagonalMatrix(op, **kwargs):
def block_diag(*inputs):
return jax.scipy.linalg.block_diag(*inputs)
return block_diag
......@@ -9,7 +9,7 @@ from scipy import linalg
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import numba_funcify
from pytensor.tensor.slinalg import SolveTriangular
from pytensor.tensor.slinalg import BlockDiagonal, SolveTriangular
_PTR = ctypes.POINTER
......@@ -273,3 +273,25 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
return res
return solve_triangular
@numba_funcify.register(BlockDiagonal)
def numba_funcify_BlockDiagonal(op, node, **kwargs):
dtype = node.outputs[0].dtype
# TODO: Why do we always inline all functions? It doesn't work with starred args, so can't use it in this case.
@numba_basic.numba_njit(inline="never")
def block_diag(*arrs):
shapes = np.array([a.shape for a in arrs], dtype="int")
out_shape = [int(s) for s in np.sum(shapes, axis=0)]
out = np.zeros((out_shape[0], out_shape[1]), dtype=dtype)
r, c = 0, 0
for arr, shape in zip(arrs, shapes):
rr, cc = shape
out[r : r + rr, c : c + cc] = arr
r += rr
c += cc
return out
return block_diag
......@@ -7,6 +7,7 @@ http://www-users.cs.umn.edu/~saad/software/SPARSKIT/paper.ps
TODO: Automatic methods for determining best sparse format?
"""
from typing import Literal
from warnings import warn
import numpy as np
......@@ -47,6 +48,7 @@ from pytensor.tensor.math import (
trunc,
)
from pytensor.tensor.shape import shape, specify_broadcastable
from pytensor.tensor.slinalg import BaseBlockDiagonal, _largest_common_dtype
from pytensor.tensor.type import TensorType
from pytensor.tensor.type import continuous_dtypes as tensor_continuous_dtypes
from pytensor.tensor.type import discrete_dtypes as tensor_discrete_dtypes
......@@ -60,7 +62,6 @@ from pytensor.tensor.variable import (
sparse_formats = ["csc", "csr"]
"""
Types of sparse matrices to use for testing.
......@@ -183,7 +184,6 @@ def as_sparse_variable(x, name=None, ndim=None, **kwargs):
as_sparse = as_sparse_variable
as_sparse_or_tensor_variable = as_symbolic
......@@ -1800,7 +1800,7 @@ class SpSum(Op):
return r
def __str__(self):
return f"{self.__class__.__name__ }{{axis={self.axis}}}"
return f"{self.__class__.__name__}{{axis={self.axis}}}"
def sp_sum(x, axis=None, sparse_grad=False):
......@@ -2775,19 +2775,14 @@ class GreaterEqualSD(__ComparisonOpSD):
greater_equal_s_d = GreaterEqualSD()
eq = __ComparisonSwitch(equal_s_s, equal_s_d, equal_s_d)
neq = __ComparisonSwitch(not_equal_s_s, not_equal_s_d, not_equal_s_d)
lt = __ComparisonSwitch(less_than_s_s, less_than_s_d, greater_than_s_d)
gt = __ComparisonSwitch(greater_than_s_s, greater_than_s_d, less_than_s_d)
le = __ComparisonSwitch(less_equal_s_s, less_equal_s_d, greater_equal_s_d)
ge = __ComparisonSwitch(greater_equal_s_s, greater_equal_s_d, less_equal_s_d)
......@@ -2992,7 +2987,7 @@ class Remove0(Op):
l = []
if self.inplace:
l.append("inplace")
return f"{self.__class__.__name__ }{{{', '.join(l)}}}"
return f"{self.__class__.__name__}{{{', '.join(l)}}}"
def make_node(self, x):
"""
......@@ -3291,6 +3286,7 @@ class TrueDot(Op):
# Simplify code by splitting into DotSS and DotSD.
__props__ = ()
# The grad_preserves_dense attribute doesn't change the
# execution behavior. To let the optimizer merge nodes with
# different values of this attribute we shouldn't compare it
......@@ -4260,3 +4256,85 @@ class ConstructSparseFromList(Op):
construct_sparse_from_list = ConstructSparseFromList()
class SparseBlockDiagonal(BaseBlockDiagonal):
__props__ = (
"n_inputs",
"format",
)
def __init__(self, n_inputs: int, format: Literal["csc", "csr"] = "csc"):
super().__init__(n_inputs)
self.format = format
def make_node(self, *matrices):
matrices = self._validate_and_prepare_inputs(
matrices, as_sparse_or_tensor_variable
)
dtype = _largest_common_dtype(matrices)
out_type = matrix(format=self.format, dtype=dtype)
return Apply(self, matrices, [out_type])
def perform(self, node, inputs, output_storage, params=None):
dtype = node.outputs[0].type.dtype
output_storage[0][0] = scipy.sparse.block_diag(
inputs, format=self.format
).astype(dtype)
def block_diag(*matrices: TensorVariable, format: Literal["csc", "csr"] = "csc"):
r"""
Construct a block diagonal matrix from a sequence of input matrices.
Given the inputs `A`, `B` and `C`, the output will have these arrays arranged on the diagonal:
[[A, 0, 0],
[0, B, 0],
[0, 0, C]]
Parameters
----------
A, B, C ... : tensors
Input tensors to form the block diagonal matrix. last two dimensions of the inputs will be used, and all
inputs should have at least 2 dimensins.
Note that the input matrices need not be sparse themselves, and will be automatically converted to the
requested format if they are not.
format: str, optional
The format of the output sparse matrix. One of 'csr' or 'csc'. Default is 'csr'. Ignored if sparse=False.
Returns
-------
out: sparse matrix tensor
Symbolic sparse matrix in the specified format.
Examples
--------
Create a sparse block diagonal matrix from two sparse 2x2 matrices:
..code-block:: python
import numpy as np
from pytensor.sparse import block_diag
from scipy.sparse import csr_matrix
A = csr_matrix([[1, 2], [3, 4]])
B = csr_matrix([[5, 6], [7, 8]])
result_sparse = block_diag(A, B, format='csr', name='X')
print(result_sparse)
>>> SparseVariable{csr,int32}
print(result_sparse.toarray().eval())
>>> array([[1, 2, 0, 0],
>>> [3, 4, 0, 0],
>>> [0, 0, 5, 6],
>>> [0, 0, 7, 8]])
"""
if len(matrices) == 1:
return matrices
_sparse_block_diagonal = SparseBlockDiagonal(n_inputs=len(matrices), format=format)
return _sparse_block_diagonal(*matrices)
......@@ -4279,6 +4279,25 @@ def take_along_axis(arr, indices, axis=0):
return arr[_make_along_axis_idx(arr.shape, indices, axis)]
def ix_(*args):
"""
PyTensor np.ix_ analog
See numpy.lib.index_tricks.ix_ for reference
"""
out = []
nd = len(args)
for k, new in enumerate(args):
if new is None:
out.append(slice(None))
new = as_tensor(new)
if new.ndim != 1:
raise ValueError("Cross index must be 1 dimensional")
new = new.reshape((1,) * k + (new.size,) + (1,) * (nd - k - 1))
out.append(new)
return tuple(out)
__all__ = [
"take_along_axis",
"expand_dims",
......
import logging
import typing
import warnings
from functools import reduce
from typing import TYPE_CHECKING, Literal, Optional, Union
import numpy as np
......@@ -23,7 +24,6 @@ from pytensor.tensor.variable import TensorVariable
if TYPE_CHECKING:
from pytensor.tensor import TensorLike
logger = logging.getLogger(__name__)
......@@ -908,6 +908,107 @@ def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable:
)
def _largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype:
return reduce(lambda l, r: np.promote_types(l, r), [x.dtype for x in tensors])
class BaseBlockDiagonal(Op):
__props__ = ("n_inputs",)
def __init__(self, n_inputs):
input_sig = ",".join([f"(m{i},n{i})" for i in range(n_inputs)])
self.gufunc_signature = f"{input_sig}->(m,n)"
if n_inputs == 0:
raise ValueError("n_inputs must be greater than 0")
self.n_inputs = n_inputs
def grad(self, inputs, gout):
shapes = pt.stack([i.shape for i in inputs])
index_end = shapes.cumsum(0)
index_begin = index_end - shapes
slices = [
ptb.ix_(
pt.arange(index_begin[i, 0], index_end[i, 0]),
pt.arange(index_begin[i, 1], index_end[i, 1]),
)
for i in range(len(inputs))
]
return [gout[0][slc] for slc in slices]
def infer_shape(self, fgraph, nodes, shapes):
first, second = zip(*shapes)
return [(pt.add(*first), pt.add(*second))]
def _validate_and_prepare_inputs(self, matrices, as_tensor_func):
if len(matrices) != self.n_inputs:
raise ValueError(
f"Expected {self.n_inputs} matri{'ces' if self.n_inputs > 1 else 'x'}, got {len(matrices)}"
)
matrices = list(map(as_tensor_func, matrices))
if any(mat.type.ndim != 2 for mat in matrices):
raise TypeError("All inputs must have dimension 2")
return matrices
class BlockDiagonal(BaseBlockDiagonal):
__props__ = ("n_inputs",)
def make_node(self, *matrices):
matrices = self._validate_and_prepare_inputs(matrices, pt.as_tensor)
dtype = _largest_common_dtype(matrices)
out_type = pytensor.tensor.matrix(dtype=dtype)
return Apply(self, matrices, [out_type])
def perform(self, node, inputs, output_storage, params=None):
dtype = node.outputs[0].type.dtype
output_storage[0][0] = scipy.linalg.block_diag(*inputs).astype(dtype)
def block_diag(*matrices: TensorVariable):
"""
Construct a block diagonal matrix from a sequence of input tensors.
Given the inputs `A`, `B` and `C`, the output will have these arrays arranged on the diagonal:
[[A, 0, 0],
[0, B, 0],
[0, 0, C]]
Parameters
----------
A, B, C ... : tensors
Input tensors to form the block diagonal matrix. last two dimensions of the inputs will be used, and all
inputs should have at least 2 dimensins.
Returns
-------
out: tensor
The block diagonal matrix formed from the input matrices.
Examples
--------
Create a block diagonal matrix from two 2x2 matrices:
..code-block:: python
import numpy as np
from pytensor.tensor.linalg import block_diag
A = pt.as_tensor_variable(np.array([[1, 2], [3, 4]]))
B = pt.as_tensor_variable(np.array([[5, 6], [7, 8]]))
result = block_diagonal(A, B, name='X')
print(result.eval())
>>> Out: array([[1, 2, 0, 0],
>>> [3, 4, 0, 0],
>>> [0, 0, 5, 6],
>>> [0, 0, 7, 8]])
"""
_block_diagonal_matrix = Blockwise(BlockDiagonal(n_inputs=len(matrices)))
return _block_diagonal_matrix(*matrices)
__all__ = [
"cholesky",
"solve",
......@@ -918,4 +1019,5 @@ __all__ = [
"solve_continuous_lyapunov",
"solve_discrete_are",
"solve_triangular",
"block_diag",
]
......@@ -129,3 +129,37 @@ def test_jax_SolveTriangular(trans, lower, check_finite):
np.arange(10).astype(config.floatX),
],
)
def test_jax_block_diag():
A = matrix("A")
B = matrix("B")
C = matrix("C")
D = matrix("D")
out = pt_slinalg.block_diag(A, B, C, D)
out_fg = FunctionGraph([A, B, C, D], [out])
compare_jax_and_py(
out_fg,
[
np.random.normal(size=(5, 5)).astype(config.floatX),
np.random.normal(size=(3, 3)).astype(config.floatX),
np.random.normal(size=(2, 2)).astype(config.floatX),
np.random.normal(size=(4, 4)).astype(config.floatX),
],
)
def test_jax_block_diag_blockwise():
A = pt.tensor3("A")
B = pt.tensor3("B")
out = pt_slinalg.block_diag(A, B)
out_fg = FunctionGraph([A, B], [out])
compare_jax_and_py(
out_fg,
[
np.random.normal(size=(5, 5, 5)).astype(config.floatX),
np.random.normal(size=(5, 3, 3)).astype(config.floatX),
],
)
......@@ -6,11 +6,11 @@ import pytest
import pytensor
import pytensor.tensor as pt
from pytensor import config
from tests.link.numba.test_basic import compare_numba_and_py
numba = pytest.importorskip("numba")
ATOL = 0 if config.floatX.endswith("64") else 1e-6
RTOL = 1e-7 if config.floatX.endswith("64") else 1e-6
rng = np.random.default_rng(42849)
......@@ -102,3 +102,18 @@ def test_solve_triangular_raises_on_nan_inf(value):
ValueError, match=re.escape("Non-numeric values (nan or inf) returned ")
):
f(A_tri, b)
def test_block_diag():
A = pt.matrix("A")
B = pt.matrix("B")
C = pt.matrix("C")
D = pt.matrix("D")
X = pt.linalg.block_diag(A, B, C, D)
A_val = np.random.normal(size=(5, 5))
B_val = np.random.normal(size=(3, 3))
C_val = np.random.normal(size=(2, 2))
D_val = np.random.normal(size=(4, 4))
out_fg = pytensor.graph.FunctionGraph([A, B, C, D], [X])
compare_numba_and_py(out_fg, [A_val, B_val, C_val, D_val])
......@@ -51,6 +51,7 @@ from pytensor.sparse import (
add_s_s_data,
as_sparse_or_tensor_variable,
as_sparse_variable,
block_diag,
cast,
clean,
construct_sparse_from_list,
......@@ -3389,3 +3390,21 @@ class TestSamplingDot(utt.InferShapeTester):
)
class TestSharedOptions:
pass
@pytest.mark.parametrize("format", ["csc", "csr"], ids=["csc", "csr"])
@pytest.mark.parametrize("sparse_input", [True, False], ids=["sparse", "dense"])
def test_block_diagonal(format, sparse_input):
from scipy import sparse as sp_sparse
f_array = sp_sparse.csr_matrix if sparse_input else np.array
A = f_array([[1, 2], [3, 4]]).astype(config.floatX)
B = f_array([[5, 6], [7, 8]]).astype(config.floatX)
result = block_diag(A, B, format=format)
assert result.owner.op._props_dict() == {"n_inputs": 2, "format": format}
sp_result = sp_sparse.block_diag([A, B], format=format)
assert isinstance(result.eval(), type(sp_result))
np.testing.assert_allclose(result.eval().toarray(), sp_result.toarray())
......@@ -15,6 +15,7 @@ from pytensor.tensor.slinalg import (
Solve,
SolveBase,
SolveTriangular,
block_diag,
cho_solve,
cholesky,
eigvalsh,
......@@ -661,3 +662,40 @@ def test_solve_discrete_are_grad():
rng=rng,
abs_tol=atol,
)
def test_block_diagonal():
A = np.array([[1.0, 2.0], [3.0, 4.0]])
B = np.array([[5.0, 6.0], [7.0, 8.0]])
result = block_diag(A, B)
assert result.owner.op.core_op._props_dict() == {"n_inputs": 2}
np.testing.assert_allclose(result.eval(), scipy.linalg.block_diag(A, B))
def test_block_diagonal_grad():
A = np.array([[1.0, 2.0], [3.0, 4.0]])
B = np.array([[5.0, 6.0], [7.0, 8.0]])
utt.verify_grad(block_diag, pt=[A, B], rng=np.random.default_rng())
def test_block_diagonal_blockwise():
batch_size = 5
A = np.random.normal(size=(batch_size, 2, 2)).astype(config.floatX)
B = np.random.normal(size=(batch_size, 4, 4)).astype(config.floatX)
result = block_diag(A, B).eval()
assert result.shape == (batch_size, 6, 6)
for i in range(batch_size):
np.testing.assert_allclose(
result[i],
scipy.linalg.block_diag(A[i], B[i]),
atol=1e-4 if config.floatX == "float32" else 1e-8,
rtol=1e-4 if config.floatX == "float32" else 1e-8,
)
# Test broadcasting
A = np.random.normal(size=(10, batch_size, 2, 2)).astype(config.floatX)
B = np.random.normal(size=(1, batch_size, 4, 4)).astype(config.floatX)
result = block_diag(A, B).eval()
assert result.shape == (10, batch_size, 6, 6)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论