提交 99a040cc authored 作者: David Horsley's avatar David Horsley 提交者: Ricardo Vieira

Add cholesky of L.LT rewrite

上级 6d431aa6
...@@ -109,6 +109,50 @@ def psd_solve_with_chol(fgraph, node): ...@@ -109,6 +109,50 @@ def psd_solve_with_chol(fgraph, node):
return [x] return [x]
@register_canonicalize
@register_stabilize
@node_rewriter([Cholesky])
def cholesky_ldotlt(fgraph, node):
"""
rewrite cholesky(dot(L, L.T), lower=True) = L, where L is lower triangular,
or cholesky(dot(U.T, U), upper=True) = U where U is upper triangular.
This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices.
"""
if not isinstance(node.op, Cholesky):
return
A = node.inputs[0]
if not (A.owner and isinstance(A.owner.op, (Dot, Dot22))):
return
l, r = A.owner.inputs
# cholesky(dot(L,L.T)) case
if (
getattr(l.tag, "lower_triangular", False)
and r.owner
and isinstance(r.owner.op, DimShuffle)
and r.owner.op.new_order == (1, 0)
and r.owner.inputs[0] == l
):
if node.op.lower:
return [l]
return [r]
# cholesky(dot(U.T,U)) case
if (
getattr(r.tag, "upper_triangular", False)
and l.owner
and isinstance(l.owner.op, DimShuffle)
and l.owner.op.new_order == (1, 0)
and l.owner.inputs[0] == r
):
if node.op.lower:
return [l]
return [r]
@register_stabilize @register_stabilize
@register_specialize @register_specialize
@node_rewriter([Det]) @node_rewriter([Det])
......
import numpy as np import numpy as np
import numpy.linalg import numpy.linalg
import pytest
import scipy.linalg
import pytensor import pytensor
from pytensor import function from pytensor import function
from pytensor import tensor as at from pytensor import tensor as at
from pytensor.compile import get_default_mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import _allclose from pytensor.tensor.math import _allclose
...@@ -105,3 +108,75 @@ def test_matrix_inverse_solve(): ...@@ -105,3 +108,75 @@ def test_matrix_inverse_solve():
node = matrix_inverse(A).dot(b).owner node = matrix_inverse(A).dot(b).owner
[out] = inv_as_solve.transform(None, node) [out] = inv_as_solve.transform(None, node)
assert isinstance(out.owner.op, Solve) assert isinstance(out.owner.op, Solve)
@pytest.mark.parametrize("tag", ("lower", "upper", None))
@pytest.mark.parametrize("cholesky_form", ("lower", "upper"))
@pytest.mark.parametrize("product", ("lower", "upper", None))
def test_cholesky_ldotlt(tag, cholesky_form, product):
cholesky = Cholesky(lower=(cholesky_form == "lower"))
transform_removes_chol = tag is not None and product == tag
transform_transposes = transform_removes_chol and cholesky_form != tag
A = matrix("L")
if tag:
setattr(A.tag, tag + "_triangular", True)
if product == "lower":
M = A.dot(A.T)
elif product == "upper":
M = A.T.dot(A)
else:
M = A
C = cholesky(M)
f = pytensor.function([A], C, mode=get_default_mode().including("cholesky_ldotlt"))
print(f.maker.fgraph.apply_nodes)
no_cholesky_in_graph = not any(
isinstance(node.op, Cholesky) for node in f.maker.fgraph.apply_nodes
)
assert no_cholesky_in_graph == transform_removes_chol
if transform_transposes:
assert any(
isinstance(node.op, DimShuffle) and node.op.new_order == (1, 0)
for node in f.maker.fgraph.apply_nodes
)
# Test some concrete value through f
# there must be lower triangular (f assumes they are)
Avs = [
np.eye(1, dtype=pytensor.config.floatX),
np.eye(10, dtype=pytensor.config.floatX),
np.array([[2, 0], [1, 4]], dtype=pytensor.config.floatX),
]
if not tag:
# these must be positive def
Avs.extend(
[
np.ones((4, 4), dtype=pytensor.config.floatX)
+ np.eye(4, dtype=pytensor.config.floatX),
]
)
for Av in Avs:
if tag == "upper":
Av = Av.T
if product == "lower":
Mv = Av.dot(Av.T)
elif product == "upper":
Mv = Av.T.dot(Av)
else:
Mv = Av
assert np.all(
np.isclose(
scipy.linalg.cholesky(Mv, lower=(cholesky_form == "lower")),
f(Av),
)
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论