Unverified 提交 9be43d07 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: GitHub

Fix bug in `tag_solve_triangular` rewrite (#383)

* Fix bug in tag_solve_triangular rewrite * Rename tag_solve_triangular to generic_solve_to_solve_triangular
上级 7a82a3f4
......@@ -11,7 +11,7 @@ from pytensor.tensor.rewriting.basic import (
register_specialize,
register_stabilize,
)
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, cholesky, solve
logger = logging.getLogger(__name__)
......@@ -50,31 +50,30 @@ def inv_as_solve(fgraph, node):
@register_stabilize
@register_canonicalize
@node_rewriter([Solve])
def tag_solve_triangular(fgraph, node):
def generic_solve_to_solve_triangular(fgraph, node):
"""
If a general solve() is applied to the output of a cholesky op, then
If any solve() is applied to the output of a cholesky op, then
replace it with a triangular solve.
"""
if isinstance(node.op, Solve):
if node.op.assume_a == "gen":
A, b = node.inputs # result is solution Ax=b
if A.owner and isinstance(A.owner.op, Cholesky):
if A.owner.op.lower:
return [Solve(assume_a="sym", lower=True)(A, b)]
A, b = node.inputs # result is solution Ax=b
if A.owner and isinstance(A.owner.op, Cholesky):
if A.owner.op.lower:
return [SolveTriangular(lower=True)(A, b)]
else:
return [SolveTriangular(lower=False)(A, b)]
if (
A.owner
and isinstance(A.owner.op, DimShuffle)
and A.owner.op.new_order == (1, 0)
):
(A_T,) = A.owner.inputs
if A_T.owner and isinstance(A_T.owner.op, Cholesky):
if A_T.owner.op.lower:
return [SolveTriangular(lower=False)(A, b)]
else:
return [Solve(assume_a="sym", lower=False)(A, b)]
if (
A.owner
and isinstance(A.owner.op, DimShuffle)
and A.owner.op.new_order == (1, 0)
):
(A_T,) = A.owner.inputs
if A_T.owner and isinstance(A_T.owner.op, Cholesky):
if A_T.owner.op.lower:
return [Solve(assume_a="sym", lower=False)(A, b)]
else:
return [Solve(assume_a="sym", lower=True)(A, b)]
return [SolveTriangular(lower=True)(A, b)]
@register_canonicalize
......
......@@ -2,6 +2,7 @@ import numpy as np
import numpy.linalg
import pytest
import scipy.linalg
from numpy.testing import assert_allclose
import pytensor
from pytensor import function
......@@ -12,7 +13,7 @@ from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import _allclose
from pytensor.tensor.nlinalg import MatrixInverse, matrix_inverse
from pytensor.tensor.rewriting.linalg import inv_as_solve
from pytensor.tensor.slinalg import Cholesky, Solve, solve
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, solve
from pytensor.tensor.type import dmatrix, matrix, vector
from tests import unittest_tools as utt
from tests.test_rop import break_op
......@@ -81,25 +82,46 @@ def test_transinv_to_invtrans():
assert node.inputs[0].name == "X"
def test_tag_solve_triangular():
def test_generic_solve_to_solve_triangular():
cholesky_lower = Cholesky(lower=True)
cholesky_upper = Cholesky(lower=False)
A = matrix("A")
x = vector("x")
x = matrix("x")
L = cholesky_lower(A)
U = cholesky_upper(A)
b1 = solve(L, x)
b2 = solve(U, x)
f = pytensor.function([A, x], b1)
X = np.random.normal(size=(10, 10)).astype(config.floatX)
X = X @ X.T
X_chol = np.linalg.cholesky(X)
eye = np.eye(10, dtype=config.floatX)
if config.mode != "FAST_COMPILE":
for node in f.maker.fgraph.toposort():
if isinstance(node.op, Solve):
assert node.op.assume_a != "gen" and node.op.lower
toposort = f.maker.fgraph.toposort()
op_list = [node.op for node in toposort]
assert not any(isinstance(op, Solve) for op in op_list)
assert any(isinstance(op, SolveTriangular) for op in op_list)
assert_allclose(
f(X, eye) @ X_chol, eye, atol=1e-8 if config.floatX.endswith("64") else 1e-4
)
f = pytensor.function([A, x], b2)
if config.mode != "FAST_COMPILE":
for node in f.maker.fgraph.toposort():
if isinstance(node.op, Solve):
assert node.op.assume_a != "gen" and not node.op.lower
toposort = f.maker.fgraph.toposort()
op_list = [node.op for node in toposort]
assert not any(isinstance(op, Solve) for op in op_list)
assert any(isinstance(op, SolveTriangular) for op in op_list)
assert_allclose(
f(X, eye).T @ X_chol,
eye,
atol=1e-8 if config.floatX.endswith("64") else 1e-4,
)
def test_matrix_inverse_solve():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论