Unverified 提交 862c4169 authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Reuse `cholesky` decomposition with `cho_solve` in graphs with multiple…

Reuse `cholesky` decomposition with `cho_solve` in graphs with multiple `pt.solve` when `assume_a = "pos"` (#1467) * Extend decomp+solve rewrite machinery to `assume_a="pos"` * Update rewrite name in test * Refactor tests to be nicer * Respect core op `lower` flag when rewriting to ChoSolve
上级 cca20eb7
...@@ -15,24 +15,29 @@ from pytensor.tensor.blockwise import Blockwise ...@@ -15,24 +15,29 @@ from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.rewriting.basic import register_specialize from pytensor.tensor.rewriting.basic import register_specialize
from pytensor.tensor.rewriting.linalg import is_matrix_transpose from pytensor.tensor.rewriting.linalg import is_matrix_transpose
from pytensor.tensor.slinalg import Solve, lu_factor, lu_solve from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_solve
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
def decompose_A(A, assume_a, check_finite): def decompose_A(A, assume_a, check_finite, lower):
if assume_a == "gen": if assume_a == "gen":
return lu_factor(A, check_finite=check_finite) return lu_factor(A, check_finite=check_finite)
elif assume_a == "tridiagonal": elif assume_a == "tridiagonal":
# We didn't implement check_finite for tridiagonal LU factorization # We didn't implement check_finite for tridiagonal LU factorization
return tridiagonal_lu_factor(A) return tridiagonal_lu_factor(A)
elif assume_a == "pos":
return cholesky(A, lower=lower, check_finite=check_finite)
else: else:
raise NotImplementedError raise NotImplementedError
def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op: Solve): def solve_decomposed_system(
A_decomp, b, transposed=False, lower=False, *, core_solve_op: Solve
):
b_ndim = core_solve_op.b_ndim b_ndim = core_solve_op.b_ndim
check_finite = core_solve_op.check_finite check_finite = core_solve_op.check_finite
assume_a = core_solve_op.assume_a assume_a = core_solve_op.assume_a
if assume_a == "gen": if assume_a == "gen":
return lu_solve( return lu_solve(
A_decomp, A_decomp,
...@@ -49,11 +54,19 @@ def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op: ...@@ -49,11 +54,19 @@ def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op:
b_ndim=b_ndim, b_ndim=b_ndim,
transposed=transposed, transposed=transposed,
) )
elif assume_a == "pos":
# We can ignore the transposed argument here because A is symmetric by assumption
return cho_solve(
(A_decomp, lower),
b,
b_ndim=b_ndim,
check_finite=check_finite,
)
else: else:
raise NotImplementedError raise NotImplementedError
def _split_lu_solve_steps( def _split_decomp_and_solve_steps(
fgraph, node, *, eager: bool, allowed_assume_a: Container[str] fgraph, node, *, eager: bool, allowed_assume_a: Container[str]
): ):
if not isinstance(node.op.core_op, Solve): if not isinstance(node.op.core_op, Solve):
...@@ -133,13 +146,21 @@ def _split_lu_solve_steps( ...@@ -133,13 +146,21 @@ def _split_lu_solve_steps(
if client.op.core_op.check_finite: if client.op.core_op.check_finite:
check_finite_decomp = True check_finite_decomp = True
break break
A_decomp = decompose_A(A, assume_a=assume_a, check_finite=check_finite_decomp)
lower = node.op.core_op.lower
A_decomp = decompose_A(
A, assume_a=assume_a, check_finite=check_finite_decomp, lower=lower
)
replacements = {} replacements = {}
for client, transposed in A_solve_clients_and_transpose: for client, transposed in A_solve_clients_and_transpose:
_, b = client.inputs _, b = client.inputs
new_x = solve_lu_decomposed_system( new_x = solve_decomposed_system(
A_decomp, b, transposed=transposed, core_solve_op=client.op.core_op A_decomp,
b,
transposed=transposed,
lower=lower,
core_solve_op=client.op.core_op,
) )
[old_x] = client.outputs [old_x] = client.outputs
new_x = atleast_Nd(new_x, n=old_x.type.ndim).astype(old_x.type.dtype) new_x = atleast_Nd(new_x, n=old_x.type.ndim).astype(old_x.type.dtype)
...@@ -149,7 +170,7 @@ def _split_lu_solve_steps( ...@@ -149,7 +170,7 @@ def _split_lu_solve_steps(
return replacements return replacements
def _scan_split_non_sequence_lu_decomposition_solve( def _scan_split_non_sequence_decomposition_and_solve(
fgraph, node, *, allowed_assume_a: Container[str] fgraph, node, *, allowed_assume_a: Container[str]
): ):
"""If the A of a Solve within a Scan is a function of non-sequences, split the LU decomposition step. """If the A of a Solve within a Scan is a function of non-sequences, split the LU decomposition step.
...@@ -179,7 +200,7 @@ def _scan_split_non_sequence_lu_decomposition_solve( ...@@ -179,7 +200,7 @@ def _scan_split_non_sequence_lu_decomposition_solve(
non_sequences = {equiv[non_seq] for non_seq in non_sequences} non_sequences = {equiv[non_seq] for non_seq in non_sequences}
inner_node = equiv[inner_node] # type: ignore inner_node = equiv[inner_node] # type: ignore
replace_dict = _split_lu_solve_steps( replace_dict = _split_decomp_and_solve_steps(
new_scan_fgraph, new_scan_fgraph,
inner_node, inner_node,
eager=True, eager=True,
...@@ -207,22 +228,22 @@ def _scan_split_non_sequence_lu_decomposition_solve( ...@@ -207,22 +228,22 @@ def _scan_split_non_sequence_lu_decomposition_solve(
@register_specialize @register_specialize
@node_rewriter([Blockwise]) @node_rewriter([Blockwise])
def reuse_lu_decomposition_multiple_solves(fgraph, node): def reuse_decomposition_multiple_solves(fgraph, node):
return _split_lu_solve_steps( return _split_decomp_and_solve_steps(
fgraph, node, eager=False, allowed_assume_a={"gen", "tridiagonal"} fgraph, node, eager=False, allowed_assume_a={"gen", "tridiagonal", "pos"}
) )
@node_rewriter([Scan]) @node_rewriter([Scan])
def scan_split_non_sequence_lu_decomposition_solve(fgraph, node): def scan_split_non_sequence_decomposition_and_solve(fgraph, node):
return _scan_split_non_sequence_lu_decomposition_solve( return _scan_split_non_sequence_decomposition_and_solve(
fgraph, node, allowed_assume_a={"gen", "tridiagonal"} fgraph, node, allowed_assume_a={"gen", "tridiagonal", "pos"}
) )
scan_seqopt1.register( scan_seqopt1.register(
"scan_split_non_sequence_lu_decomposition_solve", scan_split_non_sequence_decomposition_and_solve.__name__,
in2out(scan_split_non_sequence_lu_decomposition_solve, ignore_newtrees=True), in2out(scan_split_non_sequence_decomposition_and_solve, ignore_newtrees=True),
"fast_run", "fast_run",
"scan", "scan",
"scan_pushout", "scan_pushout",
...@@ -231,28 +252,30 @@ scan_seqopt1.register( ...@@ -231,28 +252,30 @@ scan_seqopt1.register(
@node_rewriter([Blockwise]) @node_rewriter([Blockwise])
def reuse_lu_decomposition_multiple_solves_jax(fgraph, node): def reuse_decomposition_multiple_solves_jax(fgraph, node):
return _split_lu_solve_steps(fgraph, node, eager=False, allowed_assume_a={"gen"}) return _split_decomp_and_solve_steps(
fgraph, node, eager=False, allowed_assume_a={"gen", "pos"}
)
optdb["specialize"].register( optdb["specialize"].register(
reuse_lu_decomposition_multiple_solves_jax.__name__, reuse_decomposition_multiple_solves_jax.__name__,
in2out(reuse_lu_decomposition_multiple_solves_jax, ignore_newtrees=True), in2out(reuse_decomposition_multiple_solves_jax, ignore_newtrees=True),
"jax", "jax",
use_db_name_as_tag=False, use_db_name_as_tag=False,
) )
@node_rewriter([Scan]) @node_rewriter([Scan])
def scan_split_non_sequence_lu_decomposition_solve_jax(fgraph, node): def scan_split_non_sequence_decomposition_and_solve_jax(fgraph, node):
return _scan_split_non_sequence_lu_decomposition_solve( return _scan_split_non_sequence_decomposition_and_solve(
fgraph, node, allowed_assume_a={"gen"} fgraph, node, allowed_assume_a={"gen", "pos"}
) )
scan_seqopt1.register( scan_seqopt1.register(
scan_split_non_sequence_lu_decomposition_solve_jax.__name__, scan_split_non_sequence_decomposition_and_solve_jax.__name__,
in2out(scan_split_non_sequence_lu_decomposition_solve_jax, ignore_newtrees=True), in2out(scan_split_non_sequence_decomposition_and_solve_jax, ignore_newtrees=True),
"jax", "jax",
use_db_name_as_tag=False, use_db_name_as_tag=False,
position=2, position=2,
......
...@@ -6,8 +6,8 @@ from pytensor.compile.mode import get_default_mode ...@@ -6,8 +6,8 @@ from pytensor.compile.mode import get_default_mode
from pytensor.gradient import grad from pytensor.gradient import grad
from pytensor.scan.op import Scan from pytensor.scan.op import Scan
from pytensor.tensor._linalg.solve.rewriting import ( from pytensor.tensor._linalg.solve.rewriting import (
reuse_lu_decomposition_multiple_solves, reuse_decomposition_multiple_solves,
scan_split_non_sequence_lu_decomposition_solve, scan_split_non_sequence_decomposition_and_solve,
) )
from pytensor.tensor._linalg.solve.tridiagonal import ( from pytensor.tensor._linalg.solve.tridiagonal import (
LUFactorTridiagonal, LUFactorTridiagonal,
...@@ -15,62 +15,70 @@ from pytensor.tensor._linalg.solve.tridiagonal import ( ...@@ -15,62 +15,70 @@ from pytensor.tensor._linalg.solve.tridiagonal import (
) )
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.linalg import solve from pytensor.tensor.linalg import solve
from pytensor.tensor.slinalg import LUFactor, Solve, SolveTriangular from pytensor.tensor.slinalg import (
Cholesky,
CholeskySolve,
LUFactor,
Solve,
SolveTriangular,
)
from pytensor.tensor.type import tensor from pytensor.tensor.type import tensor
def count_vanilla_solve_nodes(nodes) -> int: class DecompSolveOpCounter:
return sum( def __init__(self, solve_op, decomp_op, solve_op_value: float = 1.0):
( self.solve_op = solve_op
isinstance(node.op, Solve) self.decomp_op = decomp_op
or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Solve)) self.solve_op_value = solve_op_value
def check_node_op_or_core_op(self, node, op):
return isinstance(node.op, op) or (
isinstance(node.op, Blockwise) and isinstance(node.op.core_op, op)
) )
for node in nodes
)
def count_vanilla_solve_nodes(self, nodes) -> int:
return sum(self.check_node_op_or_core_op(node, Solve) for node in nodes)
def count_lu_decom_nodes(nodes) -> int: def count_decomp_nodes(self, nodes) -> int:
return sum( return sum(
( self.check_node_op_or_core_op(node, self.decomp_op) for node in nodes
isinstance(node.op, LUFactor | LUFactorTridiagonal)
or (
isinstance(node.op, Blockwise)
and isinstance(node.op.core_op, LUFactor | LUFactorTridiagonal)
)
) )
for node in nodes
)
def count_lu_solve_nodes(nodes) -> int: def count_solve_nodes(self, nodes) -> int:
count = sum( count = sum(
( self.solve_op_value * self.check_node_op_or_core_op(node, self.solve_op)
# LUFactor uses 2 SolveTriangular nodes, so we count each as 0.5 for node in nodes
0.5
* (
isinstance(node.op, SolveTriangular)
or (
isinstance(node.op, Blockwise)
and isinstance(node.op.core_op, SolveTriangular)
)
)
or (
isinstance(node.op, SolveLUFactorTridiagonal)
or (
isinstance(node.op, Blockwise)
and isinstance(node.op.core_op, SolveLUFactorTridiagonal)
)
)
) )
for node in nodes return int(count)
)
return int(count)
LUOpCounter = DecompSolveOpCounter(
solve_op=SolveTriangular,
decomp_op=LUFactor,
# Each rewrite introduces two triangular solves, so count them as 1/2 each
solve_op_value=0.5,
)
TriDiagLUOpCounter = DecompSolveOpCounter(
solve_op=SolveLUFactorTridiagonal, decomp_op=LUFactorTridiagonal, solve_op_value=1.0
)
CholeskyOpCounter = DecompSolveOpCounter(
solve_op=CholeskySolve, decomp_op=Cholesky, solve_op_value=1.0
)
@pytest.mark.parametrize("transposed", (False, True)) @pytest.mark.parametrize("transposed", (False, True))
@pytest.mark.parametrize("assume_a", ("gen", "tridiagonal")) @pytest.mark.parametrize(
def test_lu_decomposition_reused_forward_and_gradient(assume_a, transposed): "assume_a, counter",
rewrite_name = reuse_lu_decomposition_multiple_solves.__name__ (
("gen", LUOpCounter),
("tridiagonal", TriDiagLUOpCounter),
("pos", CholeskyOpCounter),
),
)
def test_lu_decomposition_reused_forward_and_gradient(assume_a, counter, transposed):
rewrite_name = reuse_decomposition_multiple_solves.__name__
mode = get_default_mode() mode = get_default_mode()
A = tensor("A", shape=(3, 3)) A = tensor("A", shape=(3, 3))
...@@ -80,19 +88,22 @@ def test_lu_decomposition_reused_forward_and_gradient(assume_a, transposed): ...@@ -80,19 +88,22 @@ def test_lu_decomposition_reused_forward_and_gradient(assume_a, transposed):
grad_x_wrt_A = grad(x.sum(), A) grad_x_wrt_A = grad(x.sum(), A)
fn_no_opt = function([A, b], [x, grad_x_wrt_A], mode=mode.excluding(rewrite_name)) fn_no_opt = function([A, b], [x, grad_x_wrt_A], mode=mode.excluding(rewrite_name))
no_opt_nodes = fn_no_opt.maker.fgraph.apply_nodes no_opt_nodes = fn_no_opt.maker.fgraph.apply_nodes
assert count_vanilla_solve_nodes(no_opt_nodes) == 2 assert counter.count_vanilla_solve_nodes(no_opt_nodes) == 2
assert count_lu_decom_nodes(no_opt_nodes) == 0 assert counter.count_decomp_nodes(no_opt_nodes) == 0
assert count_lu_solve_nodes(no_opt_nodes) == 0 assert counter.count_solve_nodes(no_opt_nodes) == 0
fn_opt = function([A, b], [x, grad_x_wrt_A], mode=mode.including(rewrite_name)) fn_opt = function([A, b], [x, grad_x_wrt_A], mode=mode.including(rewrite_name))
opt_nodes = fn_opt.maker.fgraph.apply_nodes opt_nodes = fn_opt.maker.fgraph.apply_nodes
assert count_vanilla_solve_nodes(opt_nodes) == 0 assert counter.count_vanilla_solve_nodes(opt_nodes) == 0
assert count_lu_decom_nodes(opt_nodes) == 1 assert counter.count_decomp_nodes(opt_nodes) == 1
assert count_lu_solve_nodes(opt_nodes) == 2 assert counter.count_solve_nodes(opt_nodes) == 2
# Make sure results are correct # Make sure results are correct
rng = np.random.default_rng(31) rng = np.random.default_rng(31)
A_test = rng.random(A.type.shape, dtype=A.type.dtype) A_test = rng.random(A.type.shape, dtype=A.type.dtype)
if assume_a == "pos":
A_test = A_test @ A_test.T # Ensure positive definite for Cholesky
b_test = rng.random(b.type.shape, dtype=b.type.dtype) b_test = rng.random(b.type.shape, dtype=b.type.dtype)
resx0, resg0 = fn_no_opt(A_test, b_test) resx0, resg0 = fn_no_opt(A_test, b_test)
resx1, resg1 = fn_opt(A_test, b_test) resx1, resg1 = fn_opt(A_test, b_test)
...@@ -102,9 +113,16 @@ def test_lu_decomposition_reused_forward_and_gradient(assume_a, transposed): ...@@ -102,9 +113,16 @@ def test_lu_decomposition_reused_forward_and_gradient(assume_a, transposed):
@pytest.mark.parametrize("transposed", (False, True)) @pytest.mark.parametrize("transposed", (False, True))
@pytest.mark.parametrize("assume_a", ("gen", "tridiagonal")) @pytest.mark.parametrize(
def test_lu_decomposition_reused_blockwise(assume_a, transposed): "assume_a, counter",
rewrite_name = reuse_lu_decomposition_multiple_solves.__name__ (
("gen", LUOpCounter),
("tridiagonal", TriDiagLUOpCounter),
("pos", CholeskyOpCounter),
),
)
def test_lu_decomposition_reused_blockwise(assume_a, counter, transposed):
rewrite_name = reuse_decomposition_multiple_solves.__name__
mode = get_default_mode() mode = get_default_mode()
A = tensor("A", shape=(3, 3)) A = tensor("A", shape=(3, 3))
...@@ -113,30 +131,40 @@ def test_lu_decomposition_reused_blockwise(assume_a, transposed): ...@@ -113,30 +131,40 @@ def test_lu_decomposition_reused_blockwise(assume_a, transposed):
x = solve(A, b, assume_a=assume_a, transposed=transposed) x = solve(A, b, assume_a=assume_a, transposed=transposed)
fn_no_opt = function([A, b], [x], mode=mode.excluding(rewrite_name)) fn_no_opt = function([A, b], [x], mode=mode.excluding(rewrite_name))
no_opt_nodes = fn_no_opt.maker.fgraph.apply_nodes no_opt_nodes = fn_no_opt.maker.fgraph.apply_nodes
assert count_vanilla_solve_nodes(no_opt_nodes) == 1 assert counter.count_vanilla_solve_nodes(no_opt_nodes) == 1
assert count_lu_decom_nodes(no_opt_nodes) == 0 assert counter.count_decomp_nodes(no_opt_nodes) == 0
assert count_lu_solve_nodes(no_opt_nodes) == 0 assert counter.count_solve_nodes(no_opt_nodes) == 0
fn_opt = function([A, b], [x], mode=mode.including(rewrite_name)) fn_opt = function([A, b], [x], mode=mode.including(rewrite_name))
opt_nodes = fn_opt.maker.fgraph.apply_nodes opt_nodes = fn_opt.maker.fgraph.apply_nodes
assert count_vanilla_solve_nodes(opt_nodes) == 0 assert counter.count_vanilla_solve_nodes(opt_nodes) == 0
assert count_lu_decom_nodes(opt_nodes) == 1 assert counter.count_decomp_nodes(opt_nodes) == 1
assert count_lu_solve_nodes(opt_nodes) == 1 assert counter.count_solve_nodes(opt_nodes) == 1
# Make sure results are correct # Make sure results are correct
rng = np.random.default_rng(31) rng = np.random.default_rng(31)
A_test = rng.random(A.type.shape, dtype=A.type.dtype) A_test = rng.random(A.type.shape, dtype=A.type.dtype)
if assume_a == "pos":
A_test = A_test @ A_test.T # Ensure positive definite for Cholesky
b_test = rng.random(b.type.shape, dtype=b.type.dtype) b_test = rng.random(b.type.shape, dtype=b.type.dtype)
resx0 = fn_no_opt(A_test, b_test) resx0 = fn_no_opt(A_test, b_test)
resx1 = fn_opt(A_test, b_test) resx1 = fn_opt(A_test, b_test)
rtol = rtol = 1e-7 if config.floatX == "float64" else 1e-4 rtol = 1e-7 if config.floatX == "float64" else 1e-4
np.testing.assert_allclose(resx0, resx1, rtol=rtol) np.testing.assert_allclose(resx0, resx1, rtol=rtol)
@pytest.mark.parametrize("transposed", (False, True)) @pytest.mark.parametrize("transposed", (False, True))
@pytest.mark.parametrize("assume_a", ("gen", "tridiagonal")) @pytest.mark.parametrize(
def test_lu_decomposition_reused_scan(assume_a, transposed): "assume_a, counter",
rewrite_name = scan_split_non_sequence_lu_decomposition_solve.__name__ (
("gen", LUOpCounter),
("tridiagonal", TriDiagLUOpCounter),
("pos", CholeskyOpCounter),
),
)
def test_lu_decomposition_reused_scan(assume_a, counter, transposed):
rewrite_name = scan_split_non_sequence_decomposition_and_solve.__name__
mode = get_default_mode() mode = get_default_mode()
A = tensor("A", shape=(3, 3)) A = tensor("A", shape=(3, 3))
...@@ -158,23 +186,26 @@ def test_lu_decomposition_reused_scan(assume_a, transposed): ...@@ -158,23 +186,26 @@ def test_lu_decomposition_reused_scan(assume_a, transposed):
node for node in fn_no_opt.maker.fgraph.apply_nodes if isinstance(node.op, Scan) node for node in fn_no_opt.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
] ]
no_opt_nodes = no_opt_scan_node.op.fgraph.apply_nodes no_opt_nodes = no_opt_scan_node.op.fgraph.apply_nodes
assert count_vanilla_solve_nodes(no_opt_nodes) == 1 assert counter.count_vanilla_solve_nodes(no_opt_nodes) == 1
assert count_lu_decom_nodes(no_opt_nodes) == 0 assert counter.count_decomp_nodes(no_opt_nodes) == 0
assert count_lu_solve_nodes(no_opt_nodes) == 0 assert counter.count_solve_nodes(no_opt_nodes) == 0
fn_opt = function([A, x0], [xs], mode=mode.including("scan", rewrite_name)) fn_opt = function([A, x0], [xs], mode=mode.including("scan", rewrite_name))
[opt_scan_node] = [ [opt_scan_node] = [
node for node in fn_opt.maker.fgraph.apply_nodes if isinstance(node.op, Scan) node for node in fn_opt.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
] ]
opt_nodes = opt_scan_node.op.fgraph.apply_nodes opt_nodes = opt_scan_node.op.fgraph.apply_nodes
assert count_vanilla_solve_nodes(opt_nodes) == 0 assert counter.count_vanilla_solve_nodes(opt_nodes) == 0
# The LU decomp is outside of the scan! # The LU decomp is outside of the scan!
assert count_lu_decom_nodes(opt_nodes) == 0 assert counter.count_decomp_nodes(opt_nodes) == 0
assert count_lu_solve_nodes(opt_nodes) == 1 assert counter.count_solve_nodes(opt_nodes) == 1
# Make sure results are correct # Make sure results are correct
rng = np.random.default_rng(170) rng = np.random.default_rng(170)
A_test = rng.random(A.type.shape, dtype=A.type.dtype) A_test = rng.random(A.type.shape, dtype=A.type.dtype)
if assume_a == "pos":
A_test = A_test @ A_test.T # Ensure positive definite for Cholesky
x0_test = rng.random(x0.type.shape, dtype=x0.type.dtype) x0_test = rng.random(x0.type.shape, dtype=x0.type.dtype)
resx0 = fn_no_opt(A_test, x0_test) resx0 = fn_no_opt(A_test, x0_test)
resx1 = fn_opt(A_test, x0_test) resx1 = fn_opt(A_test, x0_test)
...@@ -182,23 +213,30 @@ def test_lu_decomposition_reused_scan(assume_a, transposed): ...@@ -182,23 +213,30 @@ def test_lu_decomposition_reused_scan(assume_a, transposed):
np.testing.assert_allclose(resx0, resx1, rtol=rtol) np.testing.assert_allclose(resx0, resx1, rtol=rtol)
def test_lu_decomposition_reused_preserves_check_finite(): @pytest.mark.parametrize(
"assume_a, counter",
(
("gen", LUOpCounter),
("pos", CholeskyOpCounter),
),
)
def test_decomposition_reused_preserves_check_finite(assume_a, counter):
# Check that the LU decomposition rewrite preserves the check_finite flag # Check that the LU decomposition rewrite preserves the check_finite flag
rewrite_name = reuse_lu_decomposition_multiple_solves.__name__ rewrite_name = reuse_decomposition_multiple_solves.__name__
A = tensor("A", shape=(2, 2)) A = tensor("A", shape=(2, 2))
b1 = tensor("b1", shape=(2,)) b1 = tensor("b1", shape=(2,))
b2 = tensor("b2", shape=(2,)) b2 = tensor("b2", shape=(2,))
x1 = solve(A, b1, assume_a="gen", check_finite=True) x1 = solve(A, b1, assume_a=assume_a, check_finite=True)
x2 = solve(A, b2, assume_a="gen", check_finite=False) x2 = solve(A, b2, assume_a=assume_a, check_finite=False)
fn_opt = function( fn_opt = function(
[A, b1, b2], [x1, x2], mode=get_default_mode().including(rewrite_name) [A, b1, b2], [x1, x2], mode=get_default_mode().including(rewrite_name)
) )
opt_nodes = fn_opt.maker.fgraph.apply_nodes opt_nodes = fn_opt.maker.fgraph.apply_nodes
assert count_vanilla_solve_nodes(opt_nodes) == 0 assert counter.count_vanilla_solve_nodes(opt_nodes) == 0
assert count_lu_decom_nodes(opt_nodes) == 1 assert counter.count_decomp_nodes(opt_nodes) == 1
assert count_lu_solve_nodes(opt_nodes) == 2 assert counter.count_solve_nodes(opt_nodes) == 2
# We should get an error if A or b1 is non finite # We should get an error if A or b1 is non finite
A_valid = np.array([[1, 0], [0, 1]], dtype=A.type.dtype) A_valid = np.array([[1, 0], [0, 1]], dtype=A.type.dtype)
......
...@@ -581,7 +581,7 @@ class TestInplace: ...@@ -581,7 +581,7 @@ class TestInplace:
mode = get_default_mode().excluding( mode = get_default_mode().excluding(
"batched_vector_b_solve_to_matrix_b_solve", "batched_vector_b_solve_to_matrix_b_solve",
"reuse_lu_decomposition_multiple_solves", "reuse_decomposition_multiple_solves",
) )
fn = function([In(A, mutable=True), In(b, mutable=True)], x, mode=mode) fn = function([In(A, mutable=True), In(b, mutable=True)], x, mode=mode)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论