提交 0fd8315f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix contiguity bugs in Numba lapack routines

Also removes redundant tests
上级 a149f6c9
...@@ -26,6 +26,12 @@ from pytensor.tensor.slinalg import ( ...@@ -26,6 +26,12 @@ from pytensor.tensor.slinalg import (
) )
@numba_basic.numba_njit(inline="always")
def _copy_to_fortran_order_even_if_1d(x):
# Numba's _copy_to_fortran_order doesn't do anything for vectors
return x.copy() if x.ndim == 1 else _copy_to_fortran_order(x)
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit(inline="always")
def _solve_check(n, info, lamch=False, rcond=None): def _solve_check(n, info, lamch=False, rcond=None):
""" """
...@@ -132,18 +138,13 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b ...@@ -132,18 +138,13 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
# This will only copy if A is not already fortran contiguous # This will only copy if A is not already fortran contiguous
A_f = np.asfortranarray(A) A_f = np.asfortranarray(A)
if overwrite_b: if overwrite_b and B.flags.f_contiguous:
if B_is_1d: B_copy = B
B_copy = np.expand_dims(B, -1)
else:
# This *will* allow inplace destruction of B, but only if it is already fortran contiguous.
# Otherwise, there's no way to get around the need to copy the data before going into TRTRS
B_copy = np.asfortranarray(B)
else: else:
if B_is_1d: B_copy = _copy_to_fortran_order_even_if_1d(B)
B_copy = np.copy(np.expand_dims(B, -1))
else: if B_is_1d:
B_copy = _copy_to_fortran_order(B) B_copy = np.expand_dims(B_copy, -1)
NRHS = 1 if B_is_1d else int(B_copy.shape[-1]) NRHS = 1 if B_is_1d else int(B_copy.shape[-1])
...@@ -247,10 +248,10 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True): ...@@ -247,10 +248,10 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
LDA = val_to_int_ptr(_N) LDA = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0) INFO = val_to_int_ptr(0)
if not overwrite_a: if overwrite_a and A.flags.f_contiguous:
A_copy = _copy_to_fortran_order(A)
else:
A_copy = A A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
numba_potrf( numba_potrf(
UPLO, UPLO,
...@@ -283,7 +284,7 @@ def numba_funcify_Cholesky(op, node, **kwargs): ...@@ -283,7 +284,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
In particular, the `inplace` argument is not supported, which is why we choose to implement our own version. In particular, the `inplace` argument is not supported, which is why we choose to implement our own version.
""" """
lower = op.lower lower = op.lower
overwrite_a = False overwrite_a = op.overwrite_a
check_finite = op.check_finite check_finite = op.check_finite
on_error = op.on_error on_error = op.on_error
...@@ -497,10 +498,10 @@ def getrf_impl( ...@@ -497,10 +498,10 @@ def getrf_impl(
) -> tuple[np.ndarray, np.ndarray, int]: ) -> tuple[np.ndarray, np.ndarray, int]:
_M, _N = np.int32(A.shape[-2:]) # type: ignore _M, _N = np.int32(A.shape[-2:]) # type: ignore
if not overwrite_a: if overwrite_a and A.flags.f_contiguous:
A_copy = _copy_to_fortran_order(A)
else:
A_copy = A A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
M = val_to_int_ptr(_M) # type: ignore M = val_to_int_ptr(_M) # type: ignore
N = val_to_int_ptr(_N) # type: ignore N = val_to_int_ptr(_N) # type: ignore
...@@ -545,10 +546,10 @@ def getrs_impl( ...@@ -545,10 +546,10 @@ def getrs_impl(
B_is_1d = B.ndim == 1 B_is_1d = B.ndim == 1
if not overwrite_b: if overwrite_b and B.flags.f_contiguous:
B_copy = _copy_to_fortran_order(B)
else:
B_copy = B B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)
if B_is_1d: if B_is_1d:
B_copy = np.expand_dims(B_copy, -1) B_copy = np.expand_dims(B_copy, -1)
...@@ -576,7 +577,7 @@ def getrs_impl( ...@@ -576,7 +577,7 @@ def getrs_impl(
) )
if B_is_1d: if B_is_1d:
return B_copy[..., 0], int_ptr_to_val(INFO) B_copy = B_copy[..., 0]
return B_copy, int_ptr_to_val(INFO) return B_copy, int_ptr_to_val(INFO)
...@@ -681,19 +682,20 @@ def sysv_impl( ...@@ -681,19 +682,20 @@ def sysv_impl(
_LDA, _N = np.int32(A.shape[-2:]) # type: ignore _LDA, _N = np.int32(A.shape[-2:]) # type: ignore
_solve_check_input_shapes(A, B) _solve_check_input_shapes(A, B)
if not overwrite_a: if overwrite_a and A.flags.f_contiguous:
A_copy = _copy_to_fortran_order(A)
else:
A_copy = A A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
B_is_1d = B.ndim == 1 B_is_1d = B.ndim == 1
if not overwrite_b: if overwrite_b and B.flags.f_contiguous:
B_copy = _copy_to_fortran_order(B)
else:
B_copy = B B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)
if B_is_1d: if B_is_1d:
B_copy = np.asfortranarray(np.expand_dims(B_copy, -1)) B_copy = np.expand_dims(B_copy, -1)
NRHS = 1 if B_is_1d else int(B.shape[-1]) NRHS = 1 if B_is_1d else int(B.shape[-1])
...@@ -903,17 +905,17 @@ def posv_impl( ...@@ -903,17 +905,17 @@ def posv_impl(
_N = np.int32(A.shape[-1]) _N = np.int32(A.shape[-1])
if not overwrite_a: if overwrite_a and A.flags.f_contiguous:
A_copy = _copy_to_fortran_order(A)
else:
A_copy = A A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
B_is_1d = B.ndim == 1 B_is_1d = B.ndim == 1
if not overwrite_b: if overwrite_b and B.flags.f_contiguous:
B_copy = _copy_to_fortran_order(B)
else:
B_copy = B B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)
if B_is_1d: if B_is_1d:
B_copy = np.expand_dims(B_copy, -1) B_copy = np.expand_dims(B_copy, -1)
...@@ -1102,12 +1104,15 @@ def numba_funcify_Solve(op, node, **kwargs): ...@@ -1102,12 +1104,15 @@ def numba_funcify_Solve(op, node, **kwargs):
return solve return solve
def _cho_solve(A_and_lower, B, overwrite_a=False, overwrite_b=False, check_finite=True): def _cho_solve(
C: np.ndarray, B: np.ndarray, lower: bool, overwrite_b: bool, check_finite: bool
):
""" """
Solve a positive-definite linear system using the Cholesky decomposition. Solve a positive-definite linear system using the Cholesky decomposition.
""" """
A, lower = A_and_lower return linalg.cho_solve(
return linalg.cho_solve((A, lower), B) (C, lower), b=B, overwrite_b=overwrite_b, check_finite=check_finite
)
@overload(_cho_solve) @overload(_cho_solve)
...@@ -1123,13 +1128,16 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True): ...@@ -1123,13 +1128,16 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
_solve_check_input_shapes(C, B) _solve_check_input_shapes(C, B)
_N = np.int32(C.shape[-1]) _N = np.int32(C.shape[-1])
C_copy = _copy_to_fortran_order(C) C_f = np.asfortranarray(C)
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)
B_is_1d = B.ndim == 1 B_is_1d = B.ndim == 1
if B_is_1d: if B_is_1d:
B_copy = np.asfortranarray(np.expand_dims(B, -1)) B_copy = np.expand_dims(B_copy, -1)
else:
B_copy = _copy_to_fortran_order(B)
NRHS = 1 if B_is_1d else int(B.shape[-1]) NRHS = 1 if B_is_1d else int(B.shape[-1])
...@@ -1144,16 +1152,18 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True): ...@@ -1144,16 +1152,18 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
UPLO, UPLO,
N, N,
NRHS, NRHS,
C_copy.view(w_type).ctypes, C_f.view(w_type).ctypes,
LDA, LDA,
B_copy.view(w_type).ctypes, B_copy.view(w_type).ctypes,
LDB, LDB,
INFO, INFO,
) )
_solve_check(_N, int_ptr_to_val(INFO))
if B_is_1d: if B_is_1d:
return B_copy[..., 0], int_ptr_to_val(INFO) return B_copy[..., 0]
return B_copy, int_ptr_to_val(INFO) return B_copy
return impl return impl
...@@ -1182,16 +1192,8 @@ def numba_funcify_CholeskySolve(op, node, **kwargs): ...@@ -1182,16 +1192,8 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
"Non-numeric values (nan or inf) in input b to cho_solve" "Non-numeric values (nan or inf) in input b to cho_solve"
) )
res, info = _cho_solve( return _cho_solve(
c, b, lower=lower, overwrite_b=overwrite_b, check_finite=check_finite c, b, lower=lower, overwrite_b=overwrite_b, check_finite=check_finite
) )
if info < 0:
raise np.linalg.LinAlgError("Illegal values found in input to cho_solve")
elif info > 0:
raise np.linalg.LinAlgError(
"Matrix is not positive definite in input to cho_solve"
)
return res
return cho_solve return cho_solve
...@@ -7,6 +7,7 @@ from unittest import mock ...@@ -7,6 +7,7 @@ from unittest import mock
import numpy as np import numpy as np
import pytest import pytest
from pytensor.compile import SymbolicInput
from tests.tensor.test_math_scipy import scipy from tests.tensor.test_math_scipy import scipy
...@@ -120,6 +121,7 @@ opts = RewriteDatabaseQuery( ...@@ -120,6 +121,7 @@ opts = RewriteDatabaseQuery(
numba_mode = Mode( numba_mode = Mode(
NumbaLinker(), opts.including("numba", "local_useless_unbatched_blockwise") NumbaLinker(), opts.including("numba", "local_useless_unbatched_blockwise")
) )
numba_inplace_mode = numba_mode.including("inplace")
py_mode = Mode("py", opts) py_mode = Mode("py", opts)
rng = np.random.default_rng(42849) rng = np.random.default_rng(42849)
...@@ -261,7 +263,11 @@ def compare_numba_and_py( ...@@ -261,7 +263,11 @@ def compare_numba_and_py(
x, y x, y
) )
if any(inp.owner is not None for inp in graph_inputs): if any(
inp.owner is not None
for inp in graph_inputs
if not isinstance(inp, SymbolicInput)
):
raise ValueError("Inputs must be root variables") raise ValueError("Inputs must be root variables")
pytensor_py_fn = function( pytensor_py_fn = function(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论