提交 e98cbbcf authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: Jesse Grabowski

Numba dispatch for LU ops

上级 679b2f71
......@@ -76,7 +76,7 @@ def numba_njit(*args, fastmath=None, **kwargs):
message=(
"(\x1b\\[1m)*" # ansi escape code for bold text
"Cannot cache compiled function "
'"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve)" '
'"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve|lu_factor)" '
"as it uses dynamic globals"
),
category=NumbaWarning,
......
from collections.abc import Callable
from typing import cast as typing_cast
import numpy as np
from numba import njit as numba_njit
from numba.core.extending import overload
from numba.np.linalg import ensure_lapack
from scipy import linalg
from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _getrf
from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix
@numba_njit
def _pivot_to_permutation(p, dtype):
p_inv = np.arange(len(p)).astype(dtype)
for i in range(len(p)):
p_inv[i], p_inv[p[i]] = p_inv[p[i]], p_inv[i]
return p_inv
@numba_njit
def _lu_factor_to_lu(a, dtype, overwrite_a):
A_copy, IPIV, INFO = _getrf(a, overwrite_a=overwrite_a)
L = np.eye(A_copy.shape[-1], dtype=dtype)
L += np.tril(A_copy, k=-1)
U = np.triu(A_copy)
# Fortran is 1 indexed, so we need to subtract 1 from the IPIV array
IPIV = IPIV - 1
p_inv = _pivot_to_permutation(IPIV, dtype=dtype)
perm = np.argsort(p_inv)
return perm, L, U
def _lu_1(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor.
Called when permute_l is True and p_indices is False, and returns a tuple of (perm, L, U), where perm an integer
array of row swaps, such that L[perm] @ U = A.
"""
return typing_cast(
tuple[np.ndarray, np.ndarray, np.ndarray],
linalg.lu(
a,
permute_l=permute_l,
check_finite=check_finite,
p_indices=p_indices,
overwrite_a=overwrite_a,
),
)
def _lu_2(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray]:
"""
Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor.
Called when permute_l is False and p_indices is True, and returns a tuple of (PL, U), where PL is the
permuted L matrix, PL = P @ L.
"""
return typing_cast(
tuple[np.ndarray, np.ndarray],
linalg.lu(
a,
permute_l=permute_l,
check_finite=check_finite,
p_indices=p_indices,
overwrite_a=overwrite_a,
),
)
def _lu_3(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor.
Called when permute_l is False and p_indices is False, and returns a tuple of (P, L, U), where P is the permutation
matrix, P @ L @ U = A.
"""
return typing_cast(
tuple[np.ndarray, np.ndarray, np.ndarray],
linalg.lu(
a,
permute_l=permute_l,
check_finite=check_finite,
p_indices=p_indices,
overwrite_a=overwrite_a,
),
)
@overload(_lu_1)
def lu_impl_1(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> Callable[
[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray]
]:
"""
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
False. Returns a tuple of (perm, L, U), where perm an integer array of row swaps, such that L[perm] @ U = A.
"""
ensure_lapack()
_check_scipy_linalg_matrix(a, "lu")
dtype = a.dtype
def impl(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
perm, L, U = _lu_factor_to_lu(a, dtype, overwrite_a)
return perm, L, U
return impl
@overload(_lu_2)
def lu_impl_2(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> Callable[[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray]]:
"""
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is False and p_indices is
True. Returns a tuple of (PL, U), where PL is the permuted L matrix, PL = P @ L.
"""
ensure_lapack()
_check_scipy_linalg_matrix(a, "lu")
dtype = a.dtype
def impl(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray]:
perm, L, U = _lu_factor_to_lu(a, dtype, overwrite_a)
PL = L[perm]
return PL, U
return impl
@overload(_lu_3)
def lu_impl_3(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> Callable[
[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray]
]:
"""
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
False. Returns a tuple of (P, L, U), such that P @ L @ U = A.
"""
ensure_lapack()
_check_scipy_linalg_matrix(a, "lu")
dtype = a.dtype
def impl(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
perm, L, U = _lu_factor_to_lu(a, dtype, overwrite_a)
P = np.eye(a.shape[-1], dtype=dtype)[perm]
return P, L, U
return impl
from collections.abc import Callable
import numpy as np
from numba.core.extending import overload
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix,
)
def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]:
"""
Underlying LAPACK function used for LU factorization. Compared to scipy.linalg.lu_factorize, this function also
returns an info code with diagnostic information.
"""
(getrf,) = linalg.get_lapack_funcs("getrf", (A,))
A_copy, ipiv, info = getrf(A, overwrite_a=overwrite_a)
return A_copy, ipiv, info
@overload(_getrf)
def getrf_impl(
A: np.ndarray, overwrite_a: bool = False
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray, int]]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "getrf")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_getrf = _LAPACK().numba_xgetrf(dtype)
def impl(
A: np.ndarray, overwrite_a: bool = False
) -> tuple[np.ndarray, np.ndarray, int]:
_M, _N = np.int32(A.shape[-2:]) # type: ignore
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
M = val_to_int_ptr(_M) # type: ignore
N = val_to_int_ptr(_N) # type: ignore
LDA = val_to_int_ptr(_M) # type: ignore
IPIV = np.empty(_N, dtype=np.int32) # type: ignore
INFO = val_to_int_ptr(0)
numba_getrf(M, N, A_copy.view(w_type).ctypes, LDA, IPIV.ctypes, INFO)
return A_copy, IPIV, int_ptr_to_val(INFO)
return impl
def _lu_factor(A: np.ndarray, overwrite_a: bool = False):
"""
Thin wrapper around scipy.linalg.lu_factor. Used as an overload target to avoid side-effects on users who import
Pytensor.
"""
return linalg.lu_factor(A, overwrite_a=overwrite_a)
@overload(_lu_factor)
def lu_factor_impl(
A: np.ndarray, overwrite_a: bool = False
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray]]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "lu_factor")
def impl(A: np.ndarray, overwrite_a: bool = False) -> tuple[np.ndarray, np.ndarray]:
A_copy, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a)
IPIV -= 1 # LAPACK uses 1-based indexing, convert to 0-based
if INFO != 0:
raise np.linalg.LinAlgError("LU decomposition failed")
return A_copy, IPIV
return impl
......@@ -11,13 +11,13 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _getrf
from pytensor.link.numba.dispatch.linalg.solve.lu_solve import _getrs
from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix,
_copy_to_fortran_order_even_if_1d,
_solve_check,
_trans_char_to_int,
)
......@@ -72,116 +72,6 @@ def xgecon_impl(
return impl
def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]:
"""
Placeholder for LU factorization; used by linalg.solve.
# TODO: Implement an LU_factor Op, then dispatch to this function in numba mode.
"""
return # type: ignore
@overload(_getrf)
def getrf_impl(
A: np.ndarray, overwrite_a: bool = False
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray, int]]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "getrf")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_getrf = _LAPACK().numba_xgetrf(dtype)
def impl(
A: np.ndarray, overwrite_a: bool = False
) -> tuple[np.ndarray, np.ndarray, int]:
_M, _N = np.int32(A.shape[-2:]) # type: ignore
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
M = val_to_int_ptr(_M) # type: ignore
N = val_to_int_ptr(_N) # type: ignore
LDA = val_to_int_ptr(_M) # type: ignore
IPIV = np.empty(_N, dtype=np.int32) # type: ignore
INFO = val_to_int_ptr(0)
numba_getrf(M, N, A_copy.view(w_type).ctypes, LDA, IPIV.ctypes, INFO)
return A_copy, IPIV, int_ptr_to_val(INFO)
return impl
def _getrs(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
) -> tuple[np.ndarray, int]:
"""
Placeholder for solving a linear system with a matrix that has been LU-factored; used by linalg.solve.
# TODO: Implement an LU_solve Op, then dispatch to this function in numba mode.
"""
return # type: ignore
@overload(_getrs)
def getrs_impl(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, int, bool], tuple[np.ndarray, int]]:
ensure_lapack()
_check_scipy_linalg_matrix(LU, "getrs")
_check_scipy_linalg_matrix(B, "getrs")
dtype = LU.dtype
w_type = _get_underlying_float(dtype)
numba_getrs = _LAPACK().numba_xgetrs(dtype)
def impl(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
) -> tuple[np.ndarray, int]:
_N = np.int32(LU.shape[-1])
_solve_check_input_shapes(LU, B)
B_is_1d = B.ndim == 1
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)
if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)
NRHS = 1 if B_is_1d else int(B_copy.shape[-1])
TRANS = val_to_int_ptr(_trans_char_to_int(trans))
N = val_to_int_ptr(_N)
NRHS = val_to_int_ptr(NRHS)
LDA = val_to_int_ptr(_N)
LDB = val_to_int_ptr(_N)
IPIV = _copy_to_fortran_order(IPIV)
INFO = val_to_int_ptr(0)
numba_getrs(
TRANS,
N,
NRHS,
LU.view(w_type).ctypes,
LDA,
IPIV.ctypes,
B_copy.view(w_type).ctypes,
LDB,
INFO,
)
if B_is_1d:
B_copy = B_copy[..., 0]
return B_copy, int_ptr_to_val(INFO)
return impl
def _solve_gen(
A: np.ndarray,
B: np.ndarray,
......
from collections.abc import Callable
import numpy as np
from numba.core.extending import overload
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix,
_copy_to_fortran_order_even_if_1d,
_solve_check,
_trans_char_to_int,
)
def _getrs(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
) -> tuple[np.ndarray, int]:
"""
Placeholder for solving a linear system with a matrix that has been LU-factored. Used by linalg.lu_solve.
"""
return # type: ignore
@overload(_getrs)
def getrs_impl(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, int, bool], tuple[np.ndarray, int]]:
ensure_lapack()
_check_scipy_linalg_matrix(LU, "getrs")
_check_scipy_linalg_matrix(B, "getrs")
dtype = LU.dtype
w_type = _get_underlying_float(dtype)
numba_getrs = _LAPACK().numba_xgetrs(dtype)
def impl(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
) -> tuple[np.ndarray, int]:
_N = np.int32(LU.shape[-1])
_solve_check_input_shapes(LU, B)
B_is_1d = B.ndim == 1
if overwrite_b and B.flags.f_contiguous:
B_copy = B
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)
if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)
NRHS = 1 if B_is_1d else int(B_copy.shape[-1])
TRANS = val_to_int_ptr(_trans_char_to_int(trans))
N = val_to_int_ptr(_N)
NRHS = val_to_int_ptr(NRHS)
LDA = val_to_int_ptr(_N)
LDB = val_to_int_ptr(_N)
IPIV = _copy_to_fortran_order(IPIV)
INFO = val_to_int_ptr(0)
numba_getrs(
TRANS,
N,
NRHS,
LU.view(w_type).ctypes,
LDA,
IPIV.ctypes,
B_copy.view(w_type).ctypes,
LDB,
INFO,
)
if B_is_1d:
B_copy = B_copy[..., 0]
return B_copy, int_ptr_to_val(INFO)
return impl
def _lu_solve(
lu_and_piv: tuple[np.ndarray, np.ndarray],
b: np.ndarray,
trans: int,
overwrite_b: bool,
check_finite: bool,
):
"""
Thin wrapper around scipy.lu_solve, used to avoid side effects from numba overloads on users who import Pytensor.
"""
return linalg.lu_solve(
lu_and_piv, b, trans=trans, overwrite_b=overwrite_b, check_finite=check_finite
)
@overload(_lu_solve)
def lu_solve_impl(
lu_and_piv: tuple[np.ndarray, np.ndarray],
b: np.ndarray,
trans: int,
overwrite_b: bool,
check_finite: bool,
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, bool, bool, bool], np.ndarray]:
ensure_lapack()
_check_scipy_linalg_matrix(lu_and_piv[0], "lu_solve")
_check_scipy_linalg_matrix(b, "lu_solve")
def impl(
lu: np.ndarray,
piv: np.ndarray,
b: np.ndarray,
trans: int,
overwrite_b: bool,
check_finite: bool,
) -> np.ndarray:
n = np.int32(lu.shape[0])
X, INFO = _getrs(LU=lu, B=b, IPIV=piv, trans=trans, overwrite_b=overwrite_b)
_solve_check(n, INFO)
return X
return impl
......@@ -4,6 +4,13 @@ import numpy as np
from pytensor.link.numba.dispatch.basic import numba_funcify, numba_njit
from pytensor.link.numba.dispatch.linalg.decomposition.cholesky import _cholesky
from pytensor.link.numba.dispatch.linalg.decomposition.lu import (
_lu_1,
_lu_2,
_lu_3,
_pivot_to_permutation,
)
from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _lu_factor
from pytensor.link.numba.dispatch.linalg.solve.cholesky import _cho_solve
from pytensor.link.numba.dispatch.linalg.solve.general import _solve_gen
from pytensor.link.numba.dispatch.linalg.solve.posdef import _solve_psd
......@@ -11,9 +18,12 @@ from pytensor.link.numba.dispatch.linalg.solve.symmetric import _solve_symmetric
from pytensor.link.numba.dispatch.linalg.solve.triangular import _solve_triangular
from pytensor.link.numba.dispatch.linalg.solve.tridiagonal import _solve_tridiagonal
from pytensor.tensor.slinalg import (
LU,
BlockDiagonal,
Cholesky,
CholeskySolve,
LUFactor,
PivotToPermutations,
Solve,
SolveTriangular,
)
......@@ -70,6 +80,96 @@ def numba_funcify_Cholesky(op, node, **kwargs):
return cholesky
@numba_funcify.register(PivotToPermutations)
def pivot_to_permutation(op, node, **kwargs):
inverse = op.inverse
dtype = node.inputs[0].dtype
@numba_njit
def numba_pivot_to_permutation(piv):
p_inv = _pivot_to_permutation(piv, dtype)
if inverse:
return p_inv
return np.argsort(p_inv)
return numba_pivot_to_permutation
@numba_funcify.register(LU)
def numba_funcify_LU(op, node, **kwargs):
permute_l = op.permute_l
check_finite = op.check_finite
p_indices = op.p_indices
overwrite_a = op.overwrite_a
dtype = node.inputs[0].dtype
if dtype in complex_dtypes:
NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
@numba_njit(inline="always")
def lu(a):
if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) found in input to lu"
)
if p_indices:
res = _lu_1(
a,
permute_l=permute_l,
check_finite=check_finite,
p_indices=p_indices,
overwrite_a=overwrite_a,
)
elif permute_l:
res = _lu_2(
a,
permute_l=permute_l,
check_finite=check_finite,
p_indices=p_indices,
overwrite_a=overwrite_a,
)
else:
res = _lu_3(
a,
permute_l=permute_l,
check_finite=check_finite,
p_indices=p_indices,
overwrite_a=overwrite_a,
)
return res
return lu
@numba_funcify.register(LUFactor)
def numba_funcify_LUFactor(op, node, **kwargs):
dtype = node.inputs[0].dtype
check_finite = op.check_finite
overwrite_a = op.overwrite_a
if dtype in complex_dtypes:
NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
@numba_njit
def lu_factor(a):
if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) found in input to cholesky"
)
LU, piv = _lu_factor(a, overwrite_a)
return LU, piv
return lu_factor
@numba_funcify.register(BlockDiagonal)
def numba_funcify_BlockDiagonal(op, node, **kwargs):
dtype = node.outputs[0].dtype
......
......@@ -8,7 +8,14 @@ import scipy
import pytensor
import pytensor.tensor as pt
from pytensor import In, config
from pytensor.tensor.slinalg import Cholesky, CholeskySolve, Solve, SolveTriangular
from pytensor.tensor.slinalg import (
LU,
Cholesky,
CholeskySolve,
LUFactor,
Solve,
SolveTriangular,
)
from tests.link.numba.test_basic import compare_numba_and_py, numba_inplace_mode
......@@ -494,3 +501,222 @@ def test_block_diag():
C_val = np.random.normal(size=(2, 2)).astype(floatX)
D_val = np.random.normal(size=(4, 4)).astype(floatX)
compare_numba_and_py([A, B, C, D], [X], [A_val, B_val, C_val, D_val])
@pytest.mark.parametrize("inverse", [True, False], ids=["p_inv", "p"])
def test_pivot_to_permutation(inverse):
from pytensor.tensor.slinalg import pivot_to_permutation
rng = np.random.default_rng(123)
A = rng.normal(size=(5, 5)).astype(floatX)
perm_pt = pt.vector("p", dtype="int32")
piv_pt = pivot_to_permutation(perm_pt, inverse=inverse)
f = pytensor.function([perm_pt], piv_pt, mode="NUMBA")
_, piv = scipy.linalg.lu_factor(A)
if inverse:
p = np.arange(len(piv))
for i in range(len(piv)):
p[i], p[piv[i]] = p[piv[i]], p[i]
np.testing.assert_allclose(f(piv), p)
else:
p, *_ = scipy.linalg.lu(A, p_indices=True)
np.testing.assert_allclose(f(piv), p)
@pytest.mark.parametrize(
"permute_l, p_indices",
[(True, False), (False, True), (False, False)],
ids=["PL", "p_indices", "P"],
)
@pytest.mark.parametrize(
"overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"]
)
def test_lu(permute_l, p_indices, overwrite_a):
shape = (5, 5)
rng = np.random.default_rng()
A = pt.tensor(
"A",
shape=shape,
dtype=config.floatX,
)
A_val = rng.normal(size=shape).astype(config.floatX)
lu_outputs = pt.linalg.lu(A, permute_l=permute_l, p_indices=p_indices)
fn, res = compare_numba_and_py(
[In(A, mutable=overwrite_a)],
lu_outputs,
[A_val],
numba_mode=numba_inplace_mode,
inplace=True,
)
op = fn.maker.fgraph.outputs[0].owner.op
assert isinstance(op, LU)
destroy_map = op.destroy_map
if overwrite_a and permute_l:
assert destroy_map == {0: [0]}
elif overwrite_a:
assert destroy_map == {1: [0]}
else:
assert destroy_map == {}
# Test F-contiguous input
val_f_contig = np.copy(A_val, order="F")
res_f_contig = fn(val_f_contig)
for x, x_f_contig in zip(res, res_f_contig, strict=True):
np.testing.assert_allclose(x, x_f_contig)
# Should always be destroyable
assert (A_val == val_f_contig).all() == (not overwrite_a)
# Test C-contiguous input
val_c_contig = np.copy(A_val, order="C")
res_c_contig = fn(val_c_contig)
for x, x_c_contig in zip(res, res_c_contig, strict=True):
np.testing.assert_allclose(x, x_c_contig)
# Cannot destroy C-contiguous input
np.testing.assert_allclose(val_c_contig, A_val)
# Test non-contiguous input
val_not_contig = np.repeat(A_val, 2, axis=0)[::2]
res_not_contig = fn(val_not_contig)
for x, x_not_contig in zip(res, res_not_contig, strict=True):
np.testing.assert_allclose(x, x_not_contig)
# Cannot destroy non-contiguous input
np.testing.assert_allclose(val_not_contig, A_val)
@pytest.mark.parametrize(
"overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"]
)
def test_lu_factor(overwrite_a):
shape = (5, 5)
rng = np.random.default_rng()
A = pt.tensor("A", shape=shape, dtype=config.floatX)
A_val = rng.normal(size=shape).astype(config.floatX)
LU, piv = pt.linalg.lu_factor(A)
fn, res = compare_numba_and_py(
[In(A, mutable=overwrite_a)],
[LU, piv],
[A_val],
numba_mode=numba_inplace_mode,
inplace=True,
)
op = fn.maker.fgraph.outputs[0].owner.op
assert isinstance(op, LUFactor)
if overwrite_a:
assert op.destroy_map == {1: [0]}
# Test F-contiguous input
val_f_contig = np.copy(A_val, order="F")
res_f_contig = fn(val_f_contig)
for x, x_f_contig in zip(res, res_f_contig, strict=True):
np.testing.assert_allclose(x, x_f_contig)
# Should always be destroyable
assert (A_val == val_f_contig).all() == (not overwrite_a)
# Test C-contiguous input
val_c_contig = np.copy(A_val, order="C")
res_c_contig = fn(val_c_contig)
for x, x_c_contig in zip(res, res_c_contig, strict=True):
np.testing.assert_allclose(x, x_c_contig)
# Cannot destroy C-contiguous input
np.testing.assert_allclose(val_c_contig, A_val)
# Test non-contiguous input
val_not_contig = np.repeat(A_val, 2, axis=0)[::2]
res_not_contig = fn(val_not_contig)
for x, x_not_contig in zip(res, res_not_contig, strict=True):
np.testing.assert_allclose(x, x_not_contig)
# Cannot destroy non-contiguous input
np.testing.assert_allclose(val_not_contig, A_val)
@pytest.mark.parametrize("trans", [True, False], ids=lambda x: f"trans = {x}")
@pytest.mark.parametrize(
"overwrite_b", [False, True], ids=["no_overwrite", "overwrite_b"]
)
@pytest.mark.parametrize(
"b_func, b_shape",
[(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))],
ids=["b_col_vec", "b_matrix", "b_vec"],
)
def test_lu_solve(b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bool):
A = pt.matrix("A", dtype=floatX)
b = pt.tensor("b", shape=b_shape, dtype=floatX)
rng = np.random.default_rng(418)
A_val = rng.normal(size=(5, 5)).astype(floatX)
b_val = rng.normal(size=b_shape).astype(floatX)
lu_and_piv = pt.linalg.lu_factor(A)
X = pt.linalg.lu_solve(
lu_and_piv,
b,
b_ndim=len(b_shape),
trans=trans,
)
f, res = compare_numba_and_py(
[A, In(b, mutable=overwrite_b)],
X,
test_inputs=[A_val, b_val],
inplace=True,
numba_mode=numba_inplace_mode,
eval_obj_mode=False,
)
# Test with F_contiguous inputs
A_val_f_contig = np.copy(A_val, order="F")
b_val_f_contig = np.copy(b_val, order="F")
res_f_contig = f(A_val_f_contig, b_val_f_contig)
np.testing.assert_allclose(res_f_contig, res)
all_equal = (b_val == b_val_f_contig).all()
should_destroy = overwrite_b and trans
if should_destroy:
assert not all_equal
else:
assert all_equal
# Test with C_contiguous inputs
A_val_c_contig = np.copy(A_val, order="C")
b_val_c_contig = np.copy(b_val, order="C")
res_c_contig = f(A_val_c_contig, b_val_c_contig)
np.testing.assert_allclose(res_c_contig, res)
np.testing.assert_allclose(A_val_c_contig, A_val)
# b c_contiguous vectors are also f_contiguous and destroyable
assert not (should_destroy and b_val_c_contig.flags.f_contiguous) == np.allclose(
b_val_c_contig, b_val
)
# Test with non-contiguous inputs
A_val_not_contig = np.repeat(A_val, 2, axis=0)[::2]
b_val_not_contig = np.repeat(b_val, 2, axis=0)[::2]
res_not_contig = f(A_val_not_contig, b_val_not_contig)
np.testing.assert_allclose(res_not_contig, res)
np.testing.assert_allclose(A_val_not_contig, A_val)
# Can never destroy non-contiguous inputs
np.testing.assert_allclose(b_val_not_contig, b_val)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论