Unverified 提交 862c4169 authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Reuse `cholesky` decomposition with `cho_solve` in graphs with multiple…

Reuse `cholesky` decomposition with `cho_solve` in graphs with multiple `pt.solve` when `assume_a = "pos"` (#1467) * Extend decomp+solve rewrite machinery to `assume_a="pos"` * Update rewrite name in test * Refactor tests to be nicer * Respect core op `lower` flag when rewriting to ChoSolve
上级 cca20eb7
......@@ -15,24 +15,29 @@ from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.rewriting.basic import register_specialize
from pytensor.tensor.rewriting.linalg import is_matrix_transpose
from pytensor.tensor.slinalg import Solve, lu_factor, lu_solve
from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_solve
from pytensor.tensor.variable import TensorVariable
def decompose_A(A, assume_a, check_finite):
def decompose_A(A, assume_a, check_finite, lower):
if assume_a == "gen":
return lu_factor(A, check_finite=check_finite)
elif assume_a == "tridiagonal":
# We didn't implement check_finite for tridiagonal LU factorization
return tridiagonal_lu_factor(A)
elif assume_a == "pos":
return cholesky(A, lower=lower, check_finite=check_finite)
else:
raise NotImplementedError
def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op: Solve):
def solve_decomposed_system(
A_decomp, b, transposed=False, lower=False, *, core_solve_op: Solve
):
b_ndim = core_solve_op.b_ndim
check_finite = core_solve_op.check_finite
assume_a = core_solve_op.assume_a
if assume_a == "gen":
return lu_solve(
A_decomp,
......@@ -49,11 +54,19 @@ def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op:
b_ndim=b_ndim,
transposed=transposed,
)
elif assume_a == "pos":
# We can ignore the transposed argument here because A is symmetric by assumption
return cho_solve(
(A_decomp, lower),
b,
b_ndim=b_ndim,
check_finite=check_finite,
)
else:
raise NotImplementedError
def _split_lu_solve_steps(
def _split_decomp_and_solve_steps(
fgraph, node, *, eager: bool, allowed_assume_a: Container[str]
):
if not isinstance(node.op.core_op, Solve):
......@@ -133,13 +146,21 @@ def _split_lu_solve_steps(
if client.op.core_op.check_finite:
check_finite_decomp = True
break
A_decomp = decompose_A(A, assume_a=assume_a, check_finite=check_finite_decomp)
lower = node.op.core_op.lower
A_decomp = decompose_A(
A, assume_a=assume_a, check_finite=check_finite_decomp, lower=lower
)
replacements = {}
for client, transposed in A_solve_clients_and_transpose:
_, b = client.inputs
new_x = solve_lu_decomposed_system(
A_decomp, b, transposed=transposed, core_solve_op=client.op.core_op
new_x = solve_decomposed_system(
A_decomp,
b,
transposed=transposed,
lower=lower,
core_solve_op=client.op.core_op,
)
[old_x] = client.outputs
new_x = atleast_Nd(new_x, n=old_x.type.ndim).astype(old_x.type.dtype)
......@@ -149,7 +170,7 @@ def _split_lu_solve_steps(
return replacements
def _scan_split_non_sequence_lu_decomposition_solve(
def _scan_split_non_sequence_decomposition_and_solve(
fgraph, node, *, allowed_assume_a: Container[str]
):
"""If the A of a Solve within a Scan is a function of non-sequences, split the LU decomposition step.
......@@ -179,7 +200,7 @@ def _scan_split_non_sequence_lu_decomposition_solve(
non_sequences = {equiv[non_seq] for non_seq in non_sequences}
inner_node = equiv[inner_node] # type: ignore
replace_dict = _split_lu_solve_steps(
replace_dict = _split_decomp_and_solve_steps(
new_scan_fgraph,
inner_node,
eager=True,
......@@ -207,22 +228,22 @@ def _scan_split_non_sequence_lu_decomposition_solve(
@register_specialize
@node_rewriter([Blockwise])
def reuse_lu_decomposition_multiple_solves(fgraph, node):
return _split_lu_solve_steps(
fgraph, node, eager=False, allowed_assume_a={"gen", "tridiagonal"}
def reuse_decomposition_multiple_solves(fgraph, node):
return _split_decomp_and_solve_steps(
fgraph, node, eager=False, allowed_assume_a={"gen", "tridiagonal", "pos"}
)
@node_rewriter([Scan])
def scan_split_non_sequence_lu_decomposition_solve(fgraph, node):
return _scan_split_non_sequence_lu_decomposition_solve(
fgraph, node, allowed_assume_a={"gen", "tridiagonal"}
def scan_split_non_sequence_decomposition_and_solve(fgraph, node):
return _scan_split_non_sequence_decomposition_and_solve(
fgraph, node, allowed_assume_a={"gen", "tridiagonal", "pos"}
)
scan_seqopt1.register(
"scan_split_non_sequence_lu_decomposition_solve",
in2out(scan_split_non_sequence_lu_decomposition_solve, ignore_newtrees=True),
scan_split_non_sequence_decomposition_and_solve.__name__,
in2out(scan_split_non_sequence_decomposition_and_solve, ignore_newtrees=True),
"fast_run",
"scan",
"scan_pushout",
......@@ -231,28 +252,30 @@ scan_seqopt1.register(
@node_rewriter([Blockwise])
def reuse_lu_decomposition_multiple_solves_jax(fgraph, node):
return _split_lu_solve_steps(fgraph, node, eager=False, allowed_assume_a={"gen"})
def reuse_decomposition_multiple_solves_jax(fgraph, node):
return _split_decomp_and_solve_steps(
fgraph, node, eager=False, allowed_assume_a={"gen", "pos"}
)
optdb["specialize"].register(
reuse_lu_decomposition_multiple_solves_jax.__name__,
in2out(reuse_lu_decomposition_multiple_solves_jax, ignore_newtrees=True),
reuse_decomposition_multiple_solves_jax.__name__,
in2out(reuse_decomposition_multiple_solves_jax, ignore_newtrees=True),
"jax",
use_db_name_as_tag=False,
)
@node_rewriter([Scan])
def scan_split_non_sequence_lu_decomposition_solve_jax(fgraph, node):
return _scan_split_non_sequence_lu_decomposition_solve(
fgraph, node, allowed_assume_a={"gen"}
def scan_split_non_sequence_decomposition_and_solve_jax(fgraph, node):
return _scan_split_non_sequence_decomposition_and_solve(
fgraph, node, allowed_assume_a={"gen", "pos"}
)
scan_seqopt1.register(
scan_split_non_sequence_lu_decomposition_solve_jax.__name__,
in2out(scan_split_non_sequence_lu_decomposition_solve_jax, ignore_newtrees=True),
scan_split_non_sequence_decomposition_and_solve_jax.__name__,
in2out(scan_split_non_sequence_decomposition_and_solve_jax, ignore_newtrees=True),
"jax",
use_db_name_as_tag=False,
position=2,
......
......@@ -6,8 +6,8 @@ from pytensor.compile.mode import get_default_mode
from pytensor.gradient import grad
from pytensor.scan.op import Scan
from pytensor.tensor._linalg.solve.rewriting import (
reuse_lu_decomposition_multiple_solves,
scan_split_non_sequence_lu_decomposition_solve,
reuse_decomposition_multiple_solves,
scan_split_non_sequence_decomposition_and_solve,
)
from pytensor.tensor._linalg.solve.tridiagonal import (
LUFactorTridiagonal,
......@@ -15,62 +15,70 @@ from pytensor.tensor._linalg.solve.tridiagonal import (
)
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.linalg import solve
from pytensor.tensor.slinalg import LUFactor, Solve, SolveTriangular
from pytensor.tensor.slinalg import (
Cholesky,
CholeskySolve,
LUFactor,
Solve,
SolveTriangular,
)
from pytensor.tensor.type import tensor
def count_vanilla_solve_nodes(nodes) -> int:
return sum(
(
isinstance(node.op, Solve)
or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Solve))
class DecompSolveOpCounter:
def __init__(self, solve_op, decomp_op, solve_op_value: float = 1.0):
self.solve_op = solve_op
self.decomp_op = decomp_op
self.solve_op_value = solve_op_value
def check_node_op_or_core_op(self, node, op):
return isinstance(node.op, op) or (
isinstance(node.op, Blockwise) and isinstance(node.op.core_op, op)
)
for node in nodes
)
def count_vanilla_solve_nodes(self, nodes) -> int:
return sum(self.check_node_op_or_core_op(node, Solve) for node in nodes)
def count_lu_decom_nodes(nodes) -> int:
return sum(
(
isinstance(node.op, LUFactor | LUFactorTridiagonal)
or (
isinstance(node.op, Blockwise)
and isinstance(node.op.core_op, LUFactor | LUFactorTridiagonal)
)
def count_decomp_nodes(self, nodes) -> int:
return sum(
self.check_node_op_or_core_op(node, self.decomp_op) for node in nodes
)
for node in nodes
)
def count_lu_solve_nodes(nodes) -> int:
count = sum(
(
# LUFactor uses 2 SolveTriangular nodes, so we count each as 0.5
0.5
* (
isinstance(node.op, SolveTriangular)
or (
isinstance(node.op, Blockwise)
and isinstance(node.op.core_op, SolveTriangular)
)
)
or (
isinstance(node.op, SolveLUFactorTridiagonal)
or (
isinstance(node.op, Blockwise)
and isinstance(node.op.core_op, SolveLUFactorTridiagonal)
)
)
def count_solve_nodes(self, nodes) -> int:
count = sum(
self.solve_op_value * self.check_node_op_or_core_op(node, self.solve_op)
for node in nodes
)
for node in nodes
)
return int(count)
return int(count)
LUOpCounter = DecompSolveOpCounter(
solve_op=SolveTriangular,
decomp_op=LUFactor,
# Each rewrite introduces two triangular solves, so count them as 1/2 each
solve_op_value=0.5,
)
TriDiagLUOpCounter = DecompSolveOpCounter(
solve_op=SolveLUFactorTridiagonal, decomp_op=LUFactorTridiagonal, solve_op_value=1.0
)
CholeskyOpCounter = DecompSolveOpCounter(
solve_op=CholeskySolve, decomp_op=Cholesky, solve_op_value=1.0
)
@pytest.mark.parametrize("transposed", (False, True))
@pytest.mark.parametrize("assume_a", ("gen", "tridiagonal"))
def test_lu_decomposition_reused_forward_and_gradient(assume_a, transposed):
rewrite_name = reuse_lu_decomposition_multiple_solves.__name__
@pytest.mark.parametrize(
"assume_a, counter",
(
("gen", LUOpCounter),
("tridiagonal", TriDiagLUOpCounter),
("pos", CholeskyOpCounter),
),
)
def test_lu_decomposition_reused_forward_and_gradient(assume_a, counter, transposed):
rewrite_name = reuse_decomposition_multiple_solves.__name__
mode = get_default_mode()
A = tensor("A", shape=(3, 3))
......@@ -80,19 +88,22 @@ def test_lu_decomposition_reused_forward_and_gradient(assume_a, transposed):
grad_x_wrt_A = grad(x.sum(), A)
fn_no_opt = function([A, b], [x, grad_x_wrt_A], mode=mode.excluding(rewrite_name))
no_opt_nodes = fn_no_opt.maker.fgraph.apply_nodes
assert count_vanilla_solve_nodes(no_opt_nodes) == 2
assert count_lu_decom_nodes(no_opt_nodes) == 0
assert count_lu_solve_nodes(no_opt_nodes) == 0
assert counter.count_vanilla_solve_nodes(no_opt_nodes) == 2
assert counter.count_decomp_nodes(no_opt_nodes) == 0
assert counter.count_solve_nodes(no_opt_nodes) == 0
fn_opt = function([A, b], [x, grad_x_wrt_A], mode=mode.including(rewrite_name))
opt_nodes = fn_opt.maker.fgraph.apply_nodes
assert count_vanilla_solve_nodes(opt_nodes) == 0
assert count_lu_decom_nodes(opt_nodes) == 1
assert count_lu_solve_nodes(opt_nodes) == 2
assert counter.count_vanilla_solve_nodes(opt_nodes) == 0
assert counter.count_decomp_nodes(opt_nodes) == 1
assert counter.count_solve_nodes(opt_nodes) == 2
# Make sure results are correct
rng = np.random.default_rng(31)
A_test = rng.random(A.type.shape, dtype=A.type.dtype)
if assume_a == "pos":
A_test = A_test @ A_test.T # Ensure positive definite for Cholesky
b_test = rng.random(b.type.shape, dtype=b.type.dtype)
resx0, resg0 = fn_no_opt(A_test, b_test)
resx1, resg1 = fn_opt(A_test, b_test)
......@@ -102,9 +113,16 @@ def test_lu_decomposition_reused_forward_and_gradient(assume_a, transposed):
@pytest.mark.parametrize("transposed", (False, True))
@pytest.mark.parametrize("assume_a", ("gen", "tridiagonal"))
def test_lu_decomposition_reused_blockwise(assume_a, transposed):
rewrite_name = reuse_lu_decomposition_multiple_solves.__name__
@pytest.mark.parametrize(
"assume_a, counter",
(
("gen", LUOpCounter),
("tridiagonal", TriDiagLUOpCounter),
("pos", CholeskyOpCounter),
),
)
def test_lu_decomposition_reused_blockwise(assume_a, counter, transposed):
rewrite_name = reuse_decomposition_multiple_solves.__name__
mode = get_default_mode()
A = tensor("A", shape=(3, 3))
......@@ -113,30 +131,40 @@ def test_lu_decomposition_reused_blockwise(assume_a, transposed):
x = solve(A, b, assume_a=assume_a, transposed=transposed)
fn_no_opt = function([A, b], [x], mode=mode.excluding(rewrite_name))
no_opt_nodes = fn_no_opt.maker.fgraph.apply_nodes
assert count_vanilla_solve_nodes(no_opt_nodes) == 1
assert count_lu_decom_nodes(no_opt_nodes) == 0
assert count_lu_solve_nodes(no_opt_nodes) == 0
assert counter.count_vanilla_solve_nodes(no_opt_nodes) == 1
assert counter.count_decomp_nodes(no_opt_nodes) == 0
assert counter.count_solve_nodes(no_opt_nodes) == 0
fn_opt = function([A, b], [x], mode=mode.including(rewrite_name))
opt_nodes = fn_opt.maker.fgraph.apply_nodes
assert count_vanilla_solve_nodes(opt_nodes) == 0
assert count_lu_decom_nodes(opt_nodes) == 1
assert count_lu_solve_nodes(opt_nodes) == 1
assert counter.count_vanilla_solve_nodes(opt_nodes) == 0
assert counter.count_decomp_nodes(opt_nodes) == 1
assert counter.count_solve_nodes(opt_nodes) == 1
# Make sure results are correct
rng = np.random.default_rng(31)
A_test = rng.random(A.type.shape, dtype=A.type.dtype)
if assume_a == "pos":
A_test = A_test @ A_test.T # Ensure positive definite for Cholesky
b_test = rng.random(b.type.shape, dtype=b.type.dtype)
resx0 = fn_no_opt(A_test, b_test)
resx1 = fn_opt(A_test, b_test)
rtol = rtol = 1e-7 if config.floatX == "float64" else 1e-4
rtol = 1e-7 if config.floatX == "float64" else 1e-4
np.testing.assert_allclose(resx0, resx1, rtol=rtol)
@pytest.mark.parametrize("transposed", (False, True))
@pytest.mark.parametrize("assume_a", ("gen", "tridiagonal"))
def test_lu_decomposition_reused_scan(assume_a, transposed):
rewrite_name = scan_split_non_sequence_lu_decomposition_solve.__name__
@pytest.mark.parametrize(
"assume_a, counter",
(
("gen", LUOpCounter),
("tridiagonal", TriDiagLUOpCounter),
("pos", CholeskyOpCounter),
),
)
def test_lu_decomposition_reused_scan(assume_a, counter, transposed):
rewrite_name = scan_split_non_sequence_decomposition_and_solve.__name__
mode = get_default_mode()
A = tensor("A", shape=(3, 3))
......@@ -158,23 +186,26 @@ def test_lu_decomposition_reused_scan(assume_a, transposed):
node for node in fn_no_opt.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
]
no_opt_nodes = no_opt_scan_node.op.fgraph.apply_nodes
assert count_vanilla_solve_nodes(no_opt_nodes) == 1
assert count_lu_decom_nodes(no_opt_nodes) == 0
assert count_lu_solve_nodes(no_opt_nodes) == 0
assert counter.count_vanilla_solve_nodes(no_opt_nodes) == 1
assert counter.count_decomp_nodes(no_opt_nodes) == 0
assert counter.count_solve_nodes(no_opt_nodes) == 0
fn_opt = function([A, x0], [xs], mode=mode.including("scan", rewrite_name))
[opt_scan_node] = [
node for node in fn_opt.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
]
opt_nodes = opt_scan_node.op.fgraph.apply_nodes
assert count_vanilla_solve_nodes(opt_nodes) == 0
assert counter.count_vanilla_solve_nodes(opt_nodes) == 0
# The LU decomp is outside of the scan!
assert count_lu_decom_nodes(opt_nodes) == 0
assert count_lu_solve_nodes(opt_nodes) == 1
assert counter.count_decomp_nodes(opt_nodes) == 0
assert counter.count_solve_nodes(opt_nodes) == 1
# Make sure results are correct
rng = np.random.default_rng(170)
A_test = rng.random(A.type.shape, dtype=A.type.dtype)
if assume_a == "pos":
A_test = A_test @ A_test.T # Ensure positive definite for Cholesky
x0_test = rng.random(x0.type.shape, dtype=x0.type.dtype)
resx0 = fn_no_opt(A_test, x0_test)
resx1 = fn_opt(A_test, x0_test)
......@@ -182,23 +213,30 @@ def test_lu_decomposition_reused_scan(assume_a, transposed):
np.testing.assert_allclose(resx0, resx1, rtol=rtol)
def test_lu_decomposition_reused_preserves_check_finite():
@pytest.mark.parametrize(
"assume_a, counter",
(
("gen", LUOpCounter),
("pos", CholeskyOpCounter),
),
)
def test_decomposition_reused_preserves_check_finite(assume_a, counter):
# Check that the LU decomposition rewrite preserves the check_finite flag
rewrite_name = reuse_lu_decomposition_multiple_solves.__name__
rewrite_name = reuse_decomposition_multiple_solves.__name__
A = tensor("A", shape=(2, 2))
b1 = tensor("b1", shape=(2,))
b2 = tensor("b2", shape=(2,))
x1 = solve(A, b1, assume_a="gen", check_finite=True)
x2 = solve(A, b2, assume_a="gen", check_finite=False)
x1 = solve(A, b1, assume_a=assume_a, check_finite=True)
x2 = solve(A, b2, assume_a=assume_a, check_finite=False)
fn_opt = function(
[A, b1, b2], [x1, x2], mode=get_default_mode().including(rewrite_name)
)
opt_nodes = fn_opt.maker.fgraph.apply_nodes
assert count_vanilla_solve_nodes(opt_nodes) == 0
assert count_lu_decom_nodes(opt_nodes) == 1
assert count_lu_solve_nodes(opt_nodes) == 2
assert counter.count_vanilla_solve_nodes(opt_nodes) == 0
assert counter.count_decomp_nodes(opt_nodes) == 1
assert counter.count_solve_nodes(opt_nodes) == 2
# We should get an error if A or b1 is non finite
A_valid = np.array([[1, 0], [0, 1]], dtype=A.type.dtype)
......
......@@ -581,7 +581,7 @@ class TestInplace:
mode = get_default_mode().excluding(
"batched_vector_b_solve_to_matrix_b_solve",
"reuse_lu_decomposition_multiple_solves",
"reuse_decomposition_multiple_solves",
)
fn = function([In(A, mutable=True), In(b, mutable=True)], x, mode=mode)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论