提交 39aa1234 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Extend cholesky of triangular dot rewrite to matmul Ops

Also restrict to 2D Dot cases
上级 00546b9f
......@@ -6,7 +6,7 @@ from pytensor.tensor.basic import TensorVariable, diagonal, swapaxes
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Dot, Prod, log, prod
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
from pytensor.tensor.nlinalg import MatrixInverse, det
from pytensor.tensor.rewriting.basic import (
register_canonicalize,
......@@ -168,13 +168,25 @@ def cholesky_ldotlt(fgraph, node):
rewrite cholesky(dot(L, L.T), lower=True) = L, where L is lower triangular,
or cholesky(dot(U.T, U), upper=True) = U where U is upper triangular.
Also works with matmul.
This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices.
"""
if not isinstance(node.op.core_op, Cholesky):
return
A = node.inputs[0]
if not (A.owner and isinstance(A.owner.op, (Dot, Dot22))):
if not (
A.owner is not None
and (
(
isinstance(A.owner.op, (Dot, Dot22))
# This rewrite only applies to matrix Dot
and A.owner.inputs[0].type.ndim == 2
)
or (A.owner.op == _matrix_matrix_matmul)
)
):
return
l, r = A.owner.inputs
......
from functools import partial
import numpy as np
import numpy.linalg
import pytest
......@@ -9,13 +11,14 @@ from pytensor import function
from pytensor import tensor as at
from pytensor.compile import get_default_mode
from pytensor.configdefaults import config
from pytensor.tensor import swapaxes
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import _allclose
from pytensor.tensor.math import _allclose, dot, matmul
from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse
from pytensor.tensor.rewriting.linalg import inv_as_solve
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, cholesky, solve
from pytensor.tensor.type import dmatrix, matrix, vector
from pytensor.tensor.type import dmatrix, matrix, tensor, vector
from tests import unittest_tools as utt
from tests.test_rop import break_op
......@@ -137,18 +140,20 @@ def test_matrix_inverse_solve():
@pytest.mark.parametrize("tag", ("lower", "upper", None))
@pytest.mark.parametrize("cholesky_form", ("lower", "upper"))
@pytest.mark.parametrize("product", ("lower", "upper", None))
def test_cholesky_ldotlt(tag, cholesky_form, product):
@pytest.mark.parametrize("op", (dot, matmul))
def test_cholesky_ldotlt(tag, cholesky_form, product, op):
transform_removes_chol = tag is not None and product == tag
transform_transposes = transform_removes_chol and cholesky_form != tag
A = matrix("L")
ndim = 2 if op == dot else 3
A = tensor("L", shape=(None,) * ndim)
if tag:
setattr(A.tag, tag + "_triangular", True)
if product == "lower":
M = A.dot(A.T)
M = op(A, swapaxes(A, -1, -2))
elif product == "upper":
M = A.T.dot(A)
M = op(swapaxes(A, -1, -2), A)
else:
M = A
......@@ -156,14 +161,17 @@ def test_cholesky_ldotlt(tag, cholesky_form, product):
f = pytensor.function([A], C, mode=get_default_mode().including("cholesky_ldotlt"))
no_cholesky_in_graph = not any(
isinstance(node.op, Cholesky) for node in f.maker.fgraph.apply_nodes
isinstance(node.op, Cholesky)
or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Cholesky))
for node in f.maker.fgraph.apply_nodes
)
assert no_cholesky_in_graph == transform_removes_chol
if transform_transposes:
expected_order = (1, 0) if ndim == 2 else (0, 2, 1)
assert any(
isinstance(node.op, DimShuffle) and node.op.new_order == (1, 0)
isinstance(node.op, DimShuffle) and node.op.new_order == expected_order
for node in f.maker.fgraph.apply_nodes
)
......@@ -183,6 +191,11 @@ def test_cholesky_ldotlt(tag, cholesky_form, product):
]
)
cholesky_vect_fn = np.vectorize(
partial(scipy.linalg.cholesky, lower=(cholesky_form == "lower")),
signature="(a, a)->(a, a)",
)
for Av in Avs:
if tag == "upper":
Av = Av.T
......@@ -194,12 +207,14 @@ def test_cholesky_ldotlt(tag, cholesky_form, product):
else:
Mv = Av
assert np.all(
np.isclose(
scipy.linalg.cholesky(Mv, lower=(cholesky_form == "lower")),
if ndim == 3:
Av = np.broadcast_to(Av, (5, *Av.shape))
Mv = np.broadcast_to(Mv, (5, *Mv.shape))
np.testing.assert_allclose(
cholesky_vect_fn(Mv),
f(Av),
)
)
def test_local_det_chol():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论