提交 9d2f8f11 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Jesse Grabowski

Reuse LU decomposition in Solve

上级 88c07f4b
......@@ -490,6 +490,8 @@ PYTORCH = Mode(
"fusion",
"inplace",
"scan_save_mem_prealloc",
"reuse_lu_decomposition_multiple_solves",
"scan_split_non_sequence_lu_decomposition_solve",
],
),
)
......
......@@ -2561,7 +2561,6 @@ scan_seqopt1.register(
position=1,
)
scan_seqopt1.register(
"scan_push_out_non_seq",
in2out(scan_push_out_non_seq, ignore_newtrees=True),
......@@ -2569,10 +2568,9 @@ scan_seqopt1.register(
"fast_run",
"scan",
"scan_pushout",
position=2,
position=3,
)
scan_seqopt1.register(
"scan_push_out_seq",
in2out(scan_push_out_seq, ignore_newtrees=True),
......@@ -2580,7 +2578,7 @@ scan_seqopt1.register(
"fast_run",
"scan",
"scan_pushout",
position=3,
position=4,
)
......@@ -2592,7 +2590,7 @@ scan_seqopt1.register(
"more_mem",
"scan",
"scan_pushout",
position=4,
position=5,
)
......@@ -2605,7 +2603,7 @@ scan_seqopt1.register(
"more_mem",
"scan",
"scan_pushout",
position=5,
position=6,
)
scan_eqopt2.register(
......
......@@ -114,6 +114,7 @@ from pytensor.tensor import (
# isort: off
import pytensor.tensor._linalg
from pytensor.tensor import linalg
from pytensor.tensor import special
from pytensor.tensor import signal
......
# Register rewrites
import pytensor.tensor._linalg.solve
# Register rewrites in the database
import pytensor.tensor._linalg.solve.rewriting
from collections.abc import Container
from copy import copy
from pytensor.graph import Constant, graph_inputs
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
from pytensor.scan.op import Scan
from pytensor.scan.rewriting import scan_seqopt1
from pytensor.tensor.basic import atleast_Nd
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.rewriting.basic import register_specialize
from pytensor.tensor.rewriting.linalg import is_matrix_transpose
from pytensor.tensor.slinalg import Solve, lu_factor, lu_solve
from pytensor.tensor.variable import TensorVariable
def decompose_A(A, assume_a):
if assume_a == "gen":
return lu_factor(A, check_finite=False)
else:
raise NotImplementedError
def solve_lu_decomposed_system(A_decomp, b, b_ndim, assume_a, transposed=False):
if assume_a == "gen":
return lu_solve(A_decomp, b, b_ndim=b_ndim, trans=transposed)
else:
raise NotImplementedError
def _split_lu_solve_steps(
fgraph, node, *, eager: bool, allowed_assume_a: Container[str]
):
if not isinstance(node.op.core_op, Solve):
return None
def get_root_A(a: TensorVariable) -> tuple[TensorVariable, bool]:
# Find the root variable of the first input to Solve
# If `a` is a left expand_dims or matrix transpose (DimShuffle variants),
# the root variable is the pre-DimShuffled input.
# Otherwise, `a` is considered the root variable.
# We also return whether the root `a` is transposed.
transposed = False
if a.owner is not None and isinstance(a.owner.op, DimShuffle):
if a.owner.op.is_left_expand_dims:
[a] = a.owner.inputs
elif is_matrix_transpose(a):
[a] = a.owner.inputs
transposed = True
return a, transposed
def find_solve_clients(var, assume_a):
clients = []
for cl, idx in fgraph.clients[var]:
if (
idx == 0
and isinstance(cl.op, Blockwise)
and isinstance(cl.op.core_op, Solve)
and (cl.op.core_op.assume_a == assume_a)
):
clients.append(cl)
elif isinstance(cl.op, DimShuffle) and cl.op.is_left_expand_dims:
# If it's a left expand_dims, recurse on the output
clients.extend(find_solve_clients(cl.outputs[0], assume_a))
return clients
assume_a = node.op.core_op.assume_a
if assume_a not in allowed_assume_a:
return None
A, _ = get_root_A(node.inputs[0])
# Find Solve using A (or left expand_dims of A)
# TODO: We could handle arbitrary shuffle of the batch dimensions, just need to propagate
# that to the A_decomp outputs
A_solve_clients_and_transpose = [
(client, False) for client in find_solve_clients(A, assume_a)
]
# Find Solves using A.T
for cl, _ in fgraph.clients[A]:
if isinstance(cl.op, DimShuffle) and is_matrix_transpose(cl.out):
A_T = cl.out
A_solve_clients_and_transpose.extend(
(client, True) for client in find_solve_clients(A_T, assume_a)
)
if not eager and len(A_solve_clients_and_transpose) == 1:
# If theres' a single use don't do it... unless it's being broadcast in a Blockwise (or we're eager)
# That's a "reuse" inside the inner vectorized loop
batch_ndim = node.op.batch_ndim(node)
(client, _) = A_solve_clients_and_transpose[0]
original_A, b = client.inputs
if not any(
a_bcast and not b_bcast
for a_bcast, b_bcast in zip(
original_A.type.broadcastable[:batch_ndim],
b.type.broadcastable[:batch_ndim],
strict=True,
)
):
return None
A_decomp = decompose_A(A, assume_a=assume_a)
replacements = {}
for client, transposed in A_solve_clients_and_transpose:
_, b = client.inputs
b_ndim = client.op.core_op.b_ndim
new_x = solve_lu_decomposed_system(
A_decomp, b, b_ndim=b_ndim, assume_a=assume_a, transposed=transposed
)
[old_x] = client.outputs
new_x = atleast_Nd(new_x, n=old_x.type.ndim).astype(old_x.type.dtype)
copy_stack_trace(old_x, new_x)
replacements[old_x] = new_x
return replacements
def _scan_split_non_sequence_lu_decomposition_solve(
fgraph, node, *, allowed_assume_a: Container[str]
):
"""If the A of a Solve within a Scan is a function of non-sequences, split the LU decomposition step.
The LU decomposition step can then be pushed out of the inner loop by the `scan_pushout_non_sequences` rewrite.
"""
scan_op: Scan = node.op
non_sequences = set(scan_op.inner_non_seqs(scan_op.inner_inputs))
new_scan_fgraph = scan_op.fgraph
changed = False
while True:
for inner_node in new_scan_fgraph.toposort():
if (
isinstance(inner_node.op, Blockwise)
and isinstance(inner_node.op.core_op, Solve)
and inner_node.op.core_op.assume_a in allowed_assume_a
):
A, b = inner_node.inputs
if all(
(isinstance(root_inp, Constant) or (root_inp in non_sequences))
for root_inp in graph_inputs([A])
):
if new_scan_fgraph is scan_op.fgraph:
# Clone the first time to avoid mutating the original fgraph
new_scan_fgraph, equiv = new_scan_fgraph.clone_get_equiv()
non_sequences = {equiv[non_seq] for non_seq in non_sequences}
inner_node = equiv[inner_node] # type: ignore
replace_dict = _split_lu_solve_steps(
new_scan_fgraph,
inner_node,
eager=True,
allowed_assume_a=allowed_assume_a,
)
assert (
isinstance(replace_dict, dict) and len(replace_dict) > 0
), "Rewrite failed"
new_scan_fgraph.replace_all(replace_dict.items())
changed = True
break # Break to start over with a fresh toposort
else: # no_break
break # Nothing else changed
if not changed:
return
# Return a new scan to indicate that a rewrite was done
new_scan_op = copy(scan_op)
new_scan_op.fgraph = new_scan_fgraph
new_outs = new_scan_op.make_node(*node.inputs).outputs
copy_stack_trace(node.outputs, new_outs)
return new_outs
@register_specialize
@node_rewriter([Blockwise])
def reuse_lu_decomposition_multiple_solves(fgraph, node):
return _split_lu_solve_steps(fgraph, node, eager=False, allowed_assume_a={"gen"})
@node_rewriter([Scan])
def scan_split_non_sequence_lu_decomposition_solve(fgraph, node):
return _scan_split_non_sequence_lu_decomposition_solve(
fgraph, node, allowed_assume_a={"gen"}
)
scan_seqopt1.register(
"scan_split_non_sequence_lu_decomposition_solve",
in2out(scan_split_non_sequence_lu_decomposition_solve, ignore_newtrees=True),
"fast_run",
"scan",
"scan_pushout",
position=2,
)
......@@ -75,6 +75,13 @@ def is_matrix_transpose(x: TensorVariable) -> bool:
if ndims < 2:
return False
transpose_order = (*range(ndims - 2), ndims - 1, ndims - 2)
# Allow expand_dims on the left of the transpose
if (diff := len(transpose_order) - len(node.op.new_order)) > 0:
transpose_order = (
*(["x"] * diff),
*transpose_order,
)
return node.op.new_order == transpose_order
return False
......
import numpy as np
import pytest
from pytensor import config, function, scan
from pytensor.compile.mode import get_default_mode
from pytensor.gradient import grad
from pytensor.scan.op import Scan
from pytensor.tensor._linalg.solve.rewriting import (
reuse_lu_decomposition_multiple_solves,
scan_split_non_sequence_lu_decomposition_solve,
)
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.linalg import solve
from pytensor.tensor.slinalg import LUFactor, Solve, SolveTriangular
from pytensor.tensor.type import tensor
def count_vanilla_solve_nodes(nodes) -> int:
return sum(
(
isinstance(node.op, Solve)
or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Solve))
)
for node in nodes
)
def count_lu_decom_nodes(nodes) -> int:
return sum(
(
isinstance(node.op, LUFactor)
or (
isinstance(node.op, Blockwise) and isinstance(node.op.core_op, LUFactor)
)
)
for node in nodes
)
def count_lu_solve_nodes(nodes) -> int:
count = sum(
(
isinstance(node.op, SolveTriangular)
or (
isinstance(node.op, Blockwise)
and isinstance(node.op.core_op, SolveTriangular)
)
)
for node in nodes
)
# Each LU solve uses two Triangular solves
return count // 2
@pytest.mark.parametrize("transposed", (False, True))
def test_lu_decomposition_reused_forward_and_gradient(transposed):
rewrite_name = reuse_lu_decomposition_multiple_solves.__name__
mode = get_default_mode()
A = tensor("A", shape=(2, 2))
b = tensor("b", shape=(2, 3))
x = solve(A, b, assume_a="gen", transposed=transposed)
grad_x_wrt_A = grad(x.sum(), A)
fn_no_opt = function([A, b], [x, grad_x_wrt_A], mode=mode.excluding(rewrite_name))
no_opt_nodes = fn_no_opt.maker.fgraph.apply_nodes
assert count_vanilla_solve_nodes(no_opt_nodes) == 2
assert count_lu_decom_nodes(no_opt_nodes) == 0
assert count_lu_solve_nodes(no_opt_nodes) == 0
fn_opt = function([A, b], [x, grad_x_wrt_A], mode=mode.including(rewrite_name))
opt_nodes = fn_opt.maker.fgraph.apply_nodes
assert count_vanilla_solve_nodes(opt_nodes) == 0
assert count_lu_decom_nodes(opt_nodes) == 1
assert count_lu_solve_nodes(opt_nodes) == 2
# Make sure results are correct
rng = np.random.default_rng(31)
A_test = rng.random(A.type.shape, dtype=A.type.dtype)
b_test = rng.random(b.type.shape, dtype=b.type.dtype)
resx0, resg0 = fn_no_opt(A_test, b_test)
resx1, resg1 = fn_opt(A_test, b_test)
rtol = 1e-7 if config.floatX == "float64" else 1e-6
np.testing.assert_allclose(resx0, resx1, rtol=rtol)
np.testing.assert_allclose(resg0, resg1, rtol=rtol)
@pytest.mark.parametrize("transposed", (False, True))
def test_lu_decomposition_reused_blockwise(transposed):
rewrite_name = reuse_lu_decomposition_multiple_solves.__name__
mode = get_default_mode()
A = tensor("A", shape=(2, 2))
b = tensor("b", shape=(2, 2, 3))
x = solve(A, b, transposed=transposed)
fn_no_opt = function([A, b], [x], mode=mode.excluding(rewrite_name))
no_opt_nodes = fn_no_opt.maker.fgraph.apply_nodes
assert count_vanilla_solve_nodes(no_opt_nodes) == 1
assert count_lu_decom_nodes(no_opt_nodes) == 0
assert count_lu_solve_nodes(no_opt_nodes) == 0
fn_opt = function([A, b], [x], mode=mode.including(rewrite_name))
opt_nodes = fn_opt.maker.fgraph.apply_nodes
assert count_vanilla_solve_nodes(opt_nodes) == 0
assert count_lu_decom_nodes(opt_nodes) == 1
assert count_lu_solve_nodes(opt_nodes) == 1
# Make sure results are correct
rng = np.random.default_rng(31)
A_test = rng.random(A.type.shape, dtype=A.type.dtype)
b_test = rng.random(b.type.shape, dtype=b.type.dtype)
resx0 = fn_no_opt(A_test, b_test)
resx1 = fn_opt(A_test, b_test)
np.testing.assert_allclose(resx0, resx1)
@pytest.mark.parametrize("transposed", (False, True))
def test_lu_decomposition_reused_scan(transposed):
rewrite_name = scan_split_non_sequence_lu_decomposition_solve.__name__
mode = get_default_mode()
A = tensor("A", shape=(2, 2))
x0 = tensor("b", shape=(2, 3))
xs, _ = scan(
lambda xtm1, A: solve(A, xtm1, assume_a="general", transposed=transposed),
outputs_info=[x0],
non_sequences=[A],
n_steps=10,
)
fn_no_opt = function(
[A, x0],
[xs],
mode=mode.excluding(rewrite_name),
)
[no_opt_scan_node] = [
node for node in fn_no_opt.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
]
no_opt_nodes = no_opt_scan_node.op.fgraph.apply_nodes
assert count_vanilla_solve_nodes(no_opt_nodes) == 1
assert count_lu_decom_nodes(no_opt_nodes) == 0
assert count_lu_solve_nodes(no_opt_nodes) == 0
fn_opt = function([A, x0], [xs], mode=mode.including("scan", rewrite_name))
[opt_scan_node] = [
node for node in fn_opt.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
]
opt_nodes = opt_scan_node.op.fgraph.apply_nodes
assert count_vanilla_solve_nodes(opt_nodes) == 0
# The LU decomp is outside of the scan!
assert count_lu_decom_nodes(opt_nodes) == 0
assert count_lu_solve_nodes(opt_nodes) == 1
# Make sure results are correct
rng = np.random.default_rng(170)
A_test = rng.random(A.type.shape, dtype=A.type.dtype)
x0_test = rng.random(x0.type.shape, dtype=x0.type.dtype)
resx0 = fn_no_opt(A_test, x0_test)
resx1 = fn_opt(A_test, x0_test)
rtol = 1e-7 if config.floatX == "float64" else 1e-6
np.testing.assert_allclose(resx0, resx1, rtol=rtol)
......@@ -579,7 +579,10 @@ class TestInplace:
else:
x = solve_fn(A, b, b_ndim=1)
mode = get_default_mode().excluding("batched_vector_b_solve_to_matrix_b_solve")
mode = get_default_mode().excluding(
"batched_vector_b_solve_to_matrix_b_solve",
"reuse_lu_decomposition_multiple_solves",
)
fn = function([In(A, mutable=True), In(b, mutable=True)], x, mode=mode)
op = fn.maker.fgraph.outputs[0].owner.op
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论