Unverified 提交 862c4169 authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Reuse `cholesky` decomposition with `cho_solve` in graphs with multiple…

Reuse `cholesky` decomposition with `cho_solve` in graphs with multiple `pt.solve` when `assume_a = "pos"` (#1467) * Extend decomp+solve rewrite machinery to `assume_a="pos"` * Update rewrite name in test * Refactor tests to be nicer * Respect core op `lower` flag when rewriting to ChoSolve
上级 cca20eb7
...@@ -15,24 +15,29 @@ from pytensor.tensor.blockwise import Blockwise ...@@ -15,24 +15,29 @@ from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.rewriting.basic import register_specialize from pytensor.tensor.rewriting.basic import register_specialize
from pytensor.tensor.rewriting.linalg import is_matrix_transpose from pytensor.tensor.rewriting.linalg import is_matrix_transpose
from pytensor.tensor.slinalg import Solve, lu_factor, lu_solve from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_solve
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
def decompose_A(A, assume_a, check_finite): def decompose_A(A, assume_a, check_finite, lower):
if assume_a == "gen": if assume_a == "gen":
return lu_factor(A, check_finite=check_finite) return lu_factor(A, check_finite=check_finite)
elif assume_a == "tridiagonal": elif assume_a == "tridiagonal":
# We didn't implement check_finite for tridiagonal LU factorization # We didn't implement check_finite for tridiagonal LU factorization
return tridiagonal_lu_factor(A) return tridiagonal_lu_factor(A)
elif assume_a == "pos":
return cholesky(A, lower=lower, check_finite=check_finite)
else: else:
raise NotImplementedError raise NotImplementedError
def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op: Solve): def solve_decomposed_system(
A_decomp, b, transposed=False, lower=False, *, core_solve_op: Solve
):
b_ndim = core_solve_op.b_ndim b_ndim = core_solve_op.b_ndim
check_finite = core_solve_op.check_finite check_finite = core_solve_op.check_finite
assume_a = core_solve_op.assume_a assume_a = core_solve_op.assume_a
if assume_a == "gen": if assume_a == "gen":
return lu_solve( return lu_solve(
A_decomp, A_decomp,
...@@ -49,11 +54,19 @@ def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op: ...@@ -49,11 +54,19 @@ def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op:
b_ndim=b_ndim, b_ndim=b_ndim,
transposed=transposed, transposed=transposed,
) )
elif assume_a == "pos":
# We can ignore the transposed argument here because A is symmetric by assumption
return cho_solve(
(A_decomp, lower),
b,
b_ndim=b_ndim,
check_finite=check_finite,
)
else: else:
raise NotImplementedError raise NotImplementedError
def _split_lu_solve_steps( def _split_decomp_and_solve_steps(
fgraph, node, *, eager: bool, allowed_assume_a: Container[str] fgraph, node, *, eager: bool, allowed_assume_a: Container[str]
): ):
if not isinstance(node.op.core_op, Solve): if not isinstance(node.op.core_op, Solve):
...@@ -133,13 +146,21 @@ def _split_lu_solve_steps( ...@@ -133,13 +146,21 @@ def _split_lu_solve_steps(
if client.op.core_op.check_finite: if client.op.core_op.check_finite:
check_finite_decomp = True check_finite_decomp = True
break break
A_decomp = decompose_A(A, assume_a=assume_a, check_finite=check_finite_decomp)
lower = node.op.core_op.lower
A_decomp = decompose_A(
A, assume_a=assume_a, check_finite=check_finite_decomp, lower=lower
)
replacements = {} replacements = {}
for client, transposed in A_solve_clients_and_transpose: for client, transposed in A_solve_clients_and_transpose:
_, b = client.inputs _, b = client.inputs
new_x = solve_lu_decomposed_system( new_x = solve_decomposed_system(
A_decomp, b, transposed=transposed, core_solve_op=client.op.core_op A_decomp,
b,
transposed=transposed,
lower=lower,
core_solve_op=client.op.core_op,
) )
[old_x] = client.outputs [old_x] = client.outputs
new_x = atleast_Nd(new_x, n=old_x.type.ndim).astype(old_x.type.dtype) new_x = atleast_Nd(new_x, n=old_x.type.ndim).astype(old_x.type.dtype)
...@@ -149,7 +170,7 @@ def _split_lu_solve_steps( ...@@ -149,7 +170,7 @@ def _split_lu_solve_steps(
return replacements return replacements
def _scan_split_non_sequence_lu_decomposition_solve( def _scan_split_non_sequence_decomposition_and_solve(
fgraph, node, *, allowed_assume_a: Container[str] fgraph, node, *, allowed_assume_a: Container[str]
): ):
"""If the A of a Solve within a Scan is a function of non-sequences, split the LU decomposition step. """If the A of a Solve within a Scan is a function of non-sequences, split the LU decomposition step.
...@@ -179,7 +200,7 @@ def _scan_split_non_sequence_lu_decomposition_solve( ...@@ -179,7 +200,7 @@ def _scan_split_non_sequence_lu_decomposition_solve(
non_sequences = {equiv[non_seq] for non_seq in non_sequences} non_sequences = {equiv[non_seq] for non_seq in non_sequences}
inner_node = equiv[inner_node] # type: ignore inner_node = equiv[inner_node] # type: ignore
replace_dict = _split_lu_solve_steps( replace_dict = _split_decomp_and_solve_steps(
new_scan_fgraph, new_scan_fgraph,
inner_node, inner_node,
eager=True, eager=True,
...@@ -207,22 +228,22 @@ def _scan_split_non_sequence_lu_decomposition_solve( ...@@ -207,22 +228,22 @@ def _scan_split_non_sequence_lu_decomposition_solve(
@register_specialize @register_specialize
@node_rewriter([Blockwise]) @node_rewriter([Blockwise])
def reuse_lu_decomposition_multiple_solves(fgraph, node): def reuse_decomposition_multiple_solves(fgraph, node):
return _split_lu_solve_steps( return _split_decomp_and_solve_steps(
fgraph, node, eager=False, allowed_assume_a={"gen", "tridiagonal"} fgraph, node, eager=False, allowed_assume_a={"gen", "tridiagonal", "pos"}
) )
@node_rewriter([Scan]) @node_rewriter([Scan])
def scan_split_non_sequence_lu_decomposition_solve(fgraph, node): def scan_split_non_sequence_decomposition_and_solve(fgraph, node):
return _scan_split_non_sequence_lu_decomposition_solve( return _scan_split_non_sequence_decomposition_and_solve(
fgraph, node, allowed_assume_a={"gen", "tridiagonal"} fgraph, node, allowed_assume_a={"gen", "tridiagonal", "pos"}
) )
scan_seqopt1.register( scan_seqopt1.register(
"scan_split_non_sequence_lu_decomposition_solve", scan_split_non_sequence_decomposition_and_solve.__name__,
in2out(scan_split_non_sequence_lu_decomposition_solve, ignore_newtrees=True), in2out(scan_split_non_sequence_decomposition_and_solve, ignore_newtrees=True),
"fast_run", "fast_run",
"scan", "scan",
"scan_pushout", "scan_pushout",
...@@ -231,28 +252,30 @@ scan_seqopt1.register( ...@@ -231,28 +252,30 @@ scan_seqopt1.register(
@node_rewriter([Blockwise]) @node_rewriter([Blockwise])
def reuse_lu_decomposition_multiple_solves_jax(fgraph, node): def reuse_decomposition_multiple_solves_jax(fgraph, node):
return _split_lu_solve_steps(fgraph, node, eager=False, allowed_assume_a={"gen"}) return _split_decomp_and_solve_steps(
fgraph, node, eager=False, allowed_assume_a={"gen", "pos"}
)
optdb["specialize"].register( optdb["specialize"].register(
reuse_lu_decomposition_multiple_solves_jax.__name__, reuse_decomposition_multiple_solves_jax.__name__,
in2out(reuse_lu_decomposition_multiple_solves_jax, ignore_newtrees=True), in2out(reuse_decomposition_multiple_solves_jax, ignore_newtrees=True),
"jax", "jax",
use_db_name_as_tag=False, use_db_name_as_tag=False,
) )
@node_rewriter([Scan]) @node_rewriter([Scan])
def scan_split_non_sequence_lu_decomposition_solve_jax(fgraph, node): def scan_split_non_sequence_decomposition_and_solve_jax(fgraph, node):
return _scan_split_non_sequence_lu_decomposition_solve( return _scan_split_non_sequence_decomposition_and_solve(
fgraph, node, allowed_assume_a={"gen"} fgraph, node, allowed_assume_a={"gen", "pos"}
) )
scan_seqopt1.register( scan_seqopt1.register(
scan_split_non_sequence_lu_decomposition_solve_jax.__name__, scan_split_non_sequence_decomposition_and_solve_jax.__name__,
in2out(scan_split_non_sequence_lu_decomposition_solve_jax, ignore_newtrees=True), in2out(scan_split_non_sequence_decomposition_and_solve_jax, ignore_newtrees=True),
"jax", "jax",
use_db_name_as_tag=False, use_db_name_as_tag=False,
position=2, position=2,
......
...@@ -581,7 +581,7 @@ class TestInplace: ...@@ -581,7 +581,7 @@ class TestInplace:
mode = get_default_mode().excluding( mode = get_default_mode().excluding(
"batched_vector_b_solve_to_matrix_b_solve", "batched_vector_b_solve_to_matrix_b_solve",
"reuse_lu_decomposition_multiple_solves", "reuse_decomposition_multiple_solves",
) )
fn = function([In(A, mutable=True), In(b, mutable=True)], x, mode=mode) fn = function([In(A, mutable=True), In(b, mutable=True)], x, mode=mode)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论