Unverified 提交 9be43d07 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: GitHub

Fix bug in `tag_solve_triangular` rewrite (#383)

* Fix bug in tag_solve_triangular rewrite * Rename tag_solve_triangular to generic_solve_to_solve_triangular
上级 7a82a3f4
...@@ -11,7 +11,7 @@ from pytensor.tensor.rewriting.basic import ( ...@@ -11,7 +11,7 @@ from pytensor.tensor.rewriting.basic import (
register_specialize, register_specialize,
register_stabilize, register_stabilize,
) )
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, cholesky, solve
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -50,20 +50,19 @@ def inv_as_solve(fgraph, node): ...@@ -50,20 +50,19 @@ def inv_as_solve(fgraph, node):
@register_stabilize @register_stabilize
@register_canonicalize @register_canonicalize
@node_rewriter([Solve]) @node_rewriter([Solve])
def tag_solve_triangular(fgraph, node): def generic_solve_to_solve_triangular(fgraph, node):
""" """
If a general solve() is applied to the output of a cholesky op, then If any solve() is applied to the output of a cholesky op, then
replace it with a triangular solve. replace it with a triangular solve.
""" """
if isinstance(node.op, Solve): if isinstance(node.op, Solve):
if node.op.assume_a == "gen":
A, b = node.inputs # result is solution Ax=b A, b = node.inputs # result is solution Ax=b
if A.owner and isinstance(A.owner.op, Cholesky): if A.owner and isinstance(A.owner.op, Cholesky):
if A.owner.op.lower: if A.owner.op.lower:
return [Solve(assume_a="sym", lower=True)(A, b)] return [SolveTriangular(lower=True)(A, b)]
else: else:
return [Solve(assume_a="sym", lower=False)(A, b)] return [SolveTriangular(lower=False)(A, b)]
if ( if (
A.owner A.owner
and isinstance(A.owner.op, DimShuffle) and isinstance(A.owner.op, DimShuffle)
...@@ -72,9 +71,9 @@ def tag_solve_triangular(fgraph, node): ...@@ -72,9 +71,9 @@ def tag_solve_triangular(fgraph, node):
(A_T,) = A.owner.inputs (A_T,) = A.owner.inputs
if A_T.owner and isinstance(A_T.owner.op, Cholesky): if A_T.owner and isinstance(A_T.owner.op, Cholesky):
if A_T.owner.op.lower: if A_T.owner.op.lower:
return [Solve(assume_a="sym", lower=False)(A, b)] return [SolveTriangular(lower=False)(A, b)]
else: else:
return [Solve(assume_a="sym", lower=True)(A, b)] return [SolveTriangular(lower=True)(A, b)]
@register_canonicalize @register_canonicalize
......
...@@ -2,6 +2,7 @@ import numpy as np ...@@ -2,6 +2,7 @@ import numpy as np
import numpy.linalg import numpy.linalg
import pytest import pytest
import scipy.linalg import scipy.linalg
from numpy.testing import assert_allclose
import pytensor import pytensor
from pytensor import function from pytensor import function
...@@ -12,7 +13,7 @@ from pytensor.tensor.elemwise import DimShuffle ...@@ -12,7 +13,7 @@ from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import _allclose from pytensor.tensor.math import _allclose
from pytensor.tensor.nlinalg import MatrixInverse, matrix_inverse from pytensor.tensor.nlinalg import MatrixInverse, matrix_inverse
from pytensor.tensor.rewriting.linalg import inv_as_solve from pytensor.tensor.rewriting.linalg import inv_as_solve
from pytensor.tensor.slinalg import Cholesky, Solve, solve from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, solve
from pytensor.tensor.type import dmatrix, matrix, vector from pytensor.tensor.type import dmatrix, matrix, vector
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.test_rop import break_op from tests.test_rop import break_op
...@@ -81,25 +82,46 @@ def test_transinv_to_invtrans(): ...@@ -81,25 +82,46 @@ def test_transinv_to_invtrans():
assert node.inputs[0].name == "X" assert node.inputs[0].name == "X"
def test_tag_solve_triangular(): def test_generic_solve_to_solve_triangular():
cholesky_lower = Cholesky(lower=True) cholesky_lower = Cholesky(lower=True)
cholesky_upper = Cholesky(lower=False) cholesky_upper = Cholesky(lower=False)
A = matrix("A") A = matrix("A")
x = vector("x") x = matrix("x")
L = cholesky_lower(A) L = cholesky_lower(A)
U = cholesky_upper(A) U = cholesky_upper(A)
b1 = solve(L, x) b1 = solve(L, x)
b2 = solve(U, x) b2 = solve(U, x)
f = pytensor.function([A, x], b1) f = pytensor.function([A, x], b1)
X = np.random.normal(size=(10, 10)).astype(config.floatX)
X = X @ X.T
X_chol = np.linalg.cholesky(X)
eye = np.eye(10, dtype=config.floatX)
if config.mode != "FAST_COMPILE": if config.mode != "FAST_COMPILE":
for node in f.maker.fgraph.toposort(): toposort = f.maker.fgraph.toposort()
if isinstance(node.op, Solve): op_list = [node.op for node in toposort]
assert node.op.assume_a != "gen" and node.op.lower
assert not any(isinstance(op, Solve) for op in op_list)
assert any(isinstance(op, SolveTriangular) for op in op_list)
assert_allclose(
f(X, eye) @ X_chol, eye, atol=1e-8 if config.floatX.endswith("64") else 1e-4
)
f = pytensor.function([A, x], b2) f = pytensor.function([A, x], b2)
if config.mode != "FAST_COMPILE": if config.mode != "FAST_COMPILE":
for node in f.maker.fgraph.toposort(): toposort = f.maker.fgraph.toposort()
if isinstance(node.op, Solve): op_list = [node.op for node in toposort]
assert node.op.assume_a != "gen" and not node.op.lower assert not any(isinstance(op, Solve) for op in op_list)
assert any(isinstance(op, SolveTriangular) for op in op_list)
assert_allclose(
f(X, eye).T @ X_chol,
eye,
atol=1e-8 if config.floatX.endswith("64") else 1e-4,
)
def test_matrix_inverse_solve(): def test_matrix_inverse_solve():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论