提交 9df35e8d authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

Add rewrite to lift linear algebra through certain linalg ops

上级 d34760d7
...@@ -7,7 +7,7 @@ from functools import partial ...@@ -7,7 +7,7 @@ from functools import partial
from typing import cast from typing import cast
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import function from pytensor.compile.function import function
from pytensor.compile.function.pfunc import rebuild_collect_shared from pytensor.compile.function.pfunc import rebuild_collect_shared
from pytensor.compile.mode import optdb from pytensor.compile.mode import optdb
from pytensor.compile.sharedvalue import SharedVariable from pytensor.compile.sharedvalue import SharedVariable
......
...@@ -7,6 +7,7 @@ import numpy as np ...@@ -7,6 +7,7 @@ import numpy as np
from numpy.core.numeric import normalize_axis_tuple # type: ignore from numpy.core.numeric import normalize_axis_tuple # type: ignore
from pytensor import scalar as ps from pytensor import scalar as ps
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
from pytensor.graph.op import Op from pytensor.graph.op import Op
...@@ -614,7 +615,7 @@ def svd(a, full_matrices: bool = True, compute_uv: bool = True): ...@@ -614,7 +615,7 @@ def svd(a, full_matrices: bool = True, compute_uv: bool = True):
Returns Returns
------- -------
U, V, D : matrices U, V, D : matrices
""" """
return Blockwise(SVD(full_matrices, compute_uv))(a) return Blockwise(SVD(full_matrices, compute_uv))(a)
...@@ -1011,6 +1012,12 @@ def tensorsolve(a, b, axes=None): ...@@ -1011,6 +1012,12 @@ def tensorsolve(a, b, axes=None):
return TensorSolve(axes)(a, b) return TensorSolve(axes)(a, b)
class KroneckerProduct(OpFromGraph):
"""
Wrapper Op for Kronecker graphs
"""
def kron(a, b): def kron(a, b):
"""Kronecker product. """Kronecker product.
...@@ -1042,7 +1049,8 @@ def kron(a, b): ...@@ -1042,7 +1049,8 @@ def kron(a, b):
out_shape = tuple(a.shape * b.shape) out_shape = tuple(a.shape * b.shape)
output_out_of_shape = a_reshaped * b_reshaped output_out_of_shape = a_reshaped * b_reshaped
output_reshaped = output_out_of_shape.reshape(out_shape) output_reshaped = output_out_of_shape.reshape(out_shape)
return output_reshaped
return KroneckerProduct(inputs=[a, b], outputs=[output_reshaped])(a, b)
__all__ = [ __all__ = [
......
import logging import logging
from collections.abc import Callable
from typing import cast from typing import cast
from pytensor import Variable
from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.tensor.basic import TensorVariable, diagonal from pytensor.tensor.basic import TensorVariable, diagonal
from pytensor.tensor.blas import Dot22 from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
from pytensor.tensor.nlinalg import MatrixInverse, det from pytensor.tensor.nlinalg import (
KroneckerProduct,
MatrixInverse,
MatrixPinv,
det,
inv,
kron,
pinv,
)
from pytensor.tensor.rewriting.basic import ( from pytensor.tensor.rewriting.basic import (
register_canonicalize, register_canonicalize,
register_specialize, register_specialize,
register_stabilize, register_stabilize,
) )
from pytensor.tensor.slinalg import ( from pytensor.tensor.slinalg import (
BlockDiagonal,
Cholesky, Cholesky,
Solve, Solve,
SolveBase, SolveBase,
block_diag,
cholesky, cholesky,
solve, solve,
solve_triangular, solve_triangular,
...@@ -305,3 +318,62 @@ def local_log_prod_sqr(fgraph, node): ...@@ -305,3 +318,62 @@ def local_log_prod_sqr(fgraph, node):
# TODO: have a reduction like prod and sum that simply # TODO: have a reduction like prod and sum that simply
# returns the sign of the prod multiplication. # returns the sign of the prod multiplication.
@register_specialize
@node_rewriter([Blockwise])
def local_lift_through_linalg(
fgraph: FunctionGraph, node: Apply
) -> list[Variable] | None:
"""
Rewrite compositions of linear algebra operations by lifting expensive operations (Cholesky, Inverse) through Ops
that join matrices (KroneckerProduct, BlockDiagonal).
This rewrite takes advantage of commutation between certain linear algebra operations to do several smaller matrix
operations on component matrices instead of one large one. For example, when taking the inverse of Kronecker
product, we can take the inverse of each component matrix and then take the Kronecker product of the inverses. This
reduces the cost of the inverse from O((n*m)^3) to O(n^3 + m^3) where n and m are the dimensions of the component
matrices.
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
# TODO: Simplify this if we end up Blockwising KroneckerProduct
if isinstance(node.op.core_op, MatrixInverse | Cholesky | MatrixPinv):
y = node.inputs[0]
outer_op = node.op
if y.owner and (
isinstance(y.owner.op, Blockwise)
and isinstance(y.owner.op.core_op, BlockDiagonal)
or isinstance(y.owner.op, KroneckerProduct)
):
input_matrices = y.owner.inputs
if isinstance(outer_op.core_op, MatrixInverse):
outer_f = cast(Callable, inv)
elif isinstance(outer_op.core_op, Cholesky):
outer_f = cast(Callable, cholesky)
elif isinstance(outer_op.core_op, MatrixPinv):
outer_f = cast(Callable, pinv)
else:
raise NotImplementedError # pragma: no cover
inner_matrices = [cast(TensorVariable, outer_f(m)) for m in input_matrices]
if isinstance(y.owner.op, KroneckerProduct):
return [kron(*inner_matrices)]
elif isinstance(y.owner.op.core_op, BlockDiagonal):
return [block_diag(*inner_matrices)]
else:
raise NotImplementedError # pragma: no cover
...@@ -14,9 +14,16 @@ from pytensor.tensor import swapaxes ...@@ -14,9 +14,16 @@ from pytensor.tensor import swapaxes
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import _allclose, dot, matmul from pytensor.tensor.math import _allclose, dot, matmul
from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse from pytensor.tensor.nlinalg import (
Det,
KroneckerProduct,
MatrixInverse,
MatrixPinv,
matrix_inverse,
)
from pytensor.tensor.rewriting.linalg import inv_as_solve from pytensor.tensor.rewriting.linalg import inv_as_solve
from pytensor.tensor.slinalg import ( from pytensor.tensor.slinalg import (
BlockDiagonal,
Cholesky, Cholesky,
Solve, Solve,
SolveBase, SolveBase,
...@@ -333,3 +340,53 @@ class TestBatchedVectorBSolveToMatrixBSolve: ...@@ -333,3 +340,53 @@ class TestBatchedVectorBSolveToMatrixBSolve:
ref_fn(test_a, test_b), ref_fn(test_a, test_b),
rtol=1e-7 if config.floatX == "float64" else 1e-5, rtol=1e-7 if config.floatX == "float64" else 1e-5,
) )
@pytest.mark.parametrize(
"constructor", [pt.dmatrix, pt.tensor3], ids=["not_batched", "batched"]
)
@pytest.mark.parametrize(
"f_op, f",
[
(MatrixInverse, pt.linalg.inv),
(Cholesky, pt.linalg.cholesky),
(MatrixPinv, pt.linalg.pinv),
],
ids=["inv", "cholesky", "pinv"],
)
@pytest.mark.parametrize(
"g_op, g",
[(BlockDiagonal, pt.linalg.block_diag), (KroneckerProduct, pt.linalg.kron)],
ids=["block_diag", "kron"],
)
def test_local_lift_through_linalg(constructor, f_op, f, g_op, g):
if pytensor.config.floatX.endswith("32"):
pytest.skip("Test is flaky at half precision")
A, B = list(map(constructor, "ab"))
X = f(g(A, B))
f1 = pytensor.function(
[A, B], X, mode=get_default_mode().including("local_lift_through_linalg")
)
f2 = pytensor.function(
[A, B], X, mode=get_default_mode().excluding("local_lift_through_linalg")
)
all_apply_nodes = f1.maker.fgraph.apply_nodes
f_ops = [
x for x in all_apply_nodes if isinstance(getattr(x.op, "core_op", x.op), f_op)
]
g_ops = [
x for x in all_apply_nodes if isinstance(getattr(x.op, "core_op", x.op), g_op)
]
assert len(f_ops) == 2
assert len(g_ops) == 1
test_vals = [
np.random.normal(size=(3,) * A.ndim).astype(config.floatX) for _ in range(2)
]
test_vals = [x @ np.swapaxes(x, -1, -2) for x in test_vals]
np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8)
...@@ -590,6 +590,14 @@ class TestKron(utt.InferShapeTester): ...@@ -590,6 +590,14 @@ class TestKron(utt.InferShapeTester):
self.op = kron self.op = kron
super().setup_method() super().setup_method()
def test_vec_vec_kron_raises(self):
x = vector()
y = vector()
with pytest.raises(
TypeError, match="kron: inputs dimensions must sum to 3 or more"
):
kron(x, y)
@pytest.mark.parametrize("shp0", [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)]) @pytest.mark.parametrize("shp0", [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)])
@pytest.mark.parametrize("shp1", [(6,), (6, 7), (6, 7, 8), (6, 7, 8, 9)]) @pytest.mark.parametrize("shp1", [(6,), (6, 7), (6, 7, 8), (6, 7, 8, 9)])
def test_perform(self, shp0, shp1): def test_perform(self, shp0, shp1):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论