提交 9df35e8d authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

Add rewrite to lift linear algebra through certain linalg ops

上级 d34760d7
......@@ -7,7 +7,7 @@ from functools import partial
from typing import cast
import pytensor.tensor as pt
from pytensor import function
from pytensor.compile.function import function
from pytensor.compile.function.pfunc import rebuild_collect_shared
from pytensor.compile.mode import optdb
from pytensor.compile.sharedvalue import SharedVariable
......
......@@ -7,6 +7,7 @@ import numpy as np
from numpy.core.numeric import normalize_axis_tuple # type: ignore
from pytensor import scalar as ps
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
......@@ -1011,6 +1012,12 @@ def tensorsolve(a, b, axes=None):
return TensorSolve(axes)(a, b)
class KroneckerProduct(OpFromGraph):
"""
Wrapper Op for Kronecker graphs
"""
def kron(a, b):
"""Kronecker product.
......@@ -1042,7 +1049,8 @@ def kron(a, b):
out_shape = tuple(a.shape * b.shape)
output_out_of_shape = a_reshaped * b_reshaped
output_reshaped = output_out_of_shape.reshape(out_shape)
return output_reshaped
return KroneckerProduct(inputs=[a, b], outputs=[output_reshaped])(a, b)
__all__ = [
......
import logging
from collections.abc import Callable
from typing import cast
from pytensor import Variable
from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.tensor.basic import TensorVariable, diagonal
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
from pytensor.tensor.nlinalg import MatrixInverse, det
from pytensor.tensor.nlinalg import (
KroneckerProduct,
MatrixInverse,
MatrixPinv,
det,
inv,
kron,
pinv,
)
from pytensor.tensor.rewriting.basic import (
register_canonicalize,
register_specialize,
register_stabilize,
)
from pytensor.tensor.slinalg import (
BlockDiagonal,
Cholesky,
Solve,
SolveBase,
block_diag,
cholesky,
solve,
solve_triangular,
......@@ -305,3 +318,62 @@ def local_log_prod_sqr(fgraph, node):
# TODO: have a reduction like prod and sum that simply
# returns the sign of the prod multiplication.
@register_specialize
@node_rewriter([Blockwise])
def local_lift_through_linalg(
fgraph: FunctionGraph, node: Apply
) -> list[Variable] | None:
"""
Rewrite compositions of linear algebra operations by lifting expensive operations (Cholesky, Inverse) through Ops
that join matrices (KroneckerProduct, BlockDiagonal).
This rewrite takes advantage of commutation between certain linear algebra operations to do several smaller matrix
operations on component matrices instead of one large one. For example, when taking the inverse of Kronecker
product, we can take the inverse of each component matrix and then take the Kronecker product of the inverses. This
reduces the cost of the inverse from O((n*m)^3) to O(n^3 + m^3) where n and m are the dimensions of the component
matrices.
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
# TODO: Simplify this if we end up Blockwising KroneckerProduct
if isinstance(node.op.core_op, MatrixInverse | Cholesky | MatrixPinv):
y = node.inputs[0]
outer_op = node.op
if y.owner and (
isinstance(y.owner.op, Blockwise)
and isinstance(y.owner.op.core_op, BlockDiagonal)
or isinstance(y.owner.op, KroneckerProduct)
):
input_matrices = y.owner.inputs
if isinstance(outer_op.core_op, MatrixInverse):
outer_f = cast(Callable, inv)
elif isinstance(outer_op.core_op, Cholesky):
outer_f = cast(Callable, cholesky)
elif isinstance(outer_op.core_op, MatrixPinv):
outer_f = cast(Callable, pinv)
else:
raise NotImplementedError # pragma: no cover
inner_matrices = [cast(TensorVariable, outer_f(m)) for m in input_matrices]
if isinstance(y.owner.op, KroneckerProduct):
return [kron(*inner_matrices)]
elif isinstance(y.owner.op.core_op, BlockDiagonal):
return [block_diag(*inner_matrices)]
else:
raise NotImplementedError # pragma: no cover
......@@ -14,9 +14,16 @@ from pytensor.tensor import swapaxes
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import _allclose, dot, matmul
from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse
from pytensor.tensor.nlinalg import (
Det,
KroneckerProduct,
MatrixInverse,
MatrixPinv,
matrix_inverse,
)
from pytensor.tensor.rewriting.linalg import inv_as_solve
from pytensor.tensor.slinalg import (
BlockDiagonal,
Cholesky,
Solve,
SolveBase,
......@@ -333,3 +340,53 @@ class TestBatchedVectorBSolveToMatrixBSolve:
ref_fn(test_a, test_b),
rtol=1e-7 if config.floatX == "float64" else 1e-5,
)
@pytest.mark.parametrize(
"constructor", [pt.dmatrix, pt.tensor3], ids=["not_batched", "batched"]
)
@pytest.mark.parametrize(
"f_op, f",
[
(MatrixInverse, pt.linalg.inv),
(Cholesky, pt.linalg.cholesky),
(MatrixPinv, pt.linalg.pinv),
],
ids=["inv", "cholesky", "pinv"],
)
@pytest.mark.parametrize(
"g_op, g",
[(BlockDiagonal, pt.linalg.block_diag), (KroneckerProduct, pt.linalg.kron)],
ids=["block_diag", "kron"],
)
def test_local_lift_through_linalg(constructor, f_op, f, g_op, g):
if pytensor.config.floatX.endswith("32"):
pytest.skip("Test is flaky at half precision")
A, B = list(map(constructor, "ab"))
X = f(g(A, B))
f1 = pytensor.function(
[A, B], X, mode=get_default_mode().including("local_lift_through_linalg")
)
f2 = pytensor.function(
[A, B], X, mode=get_default_mode().excluding("local_lift_through_linalg")
)
all_apply_nodes = f1.maker.fgraph.apply_nodes
f_ops = [
x for x in all_apply_nodes if isinstance(getattr(x.op, "core_op", x.op), f_op)
]
g_ops = [
x for x in all_apply_nodes if isinstance(getattr(x.op, "core_op", x.op), g_op)
]
assert len(f_ops) == 2
assert len(g_ops) == 1
test_vals = [
np.random.normal(size=(3,) * A.ndim).astype(config.floatX) for _ in range(2)
]
test_vals = [x @ np.swapaxes(x, -1, -2) for x in test_vals]
np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8)
......@@ -590,6 +590,14 @@ class TestKron(utt.InferShapeTester):
self.op = kron
super().setup_method()
def test_vec_vec_kron_raises(self):
x = vector()
y = vector()
with pytest.raises(
TypeError, match="kron: inputs dimensions must sum to 3 or more"
):
kron(x, y)
@pytest.mark.parametrize("shp0", [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)])
@pytest.mark.parametrize("shp1", [(6,), (6, 7), (6, 7, 8), (6, 7, 8, 9)])
def test_perform(self, shp0, shp1):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论