提交 148477cb authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Respect check_finite in LU decomposition rewrites

上级 6fb515d0
...@@ -14,16 +14,22 @@ from pytensor.tensor.slinalg import Solve, lu_factor, lu_solve ...@@ -14,16 +14,22 @@ from pytensor.tensor.slinalg import Solve, lu_factor, lu_solve
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
def decompose_A(A, assume_a): def decompose_A(A, assume_a, check_finite):
if assume_a == "gen": if assume_a == "gen":
return lu_factor(A, check_finite=False) return lu_factor(A, check_finite=check_finite)
else: else:
raise NotImplementedError raise NotImplementedError
def solve_lu_decomposed_system(A_decomp, b, b_ndim, assume_a, transposed=False): def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op: Solve):
if assume_a == "gen": if core_solve_op.assume_a == "gen":
return lu_solve(A_decomp, b, b_ndim=b_ndim, trans=transposed) return lu_solve(
A_decomp,
b,
trans=transposed,
b_ndim=core_solve_op.b_ndim,
check_finite=core_solve_op.check_finite,
)
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -102,14 +108,19 @@ def _split_lu_solve_steps( ...@@ -102,14 +108,19 @@ def _split_lu_solve_steps(
): ):
return None return None
A_decomp = decompose_A(A, assume_a=assume_a) # If any Op had check_finite=True, we also do it for the LU decomposition
check_finite_decomp = False
for client, _ in A_solve_clients_and_transpose:
if client.op.core_op.check_finite:
check_finite_decomp = True
break
A_decomp = decompose_A(A, assume_a=assume_a, check_finite=check_finite_decomp)
replacements = {} replacements = {}
for client, transposed in A_solve_clients_and_transpose: for client, transposed in A_solve_clients_and_transpose:
_, b = client.inputs _, b = client.inputs
b_ndim = client.op.core_op.b_ndim
new_x = solve_lu_decomposed_system( new_x = solve_lu_decomposed_system(
A_decomp, b, b_ndim=b_ndim, assume_a=assume_a, transposed=transposed A_decomp, b, transposed=transposed, core_solve_op=client.op.core_op
) )
[old_x] = client.outputs [old_x] = client.outputs
new_x = atleast_Nd(new_x, n=old_x.type.ndim).astype(old_x.type.dtype) new_x = atleast_Nd(new_x, n=old_x.type.ndim).astype(old_x.type.dtype)
......
...@@ -793,7 +793,7 @@ def tensor( ...@@ -793,7 +793,7 @@ def tensor(
try: try:
# Help catching errors with the new tensor API # Help catching errors with the new tensor API
# Many single letter strings are valid sctypes # Many single letter strings are valid sctypes
if str(name) == "floatX" or (len(str(name)) > 1 and np.dtype(name).type): if str(name) == "floatX" or (len(str(name)) > 2 and np.dtype(name).type):
raise ValueError( raise ValueError(
f"The first and only positional argument of tensor is now `name`. Got {name}.\n" f"The first and only positional argument of tensor is now `name`. Got {name}.\n"
"This name looks like a dtype, which you should pass as a keyword argument only." "This name looks like a dtype, which you should pass as a keyword argument only."
......
...@@ -161,3 +161,36 @@ def test_lu_decomposition_reused_scan(transposed): ...@@ -161,3 +161,36 @@ def test_lu_decomposition_reused_scan(transposed):
resx1 = fn_opt(A_test, x0_test) resx1 = fn_opt(A_test, x0_test)
rtol = 1e-7 if config.floatX == "float64" else 1e-6 rtol = 1e-7 if config.floatX == "float64" else 1e-6
np.testing.assert_allclose(resx0, resx1, rtol=rtol) np.testing.assert_allclose(resx0, resx1, rtol=rtol)
def test_lu_decomposition_reused_preserves_check_finite():
# Check that the LU decomposition rewrite preserves the check_finite flag
rewrite_name = reuse_lu_decomposition_multiple_solves.__name__
A = tensor("A", shape=(2, 2))
b1 = tensor("b1", shape=(2,))
b2 = tensor("b2", shape=(2,))
x1 = solve(A, b1, assume_a="gen", check_finite=True)
x2 = solve(A, b2, assume_a="gen", check_finite=False)
fn_opt = function(
[A, b1, b2], [x1, x2], mode=get_default_mode().including(rewrite_name)
)
opt_nodes = fn_opt.maker.fgraph.apply_nodes
assert count_vanilla_solve_nodes(opt_nodes) == 0
assert count_lu_decom_nodes(opt_nodes) == 1
assert count_lu_solve_nodes(opt_nodes) == 2
# We should get an error if A or b1 is non finite
A_valid = np.array([[1, 0], [0, 1]], dtype=A.type.dtype)
b1_valid = np.array([1, 1], dtype=b1.type.dtype)
b2_valid = np.array([1, 1], dtype=b2.type.dtype)
assert fn_opt(A_valid, b1_valid, b2_valid) # Fine
assert fn_opt(
A_valid, b1_valid, b2_valid * np.nan
) # Should not raise (also fine on most LAPACK implementations?)
with pytest.raises(ValueError, match="array must not contain infs or NaNs"):
assert fn_opt(A_valid, b1_valid * np.nan, b2_valid)
with pytest.raises(ValueError, match="array must not contain infs or NaNs"):
assert fn_opt(A_valid * np.nan, b1_valid, b2_valid)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论