Unverified 提交 862c4169 authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Reuse `cholesky` decomposition with `cho_solve` in graphs with multiple…

Reuse `cholesky` decomposition with `cho_solve` in graphs with multiple `pt.solve` when `assume_a = "pos"` (#1467) * Extend decomp+solve rewrite machinery to `assume_a="pos"` * Update rewrite name in test * Refactor tests to be nicer * Respect core op `lower` flag when rewriting to ChoSolve
上级 cca20eb7
......@@ -15,24 +15,29 @@ from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.rewriting.basic import register_specialize
from pytensor.tensor.rewriting.linalg import is_matrix_transpose
from pytensor.tensor.slinalg import Solve, lu_factor, lu_solve
from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_solve
from pytensor.tensor.variable import TensorVariable
def decompose_A(A, assume_a, check_finite):
def decompose_A(A, assume_a, check_finite, lower):
if assume_a == "gen":
return lu_factor(A, check_finite=check_finite)
elif assume_a == "tridiagonal":
# We didn't implement check_finite for tridiagonal LU factorization
return tridiagonal_lu_factor(A)
elif assume_a == "pos":
return cholesky(A, lower=lower, check_finite=check_finite)
else:
raise NotImplementedError
def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op: Solve):
def solve_decomposed_system(
A_decomp, b, transposed=False, lower=False, *, core_solve_op: Solve
):
b_ndim = core_solve_op.b_ndim
check_finite = core_solve_op.check_finite
assume_a = core_solve_op.assume_a
if assume_a == "gen":
return lu_solve(
A_decomp,
......@@ -49,11 +54,19 @@ def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op:
b_ndim=b_ndim,
transposed=transposed,
)
elif assume_a == "pos":
# We can ignore the transposed argument here because A is symmetric by assumption
return cho_solve(
(A_decomp, lower),
b,
b_ndim=b_ndim,
check_finite=check_finite,
)
else:
raise NotImplementedError
def _split_lu_solve_steps(
def _split_decomp_and_solve_steps(
fgraph, node, *, eager: bool, allowed_assume_a: Container[str]
):
if not isinstance(node.op.core_op, Solve):
......@@ -133,13 +146,21 @@ def _split_lu_solve_steps(
if client.op.core_op.check_finite:
check_finite_decomp = True
break
A_decomp = decompose_A(A, assume_a=assume_a, check_finite=check_finite_decomp)
lower = node.op.core_op.lower
A_decomp = decompose_A(
A, assume_a=assume_a, check_finite=check_finite_decomp, lower=lower
)
replacements = {}
for client, transposed in A_solve_clients_and_transpose:
_, b = client.inputs
new_x = solve_lu_decomposed_system(
A_decomp, b, transposed=transposed, core_solve_op=client.op.core_op
new_x = solve_decomposed_system(
A_decomp,
b,
transposed=transposed,
lower=lower,
core_solve_op=client.op.core_op,
)
[old_x] = client.outputs
new_x = atleast_Nd(new_x, n=old_x.type.ndim).astype(old_x.type.dtype)
......@@ -149,7 +170,7 @@ def _split_lu_solve_steps(
return replacements
def _scan_split_non_sequence_lu_decomposition_solve(
def _scan_split_non_sequence_decomposition_and_solve(
fgraph, node, *, allowed_assume_a: Container[str]
):
"""If the A of a Solve within a Scan is a function of non-sequences, split the LU decomposition step.
......@@ -179,7 +200,7 @@ def _scan_split_non_sequence_lu_decomposition_solve(
non_sequences = {equiv[non_seq] for non_seq in non_sequences}
inner_node = equiv[inner_node] # type: ignore
replace_dict = _split_lu_solve_steps(
replace_dict = _split_decomp_and_solve_steps(
new_scan_fgraph,
inner_node,
eager=True,
......@@ -207,22 +228,22 @@ def _scan_split_non_sequence_lu_decomposition_solve(
@register_specialize
@node_rewriter([Blockwise])
def reuse_lu_decomposition_multiple_solves(fgraph, node):
return _split_lu_solve_steps(
fgraph, node, eager=False, allowed_assume_a={"gen", "tridiagonal"}
def reuse_decomposition_multiple_solves(fgraph, node):
return _split_decomp_and_solve_steps(
fgraph, node, eager=False, allowed_assume_a={"gen", "tridiagonal", "pos"}
)
@node_rewriter([Scan])
def scan_split_non_sequence_lu_decomposition_solve(fgraph, node):
return _scan_split_non_sequence_lu_decomposition_solve(
fgraph, node, allowed_assume_a={"gen", "tridiagonal"}
def scan_split_non_sequence_decomposition_and_solve(fgraph, node):
return _scan_split_non_sequence_decomposition_and_solve(
fgraph, node, allowed_assume_a={"gen", "tridiagonal", "pos"}
)
scan_seqopt1.register(
"scan_split_non_sequence_lu_decomposition_solve",
in2out(scan_split_non_sequence_lu_decomposition_solve, ignore_newtrees=True),
scan_split_non_sequence_decomposition_and_solve.__name__,
in2out(scan_split_non_sequence_decomposition_and_solve, ignore_newtrees=True),
"fast_run",
"scan",
"scan_pushout",
......@@ -231,28 +252,30 @@ scan_seqopt1.register(
@node_rewriter([Blockwise])
def reuse_lu_decomposition_multiple_solves_jax(fgraph, node):
return _split_lu_solve_steps(fgraph, node, eager=False, allowed_assume_a={"gen"})
def reuse_decomposition_multiple_solves_jax(fgraph, node):
return _split_decomp_and_solve_steps(
fgraph, node, eager=False, allowed_assume_a={"gen", "pos"}
)
optdb["specialize"].register(
reuse_lu_decomposition_multiple_solves_jax.__name__,
in2out(reuse_lu_decomposition_multiple_solves_jax, ignore_newtrees=True),
reuse_decomposition_multiple_solves_jax.__name__,
in2out(reuse_decomposition_multiple_solves_jax, ignore_newtrees=True),
"jax",
use_db_name_as_tag=False,
)
@node_rewriter([Scan])
def scan_split_non_sequence_lu_decomposition_solve_jax(fgraph, node):
return _scan_split_non_sequence_lu_decomposition_solve(
fgraph, node, allowed_assume_a={"gen"}
def scan_split_non_sequence_decomposition_and_solve_jax(fgraph, node):
return _scan_split_non_sequence_decomposition_and_solve(
fgraph, node, allowed_assume_a={"gen", "pos"}
)
scan_seqopt1.register(
scan_split_non_sequence_lu_decomposition_solve_jax.__name__,
in2out(scan_split_non_sequence_lu_decomposition_solve_jax, ignore_newtrees=True),
scan_split_non_sequence_decomposition_and_solve_jax.__name__,
in2out(scan_split_non_sequence_decomposition_and_solve_jax, ignore_newtrees=True),
"jax",
use_db_name_as_tag=False,
position=2,
......
......@@ -581,7 +581,7 @@ class TestInplace:
mode = get_default_mode().excluding(
"batched_vector_b_solve_to_matrix_b_solve",
"reuse_lu_decomposition_multiple_solves",
"reuse_decomposition_multiple_solves",
)
fn = function([In(A, mutable=True), In(b, mutable=True)], x, mode=mode)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论