提交 39aa1234 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Extend cholesky of triangular dot rewrite to matmul Ops

Also restrict to 2D Dot cases
上级 00546b9f
...@@ -6,7 +6,7 @@ from pytensor.tensor.basic import TensorVariable, diagonal, swapaxes ...@@ -6,7 +6,7 @@ from pytensor.tensor.basic import TensorVariable, diagonal, swapaxes
from pytensor.tensor.blas import Dot22 from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Dot, Prod, log, prod from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
from pytensor.tensor.nlinalg import MatrixInverse, det from pytensor.tensor.nlinalg import MatrixInverse, det
from pytensor.tensor.rewriting.basic import ( from pytensor.tensor.rewriting.basic import (
register_canonicalize, register_canonicalize,
...@@ -168,13 +168,25 @@ def cholesky_ldotlt(fgraph, node): ...@@ -168,13 +168,25 @@ def cholesky_ldotlt(fgraph, node):
rewrite cholesky(dot(L, L.T), lower=True) = L, where L is lower triangular, rewrite cholesky(dot(L, L.T), lower=True) = L, where L is lower triangular,
or cholesky(dot(U.T, U), upper=True) = U where U is upper triangular. or cholesky(dot(U.T, U), upper=True) = U where U is upper triangular.
Also works with matmul.
This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices. This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices.
""" """
if not isinstance(node.op.core_op, Cholesky): if not isinstance(node.op.core_op, Cholesky):
return return
A = node.inputs[0] A = node.inputs[0]
if not (A.owner and isinstance(A.owner.op, (Dot, Dot22))): if not (
A.owner is not None
and (
(
isinstance(A.owner.op, (Dot, Dot22))
# This rewrite only applies to matrix Dot
and A.owner.inputs[0].type.ndim == 2
)
or (A.owner.op == _matrix_matrix_matmul)
)
):
return return
l, r = A.owner.inputs l, r = A.owner.inputs
......
from functools import partial
import numpy as np import numpy as np
import numpy.linalg import numpy.linalg
import pytest import pytest
...@@ -9,13 +11,14 @@ from pytensor import function ...@@ -9,13 +11,14 @@ from pytensor import function
from pytensor import tensor as at from pytensor import tensor as at
from pytensor.compile import get_default_mode from pytensor.compile import get_default_mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.tensor import swapaxes
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import _allclose from pytensor.tensor.math import _allclose, dot, matmul
from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse
from pytensor.tensor.rewriting.linalg import inv_as_solve from pytensor.tensor.rewriting.linalg import inv_as_solve
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, cholesky, solve from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, cholesky, solve
from pytensor.tensor.type import dmatrix, matrix, vector from pytensor.tensor.type import dmatrix, matrix, tensor, vector
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.test_rop import break_op from tests.test_rop import break_op
...@@ -137,18 +140,20 @@ def test_matrix_inverse_solve(): ...@@ -137,18 +140,20 @@ def test_matrix_inverse_solve():
@pytest.mark.parametrize("tag", ("lower", "upper", None)) @pytest.mark.parametrize("tag", ("lower", "upper", None))
@pytest.mark.parametrize("cholesky_form", ("lower", "upper")) @pytest.mark.parametrize("cholesky_form", ("lower", "upper"))
@pytest.mark.parametrize("product", ("lower", "upper", None)) @pytest.mark.parametrize("product", ("lower", "upper", None))
def test_cholesky_ldotlt(tag, cholesky_form, product): @pytest.mark.parametrize("op", (dot, matmul))
def test_cholesky_ldotlt(tag, cholesky_form, product, op):
transform_removes_chol = tag is not None and product == tag transform_removes_chol = tag is not None and product == tag
transform_transposes = transform_removes_chol and cholesky_form != tag transform_transposes = transform_removes_chol and cholesky_form != tag
A = matrix("L") ndim = 2 if op == dot else 3
A = tensor("L", shape=(None,) * ndim)
if tag: if tag:
setattr(A.tag, tag + "_triangular", True) setattr(A.tag, tag + "_triangular", True)
if product == "lower": if product == "lower":
M = A.dot(A.T) M = op(A, swapaxes(A, -1, -2))
elif product == "upper": elif product == "upper":
M = A.T.dot(A) M = op(swapaxes(A, -1, -2), A)
else: else:
M = A M = A
...@@ -156,14 +161,17 @@ def test_cholesky_ldotlt(tag, cholesky_form, product): ...@@ -156,14 +161,17 @@ def test_cholesky_ldotlt(tag, cholesky_form, product):
f = pytensor.function([A], C, mode=get_default_mode().including("cholesky_ldotlt")) f = pytensor.function([A], C, mode=get_default_mode().including("cholesky_ldotlt"))
no_cholesky_in_graph = not any( no_cholesky_in_graph = not any(
isinstance(node.op, Cholesky) for node in f.maker.fgraph.apply_nodes isinstance(node.op, Cholesky)
or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Cholesky))
for node in f.maker.fgraph.apply_nodes
) )
assert no_cholesky_in_graph == transform_removes_chol assert no_cholesky_in_graph == transform_removes_chol
if transform_transposes: if transform_transposes:
expected_order = (1, 0) if ndim == 2 else (0, 2, 1)
assert any( assert any(
isinstance(node.op, DimShuffle) and node.op.new_order == (1, 0) isinstance(node.op, DimShuffle) and node.op.new_order == expected_order
for node in f.maker.fgraph.apply_nodes for node in f.maker.fgraph.apply_nodes
) )
...@@ -183,6 +191,11 @@ def test_cholesky_ldotlt(tag, cholesky_form, product): ...@@ -183,6 +191,11 @@ def test_cholesky_ldotlt(tag, cholesky_form, product):
] ]
) )
cholesky_vect_fn = np.vectorize(
partial(scipy.linalg.cholesky, lower=(cholesky_form == "lower")),
signature="(a, a)->(a, a)",
)
for Av in Avs: for Av in Avs:
if tag == "upper": if tag == "upper":
Av = Av.T Av = Av.T
...@@ -194,12 +207,14 @@ def test_cholesky_ldotlt(tag, cholesky_form, product): ...@@ -194,12 +207,14 @@ def test_cholesky_ldotlt(tag, cholesky_form, product):
else: else:
Mv = Av Mv = Av
assert np.all( if ndim == 3:
np.isclose( Av = np.broadcast_to(Av, (5, *Av.shape))
scipy.linalg.cholesky(Mv, lower=(cholesky_form == "lower")), Mv = np.broadcast_to(Mv, (5, *Mv.shape))
np.testing.assert_allclose(
cholesky_vect_fn(Mv),
f(Av), f(Av),
) )
)
def test_local_det_chol(): def test_local_det_chol():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论