提交 99a040cc authored 作者: David Horsley's avatar David Horsley 提交者: Ricardo Vieira

Add cholesky of L.LT rewrite

上级 6d431aa6
......@@ -109,6 +109,50 @@ def psd_solve_with_chol(fgraph, node):
return [x]
@register_canonicalize
@register_stabilize
@node_rewriter([Cholesky])
def cholesky_ldotlt(fgraph, node):
"""
rewrite cholesky(dot(L, L.T), lower=True) = L, where L is lower triangular,
or cholesky(dot(U.T, U), upper=True) = U where U is upper triangular.
This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices.
"""
if not isinstance(node.op, Cholesky):
return
A = node.inputs[0]
if not (A.owner and isinstance(A.owner.op, (Dot, Dot22))):
return
l, r = A.owner.inputs
# cholesky(dot(L,L.T)) case
if (
getattr(l.tag, "lower_triangular", False)
and r.owner
and isinstance(r.owner.op, DimShuffle)
and r.owner.op.new_order == (1, 0)
and r.owner.inputs[0] == l
):
if node.op.lower:
return [l]
return [r]
# cholesky(dot(U.T,U)) case
if (
getattr(r.tag, "upper_triangular", False)
and l.owner
and isinstance(l.owner.op, DimShuffle)
and l.owner.op.new_order == (1, 0)
and l.owner.inputs[0] == r
):
if node.op.lower:
return [l]
return [r]
@register_stabilize
@register_specialize
@node_rewriter([Det])
......
import numpy as np
import numpy.linalg
import pytest
import scipy.linalg
import pytensor
from pytensor import function
from pytensor import tensor as at
from pytensor.compile import get_default_mode
from pytensor.configdefaults import config
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import _allclose
......@@ -105,3 +108,75 @@ def test_matrix_inverse_solve():
node = matrix_inverse(A).dot(b).owner
[out] = inv_as_solve.transform(None, node)
assert isinstance(out.owner.op, Solve)
@pytest.mark.parametrize("tag", ("lower", "upper", None))
@pytest.mark.parametrize("cholesky_form", ("lower", "upper"))
@pytest.mark.parametrize("product", ("lower", "upper", None))
def test_cholesky_ldotlt(tag, cholesky_form, product):
cholesky = Cholesky(lower=(cholesky_form == "lower"))
transform_removes_chol = tag is not None and product == tag
transform_transposes = transform_removes_chol and cholesky_form != tag
A = matrix("L")
if tag:
setattr(A.tag, tag + "_triangular", True)
if product == "lower":
M = A.dot(A.T)
elif product == "upper":
M = A.T.dot(A)
else:
M = A
C = cholesky(M)
f = pytensor.function([A], C, mode=get_default_mode().including("cholesky_ldotlt"))
print(f.maker.fgraph.apply_nodes)
no_cholesky_in_graph = not any(
isinstance(node.op, Cholesky) for node in f.maker.fgraph.apply_nodes
)
assert no_cholesky_in_graph == transform_removes_chol
if transform_transposes:
assert any(
isinstance(node.op, DimShuffle) and node.op.new_order == (1, 0)
for node in f.maker.fgraph.apply_nodes
)
# Test some concrete value through f
# there must be lower triangular (f assumes they are)
Avs = [
np.eye(1, dtype=pytensor.config.floatX),
np.eye(10, dtype=pytensor.config.floatX),
np.array([[2, 0], [1, 4]], dtype=pytensor.config.floatX),
]
if not tag:
# these must be positive def
Avs.extend(
[
np.ones((4, 4), dtype=pytensor.config.floatX)
+ np.eye(4, dtype=pytensor.config.floatX),
]
)
for Av in Avs:
if tag == "upper":
Av = Av.T
if product == "lower":
Mv = Av.dot(Av.T)
elif product == "upper":
Mv = Av.T.dot(Av)
else:
Mv = Av
assert np.all(
np.isclose(
scipy.linalg.cholesky(Mv, lower=(cholesky_form == "lower")),
f(Av),
)
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论