提交 0fd8315f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix contiguity bugs in Numba lapack routines

Also removes redundant tests
上级 a149f6c9
......@@ -26,6 +26,12 @@ from pytensor.tensor.slinalg import (
)
@numba_basic.numba_njit(inline="always")
def _copy_to_fortran_order_even_if_1d(x):
# Numba's _copy_to_fortran_order doesn't do anything for vectors
return x.copy() if x.ndim == 1 else _copy_to_fortran_order(x)
@numba_basic.numba_njit(inline="always")
def _solve_check(n, info, lamch=False, rcond=None):
"""
......@@ -132,18 +138,13 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
# This will only copy if A is not already fortran contiguous
A_f = np.asfortranarray(A)
if overwrite_b:
if B_is_1d:
B_copy = np.expand_dims(B, -1)
else:
# This *will* allow inplace destruction of B, but only if it is already fortran contiguous.
# Otherwise, there's no way to get around the need to copy the data before going into TRTRS
B_copy = np.asfortranarray(B)
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)
if B_is_1d:
B_copy = np.copy(np.expand_dims(B, -1))
else:
B_copy = _copy_to_fortran_order(B)
B_copy = np.expand_dims(B_copy, -1)
NRHS = 1 if B_is_1d else int(B_copy.shape[-1])
......@@ -247,10 +248,10 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
LDA = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)
if not overwrite_a:
A_copy = _copy_to_fortran_order(A)
else:
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
numba_potrf(
UPLO,
......@@ -283,7 +284,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
In particular, the `inplace` argument is not supported, which is why we choose to implement our own version.
"""
lower = op.lower
overwrite_a = False
overwrite_a = op.overwrite_a
check_finite = op.check_finite
on_error = op.on_error
......@@ -497,10 +498,10 @@ def getrf_impl(
) -> tuple[np.ndarray, np.ndarray, int]:
_M, _N = np.int32(A.shape[-2:]) # type: ignore
if not overwrite_a:
A_copy = _copy_to_fortran_order(A)
else:
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
M = val_to_int_ptr(_M) # type: ignore
N = val_to_int_ptr(_N) # type: ignore
......@@ -545,10 +546,10 @@ def getrs_impl(
B_is_1d = B.ndim == 1
if not overwrite_b:
B_copy = _copy_to_fortran_order(B)
else:
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)
if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)
......@@ -576,7 +577,7 @@ def getrs_impl(
)
if B_is_1d:
return B_copy[..., 0], int_ptr_to_val(INFO)
B_copy = B_copy[..., 0]
return B_copy, int_ptr_to_val(INFO)
......@@ -681,19 +682,20 @@ def sysv_impl(
_LDA, _N = np.int32(A.shape[-2:]) # type: ignore
_solve_check_input_shapes(A, B)
if not overwrite_a:
A_copy = _copy_to_fortran_order(A)
else:
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
B_is_1d = B.ndim == 1
if not overwrite_b:
B_copy = _copy_to_fortran_order(B)
else:
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)
if B_is_1d:
B_copy = np.asfortranarray(np.expand_dims(B_copy, -1))
B_copy = np.expand_dims(B_copy, -1)
NRHS = 1 if B_is_1d else int(B.shape[-1])
......@@ -903,17 +905,17 @@ def posv_impl(
_N = np.int32(A.shape[-1])
if not overwrite_a:
A_copy = _copy_to_fortran_order(A)
else:
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
B_is_1d = B.ndim == 1
if not overwrite_b:
B_copy = _copy_to_fortran_order(B)
else:
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)
if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)
......@@ -1102,12 +1104,15 @@ def numba_funcify_Solve(op, node, **kwargs):
return solve
def _cho_solve(A_and_lower, B, overwrite_a=False, overwrite_b=False, check_finite=True):
def _cho_solve(
C: np.ndarray, B: np.ndarray, lower: bool, overwrite_b: bool, check_finite: bool
):
"""
Solve a positive-definite linear system using the Cholesky decomposition.
"""
A, lower = A_and_lower
return linalg.cho_solve((A, lower), B)
return linalg.cho_solve(
(C, lower), b=B, overwrite_b=overwrite_b, check_finite=check_finite
)
@overload(_cho_solve)
......@@ -1123,13 +1128,16 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
_solve_check_input_shapes(C, B)
_N = np.int32(C.shape[-1])
C_copy = _copy_to_fortran_order(C)
C_f = np.asfortranarray(C)
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)
B_is_1d = B.ndim == 1
if B_is_1d:
B_copy = np.asfortranarray(np.expand_dims(B, -1))
else:
B_copy = _copy_to_fortran_order(B)
B_copy = np.expand_dims(B_copy, -1)
NRHS = 1 if B_is_1d else int(B.shape[-1])
......@@ -1144,16 +1152,18 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
UPLO,
N,
NRHS,
C_copy.view(w_type).ctypes,
C_f.view(w_type).ctypes,
LDA,
B_copy.view(w_type).ctypes,
LDB,
INFO,
)
_solve_check(_N, int_ptr_to_val(INFO))
if B_is_1d:
return B_copy[..., 0], int_ptr_to_val(INFO)
return B_copy, int_ptr_to_val(INFO)
return B_copy[..., 0]
return B_copy
return impl
......@@ -1182,16 +1192,8 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
"Non-numeric values (nan or inf) in input b to cho_solve"
)
res, info = _cho_solve(
return _cho_solve(
c, b, lower=lower, overwrite_b=overwrite_b, check_finite=check_finite
)
if info < 0:
raise np.linalg.LinAlgError("Illegal values found in input to cho_solve")
elif info > 0:
raise np.linalg.LinAlgError(
"Matrix is not positive definite in input to cho_solve"
)
return res
return cho_solve
......@@ -7,6 +7,7 @@ from unittest import mock
import numpy as np
import pytest
from pytensor.compile import SymbolicInput
from tests.tensor.test_math_scipy import scipy
......@@ -120,6 +121,7 @@ opts = RewriteDatabaseQuery(
numba_mode = Mode(
NumbaLinker(), opts.including("numba", "local_useless_unbatched_blockwise")
)
numba_inplace_mode = numba_mode.including("inplace")
py_mode = Mode("py", opts)
rng = np.random.default_rng(42849)
......@@ -261,7 +263,11 @@ def compare_numba_and_py(
x, y
)
if any(inp.owner is not None for inp in graph_inputs):
if any(
inp.owner is not None
for inp in graph_inputs
if not isinstance(inp, SymbolicInput)
):
raise ValueError("Inputs must be root variables")
pytensor_py_fn = function(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论