提交 d88c7351 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Decompose Tridiagonal Solve into core steps

上级 43d8e303
...@@ -477,6 +477,9 @@ JAX = Mode( ...@@ -477,6 +477,9 @@ JAX = Mode(
"fusion", "fusion",
"inplace", "inplace",
"scan_save_mem_prealloc", "scan_save_mem_prealloc",
# There are specific variants for the LU decompositions supported by JAX
"reuse_lu_decomposition_multiple_solves",
"scan_split_non_sequence_lu_decomposition_solve",
], ],
), ),
) )
......
...@@ -6,6 +6,7 @@ from numba.np.linalg import ensure_lapack ...@@ -6,6 +6,7 @@ from numba.np.linalg import ensure_lapack
from numpy import ndarray from numpy import ndarray
from scipy import linalg from scipy import linalg
from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import numba_njit from pytensor.link.numba.dispatch.basic import numba_njit
from pytensor.link.numba.dispatch.linalg._LAPACK import ( from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK, _LAPACK,
...@@ -20,6 +21,10 @@ from pytensor.link.numba.dispatch.linalg.utils import ( ...@@ -20,6 +21,10 @@ from pytensor.link.numba.dispatch.linalg.utils import (
_solve_check, _solve_check,
_trans_char_to_int, _trans_char_to_int,
) )
from pytensor.tensor._linalg.solve.tridiagonal import (
LUFactorTridiagonal,
SolveLUFactorTridiagonal,
)
@numba_njit @numba_njit
...@@ -34,7 +39,12 @@ def tridiagonal_norm(du, d, dl): ...@@ -34,7 +39,12 @@ def tridiagonal_norm(du, d, dl):
def _gttrf( def _gttrf(
dl: ndarray, d: ndarray, du: ndarray dl: ndarray,
d: ndarray,
du: ndarray,
overwrite_dl: bool,
overwrite_d: bool,
overwrite_du: bool,
) -> tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int]: ) -> tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int]:
"""Placeholder for LU factorization of tridiagonal matrix.""" """Placeholder for LU factorization of tridiagonal matrix."""
return # type: ignore return # type: ignore
...@@ -45,8 +55,12 @@ def gttrf_impl( ...@@ -45,8 +55,12 @@ def gttrf_impl(
dl: ndarray, dl: ndarray,
d: ndarray, d: ndarray,
du: ndarray, du: ndarray,
overwrite_dl: bool,
overwrite_d: bool,
overwrite_du: bool,
) -> Callable[ ) -> Callable[
[ndarray, ndarray, ndarray], tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int] [ndarray, ndarray, ndarray, bool, bool, bool],
tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int],
]: ]:
ensure_lapack() ensure_lapack()
_check_scipy_linalg_matrix(dl, "gttrf") _check_scipy_linalg_matrix(dl, "gttrf")
...@@ -60,12 +74,24 @@ def gttrf_impl( ...@@ -60,12 +74,24 @@ def gttrf_impl(
dl: ndarray, dl: ndarray,
d: ndarray, d: ndarray,
du: ndarray, du: ndarray,
overwrite_dl: bool,
overwrite_d: bool,
overwrite_du: bool,
) -> tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int]: ) -> tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int]:
n = np.int32(d.shape[-1]) n = np.int32(d.shape[-1])
ipiv = np.empty(n, dtype=np.int32) ipiv = np.empty(n, dtype=np.int32)
du2 = np.empty(n - 2, dtype=dtype) du2 = np.empty(n - 2, dtype=dtype)
info = val_to_int_ptr(0) info = val_to_int_ptr(0)
if not overwrite_dl or not dl.flags.f_contiguous:
dl = dl.copy()
if not overwrite_d or not d.flags.f_contiguous:
d = d.copy()
if not overwrite_du or not du.flags.f_contiguous:
du = du.copy()
numba_gttrf( numba_gttrf(
val_to_int_ptr(n), val_to_int_ptr(n),
dl.view(w_type).ctypes, dl.view(w_type).ctypes,
...@@ -133,10 +159,23 @@ def gttrs_impl( ...@@ -133,10 +159,23 @@ def gttrs_impl(
nrhs = 1 if b.ndim == 1 else int(b.shape[-1]) nrhs = 1 if b.ndim == 1 else int(b.shape[-1])
info = val_to_int_ptr(0) info = val_to_int_ptr(0)
if overwrite_b and b.flags.f_contiguous: if not overwrite_b or not b.flags.f_contiguous:
b_copy = b b = _copy_to_fortran_order_even_if_1d(b)
else:
b_copy = _copy_to_fortran_order_even_if_1d(b) if not dl.flags.f_contiguous:
dl = dl.copy()
if not d.flags.f_contiguous:
d = d.copy()
if not du.flags.f_contiguous:
du = du.copy()
if not du2.flags.f_contiguous:
du2 = du2.copy()
if not ipiv.flags.f_contiguous:
ipiv = ipiv.copy()
numba_gttrs( numba_gttrs(
val_to_int_ptr(_trans_char_to_int(trans)), val_to_int_ptr(_trans_char_to_int(trans)),
...@@ -147,12 +186,12 @@ def gttrs_impl( ...@@ -147,12 +186,12 @@ def gttrs_impl(
du.view(w_type).ctypes, du.view(w_type).ctypes,
du2.view(w_type).ctypes, du2.view(w_type).ctypes,
ipiv.ctypes, ipiv.ctypes,
b_copy.view(w_type).ctypes, b.view(w_type).ctypes,
val_to_int_ptr(n), val_to_int_ptr(n),
info, info,
) )
return b_copy, int_ptr_to_val(info) return b, int_ptr_to_val(info)
return impl return impl
...@@ -283,7 +322,9 @@ def _tridiagonal_solve_impl( ...@@ -283,7 +322,9 @@ def _tridiagonal_solve_impl(
anorm = tridiagonal_norm(du, d, dl) anorm = tridiagonal_norm(du, d, dl)
dl, d, du, du2, IPIV, INFO = _gttrf(dl, d, du) dl, d, du, du2, IPIV, INFO = _gttrf(
dl, d, du, overwrite_dl=True, overwrite_d=True, overwrite_du=True
)
_solve_check(n, INFO) _solve_check(n, INFO)
X, INFO = _gttrs( X, INFO = _gttrs(
...@@ -297,3 +338,48 @@ def _tridiagonal_solve_impl( ...@@ -297,3 +338,48 @@ def _tridiagonal_solve_impl(
return X return X
return impl return impl
@numba_funcify.register(LUFactorTridiagonal)
def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
overwrite_dl = op.overwrite_dl
overwrite_d = op.overwrite_d
overwrite_du = op.overwrite_du
@numba_njit(cache=False)
def lu_factor_tridiagonal(dl, d, du):
dl, d, du, du2, ipiv, _ = _gttrf(
dl,
d,
du,
overwrite_dl=overwrite_dl,
overwrite_d=overwrite_d,
overwrite_du=overwrite_du,
)
return dl, d, du, du2, ipiv
return lu_factor_tridiagonal
@numba_funcify.register(SolveLUFactorTridiagonal)
def numba_funcify_SolveLUFactorTridiagonal(
op: SolveLUFactorTridiagonal, node, **kwargs
):
overwrite_b = op.overwrite_b
transposed = op.transposed
@numba_njit(cache=False)
def solve_lu_factor_tridiagonal(dl, d, du, du2, ipiv, b):
x, _ = _gttrs(
dl,
d,
du,
du2,
ipiv,
b,
overwrite_b=overwrite_b,
trans=transposed,
)
return x
return solve_lu_factor_tridiagonal
from collections.abc import Container from collections.abc import Container
from copy import copy from copy import copy
from pytensor.compile import optdb
from pytensor.graph import Constant, graph_inputs from pytensor.graph import Constant, graph_inputs
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
from pytensor.scan.op import Scan from pytensor.scan.op import Scan
from pytensor.scan.rewriting import scan_seqopt1 from pytensor.scan.rewriting import scan_seqopt1
from pytensor.tensor._linalg.solve.tridiagonal import (
tridiagonal_lu_factor,
tridiagonal_lu_solve,
)
from pytensor.tensor.basic import atleast_Nd from pytensor.tensor.basic import atleast_Nd
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
...@@ -17,18 +22,32 @@ from pytensor.tensor.variable import TensorVariable ...@@ -17,18 +22,32 @@ from pytensor.tensor.variable import TensorVariable
def decompose_A(A, assume_a, check_finite): def decompose_A(A, assume_a, check_finite):
if assume_a == "gen": if assume_a == "gen":
return lu_factor(A, check_finite=check_finite) return lu_factor(A, check_finite=check_finite)
elif assume_a == "tridiagonal":
# We didn't implement check_finite for tridiagonal LU factorization
return tridiagonal_lu_factor(A)
else: else:
raise NotImplementedError raise NotImplementedError
def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op: Solve): def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op: Solve):
if core_solve_op.assume_a == "gen": b_ndim = core_solve_op.b_ndim
check_finite = core_solve_op.check_finite
assume_a = core_solve_op.assume_a
if assume_a == "gen":
return lu_solve( return lu_solve(
A_decomp, A_decomp,
b, b,
b_ndim=b_ndim,
trans=transposed, trans=transposed,
b_ndim=core_solve_op.b_ndim, check_finite=check_finite,
check_finite=core_solve_op.check_finite, )
elif assume_a == "tridiagonal":
# We didn't implement check_finite for tridiagonal LU solve
return tridiagonal_lu_solve(
A_decomp,
b,
b_ndim=b_ndim,
transposed=transposed,
) )
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -189,13 +208,15 @@ def _scan_split_non_sequence_lu_decomposition_solve( ...@@ -189,13 +208,15 @@ def _scan_split_non_sequence_lu_decomposition_solve(
@register_specialize @register_specialize
@node_rewriter([Blockwise]) @node_rewriter([Blockwise])
def reuse_lu_decomposition_multiple_solves(fgraph, node): def reuse_lu_decomposition_multiple_solves(fgraph, node):
return _split_lu_solve_steps(fgraph, node, eager=False, allowed_assume_a={"gen"}) return _split_lu_solve_steps(
fgraph, node, eager=False, allowed_assume_a={"gen", "tridiagonal"}
)
@node_rewriter([Scan]) @node_rewriter([Scan])
def scan_split_non_sequence_lu_decomposition_solve(fgraph, node): def scan_split_non_sequence_lu_decomposition_solve(fgraph, node):
return _scan_split_non_sequence_lu_decomposition_solve( return _scan_split_non_sequence_lu_decomposition_solve(
fgraph, node, allowed_assume_a={"gen"} fgraph, node, allowed_assume_a={"gen", "tridiagonal"}
) )
...@@ -207,3 +228,32 @@ scan_seqopt1.register( ...@@ -207,3 +228,32 @@ scan_seqopt1.register(
"scan_pushout", "scan_pushout",
position=2, position=2,
) )
@node_rewriter([Blockwise])
def reuse_lu_decomposition_multiple_solves_jax(fgraph, node):
return _split_lu_solve_steps(fgraph, node, eager=False, allowed_assume_a={"gen"})
optdb["specialize"].register(
reuse_lu_decomposition_multiple_solves_jax.__name__,
in2out(reuse_lu_decomposition_multiple_solves_jax, ignore_newtrees=True),
"jax",
use_db_name_as_tag=False,
)
@node_rewriter([Scan])
def scan_split_non_sequence_lu_decomposition_solve_jax(fgraph, node):
return _scan_split_non_sequence_lu_decomposition_solve(
fgraph, node, allowed_assume_a={"gen"}
)
scan_seqopt1.register(
scan_split_non_sequence_lu_decomposition_solve_jax.__name__,
in2out(scan_split_non_sequence_lu_decomposition_solve_jax, ignore_newtrees=True),
"jax",
use_db_name_as_tag=False,
position=2,
)
import typing
from typing import TYPE_CHECKING
import numpy as np
from scipy.linalg import get_lapack_funcs
from pytensor.graph import Apply, Op
from pytensor.tensor.basic import as_tensor, diagonal
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.type import tensor, vector
from pytensor.tensor.variable import TensorVariable
if TYPE_CHECKING:
from pytensor.tensor import TensorLike
class LUFactorTridiagonal(Op):
"""Compute LU factorization of a tridiagonal matrix (lapack gttrf)"""
__props__ = (
"overwrite_dl",
"overwrite_d",
"overwrite_du",
)
gufunc_signature = "(dl),(d),(dl)->(dl),(d),(dl),(du2),(d)"
def __init__(self, overwrite_dl=False, overwrite_d=False, overwrite_du=False):
self.destroy_map = dm = {}
if overwrite_dl:
dm[0] = [0]
if overwrite_d:
dm[1] = [1]
if overwrite_du:
dm[2] = [2]
self.overwrite_dl = overwrite_dl
self.overwrite_d = overwrite_d
self.overwrite_du = overwrite_du
super().__init__()
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
return type(self)(
overwrite_dl=0 in allowed_inplace_inputs,
overwrite_d=1 in allowed_inplace_inputs,
overwrite_du=2 in allowed_inplace_inputs,
)
def make_node(self, dl, d, du):
dl, d, du = map(as_tensor, (dl, d, du))
if not all(inp.type.ndim == 1 for inp in (dl, d, du)):
raise ValueError("Diagonals must be vectors")
ndl, nd, ndu = (inp.type.shape[-1] for inp in (dl, d, du))
match (ndl, nd, ndu):
case (int(), _, _):
n = ndl + 1
case (_, int(), _):
n = nd + 1
case (_, _, int()):
n = ndu + 1
case _:
n = None
dummy_arrays = [np.zeros((), dtype=inp.type.dtype) for inp in (dl, d, du)]
out_dtype = get_lapack_funcs("gttrf", dummy_arrays).dtype
outputs = [
vector(shape=(None if n is None else (n - 1),), dtype=out_dtype),
vector(shape=(n,), dtype=out_dtype),
vector(shape=(None if n is None else n - 1,), dtype=out_dtype),
vector(shape=(None if n is None else n - 2,), dtype=out_dtype),
vector(shape=(n,), dtype=np.int32),
]
return Apply(self, [dl, d, du], outputs)
def perform(self, node, inputs, output_storage):
gttrf = get_lapack_funcs("gttrf", dtype=node.outputs[0].type.dtype)
dl, d, du, du2, ipiv, _ = gttrf(
*inputs,
overwrite_dl=self.overwrite_dl,
overwrite_d=self.overwrite_d,
overwrite_du=self.overwrite_du,
)
output_storage[0][0] = dl
output_storage[1][0] = d
output_storage[2][0] = du
output_storage[3][0] = du2
output_storage[4][0] = ipiv
class SolveLUFactorTridiagonal(Op):
"""Solve a system of linear equations with a tridiagonal coefficient matrix (lapack gttrs)."""
__props__ = ("b_ndim", "overwrite_b", "transposed")
def __init__(self, b_ndim: int, transposed: bool, overwrite_b=False):
if b_ndim not in (1, 2):
raise ValueError("b_ndim must be 1 or 2")
if b_ndim == 1:
self.gufunc_signature = "(dl),(d),(dl),(du2),(d),(d)->(d)"
else:
self.gufunc_signature = "(dl),(d),(dl),(du2),(d),(d,rhs)->(d,rhs)"
if overwrite_b:
self.destroy_map = {0: [5]}
self.b_ndim = b_ndim
self.transposed = transposed
self.overwrite_b = overwrite_b
super().__init__()
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
# b matrix is the 5th input
if 5 in allowed_inplace_inputs:
props = self._props_dict() # type: ignore
props["overwrite_b"] = True
return type(self)(**props)
return self
def make_node(self, dl, d, du, du2, ipiv, b):
dl, d, du, du2, ipiv, b = map(as_tensor, (dl, d, du, du2, ipiv, b))
if b.type.ndim != self.b_ndim:
raise ValueError("Wrong number of dimensions for input b.")
if not all(inp.type.ndim == 1 for inp in (dl, d, du, du2, ipiv)):
raise ValueError("Inputs must be vectors")
ndl, nd, ndu, ndu2, nipiv = (
inp.type.shape[-1] for inp in (dl, d, du, du2, ipiv)
)
nb = b.type.shape[0]
match (ndl, nd, ndu, ndu2, nipiv):
case (int(), _, _, _, _):
n = ndl + 1
case (_, int(), _, _, _):
n = nd
case (_, _, int(), _, _):
n = ndu + 1
case (_, _, _, int(), _):
n = ndu2 + 2
case (_, _, _, _, int()):
n = nipiv
case _:
n = nb
dummy_arrays = [
np.zeros((), dtype=inp.type.dtype) for inp in (dl, d, du, du2, ipiv)
]
# Seems to always be float64?
out_dtype = get_lapack_funcs("gttrs", dummy_arrays).dtype
if self.b_ndim == 1:
output_shape = (n,)
else:
output_shape = (n, b.type.shape[-1])
outputs = [tensor(shape=output_shape, dtype=out_dtype)]
return Apply(self, [dl, d, du, du2, ipiv, b], outputs)
def perform(self, node, inputs, output_storage):
gttrs = get_lapack_funcs("gttrs", dtype=node.outputs[0].type.dtype)
x, _ = gttrs(
*inputs,
overwrite_b=self.overwrite_b,
trans="N" if not self.transposed else "T",
)
output_storage[0][0] = x
def tridiagonal_lu_factor(
a: "TensorLike",
) -> tuple[
TensorVariable, TensorVariable, TensorVariable, TensorVariable, TensorVariable
]:
"""Return the decomposition of A implied by a solve tridiagonal (LAPACK's gttrf)
Parameters
----------
a
The input matrix.
Returns
-------
dl, d, du, du2, ipiv
The LU factorization of A.
"""
dl, d, du = (diagonal(a, offset=o, axis1=-2, axis2=-1) for o in (-1, 0, 1))
dl, d, du, du2, ipiv = typing.cast(
list[TensorVariable], Blockwise(LUFactorTridiagonal())(dl, d, du)
)
return dl, d, du, du2, ipiv
def tridiagonal_lu_solve(
a_diagonals: tuple[
"TensorLike", "TensorLike", "TensorLike", "TensorLike", "TensorLike"
],
b: "TensorLike",
*,
b_ndim: int,
transposed: bool = False,
) -> TensorVariable:
"""Solve a tridiagonal system of equations using LU factorized inputs (LAPACK's gttrs).
Parameters
----------
a_diagonals
The outputs of tridiagonal_lu_factor(A).
b
The right-hand side vector or matrix.
b_ndim
The number of dimensions of the right-hand side.
transposed
Whether to solve the transposed system.
Returns
-------
TensorVariable
The solution vector or matrix.
"""
dl, d, du, du2, ipiv = a_diagonals
return typing.cast(
TensorVariable,
Blockwise(SolveLUFactorTridiagonal(b_ndim=b_ndim, transposed=transposed))(
dl, d, du, du2, ipiv, b
),
)
import numpy as np
import pytest
import scipy
from pytensor import In
from pytensor import tensor as pt
from pytensor.tensor._linalg.solve.tridiagonal import (
LUFactorTridiagonal,
SolveLUFactorTridiagonal,
)
from pytensor.tensor.blockwise import Blockwise
from tests.link.numba.test_basic import compare_numba_and_py, numba_inplace_mode
@pytest.mark.parametrize("inplace", [False, True], ids=lambda x: f"inplace={x}")
def test_tridiagonal_lu_factor(inplace):
dl = pt.vector("dl", shape=(4,))
d = pt.vector("d", shape=(5,))
du = pt.vector("du", shape=(4,))
lu_factor_outs = Blockwise(LUFactorTridiagonal())(dl, d, du)
rng = np.random.default_rng(734)
dl_test = rng.random(dl.type.shape)
d_test = rng.random(d.type.shape)
du_test = rng.random(du.type.shape)
f, results = compare_numba_and_py(
[
In(dl, mutable=inplace),
In(d, mutable=inplace),
In(du, mutable=inplace),
],
lu_factor_outs,
test_inputs=[dl_test, d_test, du_test],
inplace=True,
numba_mode=numba_inplace_mode,
eval_obj_mode=False,
)
# Test with contiguous inputs
dl_test_contig = dl_test.copy()
d_test_contig = d_test.copy()
du_test_contig = du_test.copy()
results_contig = f(dl_test_contig, d_test_contig, du_test_contig)
for res, res_contig in zip(results, results_contig):
np.testing.assert_allclose(res, res_contig)
assert (dl_test_contig == dl_test).all() == (not inplace)
assert (d_test_contig == d_test).all() == (not inplace)
assert (du_test_contig == du_test).all() == (not inplace)
# Test with non-contiguous inputs
dl_test_not_contig = np.repeat(dl_test, 2)[::2]
d_test_not_contig = np.repeat(d_test, 2)[::2]
du_test_not_contig = np.repeat(du_test, 2)[::2]
results_not_contig = f(dl_test_not_contig, d_test_not_contig, du_test_not_contig)
for res, res_not_contig in zip(results, results_not_contig):
np.testing.assert_allclose(res, res_not_contig)
# Non-contiguous inputs have to be copied so are not modified in place
assert (dl_test_not_contig == dl_test).all()
assert (d_test_not_contig == d_test).all()
assert (du_test_not_contig == du_test).all()
@pytest.mark.parametrize("transposed", [False, True], ids=lambda x: f"transposed={x}")
@pytest.mark.parametrize("inplace", [True, False], ids=lambda x: f"inplace={x}")
@pytest.mark.parametrize("b_ndim", [1, 2], ids=lambda x: f"b_ndim={x}")
def test_tridiagonal_lu_solve(b_ndim, transposed, inplace):
scipy_gttrf = scipy.linalg.get_lapack_funcs("gttrf")
dl = pt.tensor("dl", shape=(9,))
d = pt.tensor("d", shape=(10,))
du = pt.tensor("du", shape=(9,))
du2 = pt.tensor("du2", shape=(8,))
ipiv = pt.tensor("ipiv", shape=(10,), dtype="int32")
diagonals = [dl, d, du, du2, ipiv]
b = pt.tensor("b", shape=(10, 25)[:b_ndim])
x = Blockwise(SolveLUFactorTridiagonal(b_ndim=b.type.ndim, transposed=transposed))(
*diagonals, b
)
rng = np.random.default_rng(787)
A_test = rng.random((d.type.shape[0], d.type.shape[0]))
*diagonals_test, _ = scipy_gttrf(
*(np.diagonal(A_test, offset=o) for o in (-1, 0, 1))
)
b_test = rng.random(b.type.shape)
f, res = compare_numba_and_py(
[
*diagonals,
In(b, mutable=inplace),
],
x,
test_inputs=[*diagonals_test, b_test],
inplace=True,
numba_mode=numba_inplace_mode,
eval_obj_mode=False,
)
# Test with contiguous_inputs
diagonals_test_contig = [d_test.copy() for d_test in diagonals_test]
b_test_contig = b_test.copy(order="F")
res_contig = f(*diagonals_test_contig, b_test_contig)
assert (res_contig == res).all()
assert (b_test == b_test_contig).all() == (not inplace)
# Test with non-contiguous inputs
diagonals_test_non_contig = [np.repeat(d_test, 2)[::2] for d_test in diagonals_test]
b_test_non_contig = np.repeat(b_test, 2, axis=0)[::2]
res_non_contig = f(*diagonals_test_non_contig, b_test_non_contig)
assert (res_non_contig == res).all()
# b must be copied when not contiguous so it can't be inplaced
assert (b_test == b_test_non_contig).all()
...@@ -9,6 +9,10 @@ from pytensor.tensor._linalg.solve.rewriting import ( ...@@ -9,6 +9,10 @@ from pytensor.tensor._linalg.solve.rewriting import (
reuse_lu_decomposition_multiple_solves, reuse_lu_decomposition_multiple_solves,
scan_split_non_sequence_lu_decomposition_solve, scan_split_non_sequence_lu_decomposition_solve,
) )
from pytensor.tensor._linalg.solve.tridiagonal import (
LUFactorTridiagonal,
SolveLUFactorTridiagonal,
)
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.linalg import solve from pytensor.tensor.linalg import solve
from pytensor.tensor.slinalg import LUFactor, Solve, SolveTriangular from pytensor.tensor.slinalg import LUFactor, Solve, SolveTriangular
...@@ -28,9 +32,10 @@ def count_vanilla_solve_nodes(nodes) -> int: ...@@ -28,9 +32,10 @@ def count_vanilla_solve_nodes(nodes) -> int:
def count_lu_decom_nodes(nodes) -> int: def count_lu_decom_nodes(nodes) -> int:
return sum( return sum(
( (
isinstance(node.op, LUFactor) isinstance(node.op, LUFactor | LUFactorTridiagonal)
or ( or (
isinstance(node.op, Blockwise) and isinstance(node.op.core_op, LUFactor) isinstance(node.op, Blockwise)
and isinstance(node.op.core_op, LUFactor | LUFactorTridiagonal)
) )
) )
for node in nodes for node in nodes
...@@ -40,27 +45,38 @@ def count_lu_decom_nodes(nodes) -> int: ...@@ -40,27 +45,38 @@ def count_lu_decom_nodes(nodes) -> int:
def count_lu_solve_nodes(nodes) -> int: def count_lu_solve_nodes(nodes) -> int:
count = sum( count = sum(
( (
# LUFactor uses 2 SolveTriangular nodes, so we count each as 0.5
0.5
* (
isinstance(node.op, SolveTriangular) isinstance(node.op, SolveTriangular)
or ( or (
isinstance(node.op, Blockwise) isinstance(node.op, Blockwise)
and isinstance(node.op.core_op, SolveTriangular) and isinstance(node.op.core_op, SolveTriangular)
) )
) )
or (
isinstance(node.op, SolveLUFactorTridiagonal)
or (
isinstance(node.op, Blockwise)
and isinstance(node.op.core_op, SolveLUFactorTridiagonal)
)
)
)
for node in nodes for node in nodes
) )
# Each LU solve uses two Triangular solves return int(count)
return count // 2
@pytest.mark.parametrize("transposed", (False, True)) @pytest.mark.parametrize("transposed", (False, True))
def test_lu_decomposition_reused_forward_and_gradient(transposed): @pytest.mark.parametrize("assume_a", ("gen", "tridiagonal"))
def test_lu_decomposition_reused_forward_and_gradient(assume_a, transposed):
rewrite_name = reuse_lu_decomposition_multiple_solves.__name__ rewrite_name = reuse_lu_decomposition_multiple_solves.__name__
mode = get_default_mode() mode = get_default_mode()
A = tensor("A", shape=(2, 2)) A = tensor("A", shape=(3, 3))
b = tensor("b", shape=(2, 3)) b = tensor("b", shape=(3, 4))
x = solve(A, b, assume_a="gen", transposed=transposed) x = solve(A, b, assume_a=assume_a, transposed=transposed)
grad_x_wrt_A = grad(x.sum(), A) grad_x_wrt_A = grad(x.sum(), A)
fn_no_opt = function([A, b], [x, grad_x_wrt_A], mode=mode.excluding(rewrite_name)) fn_no_opt = function([A, b], [x, grad_x_wrt_A], mode=mode.excluding(rewrite_name))
no_opt_nodes = fn_no_opt.maker.fgraph.apply_nodes no_opt_nodes = fn_no_opt.maker.fgraph.apply_nodes
...@@ -80,20 +96,21 @@ def test_lu_decomposition_reused_forward_and_gradient(transposed): ...@@ -80,20 +96,21 @@ def test_lu_decomposition_reused_forward_and_gradient(transposed):
b_test = rng.random(b.type.shape, dtype=b.type.dtype) b_test = rng.random(b.type.shape, dtype=b.type.dtype)
resx0, resg0 = fn_no_opt(A_test, b_test) resx0, resg0 = fn_no_opt(A_test, b_test)
resx1, resg1 = fn_opt(A_test, b_test) resx1, resg1 = fn_opt(A_test, b_test)
rtol = 1e-7 if config.floatX == "float64" else 1e-6 rtol = 1e-7 if config.floatX == "float64" else 1e-4
np.testing.assert_allclose(resx0, resx1, rtol=rtol) np.testing.assert_allclose(resx0, resx1, rtol=rtol)
np.testing.assert_allclose(resg0, resg1, rtol=rtol) np.testing.assert_allclose(resg0, resg1, rtol=rtol)
@pytest.mark.parametrize("transposed", (False, True)) @pytest.mark.parametrize("transposed", (False, True))
def test_lu_decomposition_reused_blockwise(transposed): @pytest.mark.parametrize("assume_a", ("gen", "tridiagonal"))
def test_lu_decomposition_reused_blockwise(assume_a, transposed):
rewrite_name = reuse_lu_decomposition_multiple_solves.__name__ rewrite_name = reuse_lu_decomposition_multiple_solves.__name__
mode = get_default_mode() mode = get_default_mode()
A = tensor("A", shape=(2, 2)) A = tensor("A", shape=(3, 3))
b = tensor("b", shape=(2, 2, 3)) b = tensor("b", shape=(2, 3, 4))
x = solve(A, b, transposed=transposed) x = solve(A, b, assume_a=assume_a, transposed=transposed)
fn_no_opt = function([A, b], [x], mode=mode.excluding(rewrite_name)) fn_no_opt = function([A, b], [x], mode=mode.excluding(rewrite_name))
no_opt_nodes = fn_no_opt.maker.fgraph.apply_nodes no_opt_nodes = fn_no_opt.maker.fgraph.apply_nodes
assert count_vanilla_solve_nodes(no_opt_nodes) == 1 assert count_vanilla_solve_nodes(no_opt_nodes) == 1
...@@ -112,19 +129,21 @@ def test_lu_decomposition_reused_blockwise(transposed): ...@@ -112,19 +129,21 @@ def test_lu_decomposition_reused_blockwise(transposed):
b_test = rng.random(b.type.shape, dtype=b.type.dtype) b_test = rng.random(b.type.shape, dtype=b.type.dtype)
resx0 = fn_no_opt(A_test, b_test) resx0 = fn_no_opt(A_test, b_test)
resx1 = fn_opt(A_test, b_test) resx1 = fn_opt(A_test, b_test)
np.testing.assert_allclose(resx0, resx1) rtol = rtol = 1e-7 if config.floatX == "float64" else 1e-4
np.testing.assert_allclose(resx0, resx1, rtol=rtol)
@pytest.mark.parametrize("transposed", (False, True)) @pytest.mark.parametrize("transposed", (False, True))
def test_lu_decomposition_reused_scan(transposed): @pytest.mark.parametrize("assume_a", ("gen", "tridiagonal"))
def test_lu_decomposition_reused_scan(assume_a, transposed):
rewrite_name = scan_split_non_sequence_lu_decomposition_solve.__name__ rewrite_name = scan_split_non_sequence_lu_decomposition_solve.__name__
mode = get_default_mode() mode = get_default_mode()
A = tensor("A", shape=(2, 2)) A = tensor("A", shape=(3, 3))
x0 = tensor("b", shape=(2, 3)) x0 = tensor("b", shape=(3, 4))
xs, _ = scan( xs, _ = scan(
lambda xtm1, A: solve(A, xtm1, assume_a="general", transposed=transposed), lambda xtm1, A: solve(A, xtm1, assume_a=assume_a, transposed=transposed),
outputs_info=[x0], outputs_info=[x0],
non_sequences=[A], non_sequences=[A],
n_steps=10, n_steps=10,
...@@ -159,7 +178,7 @@ def test_lu_decomposition_reused_scan(transposed): ...@@ -159,7 +178,7 @@ def test_lu_decomposition_reused_scan(transposed):
x0_test = rng.random(x0.type.shape, dtype=x0.type.dtype) x0_test = rng.random(x0.type.shape, dtype=x0.type.dtype)
resx0 = fn_no_opt(A_test, x0_test) resx0 = fn_no_opt(A_test, x0_test)
resx1 = fn_opt(A_test, x0_test) resx1 = fn_opt(A_test, x0_test)
rtol = 1e-7 if config.floatX == "float64" else 1e-6 rtol = 1e-7 if config.floatX == "float64" else 1e-4
np.testing.assert_allclose(resx0, resx1, rtol=rtol) np.testing.assert_allclose(resx0, resx1, rtol=rtol)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论