提交 ffd999c8 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba linalg: handle dtypes more strictly

上级 edb1b205
import numpy as np
from numba.core.extending import overload
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from numba.types import Float
from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix
from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix
def _cholesky(a, lower=False, overwrite_a=False, check_finite=True):
......@@ -24,9 +24,9 @@ def _cholesky(a, lower=False, overwrite_a=False, check_finite=True):
@overload(_cholesky)
def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
ensure_lapack()
_check_scipy_linalg_matrix(A, "cholesky")
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="cholesky")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_potrf = _LAPACK().numba_xpotrf(dtype)
def impl(A, lower=0, overwrite_a=False, check_finite=True):
......@@ -47,7 +47,7 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
numba_potrf(
UPLO,
N,
A_copy.view(w_type).ctypes,
A_copy.ctypes,
LDA,
INFO,
)
......
......@@ -3,12 +3,13 @@ from typing import Literal
import numpy as np
from numba.core.extending import overload
from numba.core.types import Float
from numba.np.linalg import ensure_lapack
from scipy import linalg
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _getrf
from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix
from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix
@numba_basic.numba_njit
......@@ -116,7 +117,7 @@ def lu_impl_1(
False. Returns a tuple of (perm, L, U), where perm an integer array of row swaps, such that L[perm] @ U = A.
"""
ensure_lapack()
_check_scipy_linalg_matrix(a, "lu")
_check_linalg_matrix(a, ndim=2, dtype=Float, func_name="lu")
dtype = a.dtype
def impl(
......@@ -146,7 +147,7 @@ def lu_impl_2(
"""
ensure_lapack()
_check_scipy_linalg_matrix(a, "lu")
_check_linalg_matrix(a, ndim=2, dtype=Float, func_name="lu")
dtype = a.dtype
def impl(
......@@ -179,7 +180,7 @@ def lu_impl_3(
False. Returns a tuple of (P, L, U), such that P @ L @ U = A.
"""
ensure_lapack()
_check_scipy_linalg_matrix(a, "lu")
_check_linalg_matrix(a, ndim=2, dtype=Float, func_name="lu")
dtype = a.dtype
def impl(
......
......@@ -3,18 +3,16 @@ from typing import cast as typing_cast
import numpy as np
from numba.core.extending import overload
from numba.core.types import Float
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix,
)
from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix
def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]:
......@@ -38,9 +36,8 @@ def getrf_impl(
A: np.ndarray, overwrite_a: bool = False
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray, int]]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "getrf")
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="getrf")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_getrf = _LAPACK().numba_xgetrf(dtype)
def impl(
......@@ -59,7 +56,7 @@ def getrf_impl(
IPIV = np.empty(_N, dtype=np.int32) # type: ignore
INFO = val_to_int_ptr(0)
numba_getrf(M, N, A_copy.view(w_type).ctypes, LDA, IPIV.ctypes, INFO)
numba_getrf(M, N, A_copy.ctypes, LDA, IPIV.ctypes, INFO)
return A_copy, IPIV, int_ptr_to_val(INFO)
......@@ -79,7 +76,7 @@ def lu_factor_impl(
A: np.ndarray, overwrite_a: bool = False
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray]]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "lu_factor")
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="lu_factor")
def impl(A: np.ndarray, overwrite_a: bool = False) -> tuple[np.ndarray, np.ndarray]:
A_copy, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a)
......
import numpy as np
from numba.core.extending import overload
from numba.core.types import Float
from numba.np.linalg import ensure_lapack
from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix,
_check_dtypes_match,
_check_linalg_matrix,
_copy_to_fortran_order_even_if_1d,
_solve_check,
)
......@@ -31,10 +32,10 @@ def _cho_solve(
@overload(_cho_solve)
def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
ensure_lapack()
_check_scipy_linalg_matrix(C, "cho_solve")
_check_scipy_linalg_matrix(B, "cho_solve")
_check_linalg_matrix(C, ndim=2, dtype=Float, func_name="cho_solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="cho_solve")
_check_dtypes_match((C, B), func_name="cho_solve")
dtype = C.dtype
w_type = _get_underlying_float(dtype)
numba_potrs = _LAPACK().numba_xpotrs(dtype)
def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
......@@ -71,9 +72,9 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
UPLO,
N,
NRHS,
C_f.view(w_type).ctypes,
C_f.ctypes,
LDA,
B_copy.view(w_type).ctypes,
B_copy.ctypes,
LDB,
INFO,
)
......
......@@ -2,12 +2,12 @@ from collections.abc import Callable
import numpy as np
from numba.core.extending import overload
from numba.core.types import Float
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
......@@ -16,7 +16,8 @@ from pytensor.link.numba.dispatch.linalg.solve.lu_solve import _getrs
from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix,
_check_dtypes_match,
_check_linalg_matrix,
_solve_check,
)
......@@ -37,9 +38,8 @@ def xgecon_impl(
Compute the condition number of a matrix A.
"""
ensure_lapack()
_check_scipy_linalg_matrix(A, "gecon")
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="gecon")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_gecon = _LAPACK().numba_xgecon(dtype)
def impl(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]:
......@@ -58,11 +58,11 @@ def xgecon_impl(
numba_gecon(
NORM,
N,
A_copy.view(w_type).ctypes,
A_copy.ctypes,
LDA,
A_NORM.view(w_type).ctypes,
RCOND.view(w_type).ctypes,
WORK.view(w_type).ctypes,
A_NORM.ctypes,
RCOND.ctypes,
WORK.ctypes,
IWORK.ctypes,
INFO,
)
......@@ -106,8 +106,9 @@ def solve_gen_impl(
transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "solve")
_check_scipy_linalg_matrix(B, "solve")
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
_check_dtypes_match((A, B), "solve")
def impl(
A: np.ndarray,
......
......@@ -3,18 +3,19 @@ from typing import Literal, TypeAlias
import numpy as np
from numba.core.extending import overload
from numba.core.types import Float, int32
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix,
_check_dtypes_match,
_check_linalg_matrix,
_copy_to_fortran_order_even_if_1d,
_solve_check,
_trans_char_to_int,
......@@ -44,10 +45,11 @@ def getrs_impl(
[np.ndarray, np.ndarray, np.ndarray, _Trans, bool], tuple[np.ndarray, int]
]:
ensure_lapack()
_check_scipy_linalg_matrix(LU, "getrs")
_check_scipy_linalg_matrix(B, "getrs")
_check_linalg_matrix(LU, ndim=2, dtype=Float, func_name="getrs")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="getrs")
_check_dtypes_match((LU, B), func_name="getrs")
_check_linalg_matrix(IPIV, ndim=1, dtype=int32, func_name="getrs")
dtype = LU.dtype
w_type = _get_underlying_float(dtype)
numba_getrs = _LAPACK().numba_xgetrs(dtype)
def impl(
......@@ -84,10 +86,10 @@ def getrs_impl(
TRANS,
N,
NRHS,
LU.view(w_type).ctypes,
LU.ctypes,
LDA,
IPIV.ctypes,
B_copy.view(w_type).ctypes,
B_copy.ctypes,
LDB,
INFO,
)
......@@ -124,8 +126,10 @@ def lu_solve_impl(
check_finite: bool,
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, _Trans, bool, bool], np.ndarray]:
ensure_lapack()
_check_scipy_linalg_matrix(lu_and_piv[0], "lu_solve")
_check_scipy_linalg_matrix(b, "lu_solve")
lu, _piv = lu_and_piv
_check_linalg_matrix(lu, ndim=2, dtype=Float, func_name="lu_solve")
_check_linalg_matrix(b, ndim=(1, 2), dtype=Float, func_name="lu_solve")
_check_dtypes_match((lu, b), func_name="lu_solve")
def impl(
lu: np.ndarray,
......
......@@ -2,14 +2,14 @@ from collections.abc import Callable
import numpy as np
from numba.core.extending import overload
from numba.core.types import Float
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix
from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix
def _xlange(A: np.ndarray, order: str | None = None) -> float:
......@@ -28,9 +28,8 @@ def xlange_impl(
largest absolute value of a matrix A.
"""
ensure_lapack()
_check_scipy_linalg_matrix(A, "norm")
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="norm")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_lange = _LAPACK().numba_xlange(dtype)
def impl(A: np.ndarray, order: str | None = None):
......@@ -49,9 +48,7 @@ def xlange_impl(
)
WORK = np.empty(_M, dtype=dtype) # type: ignore
result = numba_lange(
NORM, M, N, A_copy.view(w_type).ctypes, LDA, WORK.view(w_type).ctypes
)
result = numba_lange(NORM, M, N, A_copy.ctypes, LDA, WORK.ctypes)
return result
......
......@@ -2,19 +2,20 @@ from collections.abc import Callable
import numpy as np
from numba.core.extending import overload
from numba.core.types import Float
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix,
_check_dtypes_match,
_check_linalg_matrix,
_copy_to_fortran_order_even_if_1d,
_solve_check,
)
......@@ -49,10 +50,10 @@ def posv_impl(
tuple[np.ndarray, np.ndarray, int],
]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "solve")
_check_scipy_linalg_matrix(B, "solve")
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
_check_dtypes_match((A, B), func_name="solve")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_posv = _LAPACK().numba_xposv(dtype)
def impl(
......@@ -99,9 +100,9 @@ def posv_impl(
UPLO,
N,
NRHS,
A_copy.view(w_type).ctypes,
A_copy.ctypes,
LDA,
B_copy.view(w_type).ctypes,
B_copy.ctypes,
LDB,
INFO,
)
......@@ -127,9 +128,8 @@ def pocon_impl(
A: np.ndarray, anorm: float
) -> Callable[[np.ndarray, float], tuple[np.ndarray, int]]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "pocon")
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="pocon")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_pocon = _LAPACK().numba_xpocon(dtype)
def impl(A: np.ndarray, anorm: float):
......@@ -148,11 +148,11 @@ def pocon_impl(
numba_pocon(
UPLO,
N,
A_copy.view(w_type).ctypes,
A_copy.ctypes,
LDA,
ANORM.view(w_type).ctypes,
RCOND.view(w_type).ctypes,
WORK.view(w_type).ctypes,
ANORM.ctypes,
RCOND.ctypes,
WORK.ctypes,
IWORK.ctypes,
INFO,
)
......@@ -196,8 +196,9 @@ def solve_psd_impl(
transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "solve")
_check_scipy_linalg_matrix(B, "solve")
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
_check_dtypes_match((A, B), func_name="solve")
def impl(
A: np.ndarray,
......
......@@ -2,19 +2,20 @@ from collections.abc import Callable
import numpy as np
from numba.core.extending import overload
from numba.core.types import Float
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix,
_check_dtypes_match,
_check_linalg_matrix,
_copy_to_fortran_order_even_if_1d,
_solve_check,
)
......@@ -37,10 +38,10 @@ def sysv_impl(
tuple[np.ndarray, np.ndarray, np.ndarray, int],
]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "sysv")
_check_scipy_linalg_matrix(B, "sysv")
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="sysv")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="sysv")
_check_dtypes_match((A, B), func_name="sysv")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_sysv = _LAPACK().numba_xsysv(dtype)
def impl(
......@@ -84,12 +85,12 @@ def sysv_impl(
UPLO,
N,
NRHS,
A_copy.view(w_type).ctypes,
A_copy.ctypes,
LDA,
IPIV.ctypes,
B_copy.view(w_type).ctypes,
B_copy.ctypes,
LDB,
WORK.view(w_type).ctypes,
WORK.ctypes,
LWORK,
INFO,
)
......@@ -103,12 +104,12 @@ def sysv_impl(
UPLO,
N,
NRHS,
A_copy.view(w_type).ctypes,
A_copy.ctypes,
LDA,
IPIV.ctypes,
B_copy.view(w_type).ctypes,
B_copy.ctypes,
LDB,
WORK.view(w_type).ctypes,
WORK.ctypes,
LWORK,
INFO,
)
......@@ -133,9 +134,8 @@ def sycon_impl(
A: np.ndarray, ipiv: np.ndarray, anorm: float
) -> Callable[[np.ndarray, np.ndarray, float], tuple[np.ndarray, int]]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "sycon")
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="sycon")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_sycon = _LAPACK().numba_xsycon(dtype)
def impl(A: np.ndarray, ipiv: np.ndarray, anorm: float) -> tuple[np.ndarray, int]:
......@@ -154,12 +154,12 @@ def sycon_impl(
numba_sycon(
UPLO,
N,
A_copy.view(w_type).ctypes,
A_copy.ctypes,
LDA,
ipiv.ctypes,
ANORM.view(w_type).ctypes,
RCOND.view(w_type).ctypes,
WORK.view(w_type).ctypes,
ANORM.ctypes,
RCOND.ctypes,
WORK.ctypes,
IWORK.ctypes,
INFO,
)
......@@ -203,8 +203,9 @@ def solve_symmetric_impl(
transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "solve")
_check_scipy_linalg_matrix(B, "solve")
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
_check_dtypes_match((A, B), func_name="solve")
def impl(
A: np.ndarray,
......
import numpy as np
from numba.core import types
from numba.core.extending import overload
from numba.core.types import Float
from numba.np.linalg import ensure_lapack
from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix,
_check_dtypes_match,
_check_linalg_matrix,
_copy_to_fortran_order_even_if_1d,
_solve_check,
_trans_char_to_int,
......@@ -45,10 +46,10 @@ def _solve_triangular(
def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
ensure_lapack()
_check_scipy_linalg_matrix(A, "solve_triangular")
_check_scipy_linalg_matrix(B, "solve_triangular")
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve_triangular")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve_triangular")
_check_dtypes_match((A, B), func_name="solve_triangular")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_trtrs = _LAPACK().numba_xtrtrs(dtype)
if isinstance(dtype, types.Complex):
# If you want to make this work with complex numbers make sure you handle the c_contiguous trick correctly
......@@ -99,9 +100,9 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
DIAG,
N,
NRHS,
A_f.view(w_type).ctypes,
A_f.ctypes,
LDA,
B_copy.view(w_type).ctypes,
B_copy.ctypes,
LDB,
INFO,
)
......
......@@ -2,21 +2,24 @@ from collections.abc import Callable
import numpy as np
from numba.core.extending import overload
from numba.core.types import Float, int32
from numba.np.linalg import ensure_lapack
from numpy import ndarray
from scipy import linalg
from pytensor import config
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import generate_fallback_impl
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix,
_check_dtypes_match,
_check_linalg_matrix,
_copy_to_fortran_order_even_if_1d,
_solve_check,
_trans_char_to_int,
......@@ -63,11 +66,11 @@ def gttrf_impl(
tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int],
]:
ensure_lapack()
_check_scipy_linalg_matrix(dl, "gttrf")
_check_scipy_linalg_matrix(d, "gttrf")
_check_scipy_linalg_matrix(du, "gttrf")
_check_linalg_matrix(dl, ndim=1, dtype=Float, func_name="gttrf")
_check_linalg_matrix(d, ndim=1, dtype=Float, func_name="gttrf")
_check_linalg_matrix(du, ndim=1, dtype=Float, func_name="gttrf")
_check_dtypes_match((dl, d, du), func_name="gttrf")
dtype = d.dtype
w_type = _get_underlying_float(dtype)
numba_gttrf = _LAPACK().numba_xgttrf(dtype)
def impl(
......@@ -94,10 +97,10 @@ def gttrf_impl(
numba_gttrf(
val_to_int_ptr(n),
dl.view(w_type).ctypes,
d.view(w_type).ctypes,
du.view(w_type).ctypes,
du2.view(w_type).ctypes,
dl.ctypes,
d.ctypes,
du.ctypes,
du2.ctypes,
ipiv.ctypes,
info,
)
......@@ -136,13 +139,14 @@ def gttrs_impl(
tuple[ndarray, int],
]:
ensure_lapack()
_check_scipy_linalg_matrix(dl, "gttrs")
_check_scipy_linalg_matrix(d, "gttrs")
_check_scipy_linalg_matrix(du, "gttrs")
_check_scipy_linalg_matrix(du2, "gttrs")
_check_scipy_linalg_matrix(b, "gttrs")
_check_linalg_matrix(dl, ndim=1, dtype=Float, func_name="gttrs")
_check_linalg_matrix(d, ndim=1, dtype=Float, func_name="gttrs")
_check_linalg_matrix(du, ndim=1, dtype=Float, func_name="gttrs")
_check_linalg_matrix(du2, ndim=1, dtype=Float, func_name="gttrs")
_check_linalg_matrix(b, ndim=(1, 2), dtype=Float, func_name="gttrs")
_check_dtypes_match((dl, d, du, du2, b), func_name="gttrs")
_check_linalg_matrix(ipiv, ndim=1, dtype=int32, func_name="gttrs")
dtype = d.dtype
w_type = _get_underlying_float(dtype)
numba_gttrs = _LAPACK().numba_xgttrs(dtype)
def impl(
......@@ -181,12 +185,12 @@ def gttrs_impl(
val_to_int_ptr(_trans_char_to_int(trans)),
val_to_int_ptr(n),
val_to_int_ptr(nrhs),
dl.view(w_type).ctypes,
d.view(w_type).ctypes,
du.view(w_type).ctypes,
du2.view(w_type).ctypes,
dl.ctypes,
d.ctypes,
du.ctypes,
du2.ctypes,
ipiv.ctypes,
b.view(w_type).ctypes,
b.ctypes,
val_to_int_ptr(n),
info,
)
......@@ -222,12 +226,13 @@ def gtcon_impl(
[ndarray, ndarray, ndarray, ndarray, ndarray, float, str], tuple[ndarray, int]
]:
ensure_lapack()
_check_scipy_linalg_matrix(dl, "gtcon")
_check_scipy_linalg_matrix(d, "gtcon")
_check_scipy_linalg_matrix(du, "gtcon")
_check_scipy_linalg_matrix(du2, "gtcon")
_check_linalg_matrix(dl, ndim=1, dtype=Float, func_name="gtcon")
_check_linalg_matrix(d, ndim=1, dtype=Float, func_name="gtcon")
_check_linalg_matrix(du, ndim=1, dtype=Float, func_name="gtcon")
_check_linalg_matrix(du2, ndim=1, dtype=Float, func_name="gtcon")
_check_dtypes_match((dl, d, du, du2), func_name="gtcon")
_check_linalg_matrix(ipiv, ndim=1, dtype=int32, func_name="gtcon")
dtype = d.dtype
w_type = _get_underlying_float(dtype)
numba_gtcon = _LAPACK().numba_xgtcon(dtype)
def impl(
......@@ -248,14 +253,14 @@ def gtcon_impl(
numba_gtcon(
val_to_int_ptr(ord(norm)),
val_to_int_ptr(n),
dl.view(w_type).ctypes,
d.view(w_type).ctypes,
du.view(w_type).ctypes,
du2.view(w_type).ctypes,
dl.ctypes,
d.ctypes,
du.ctypes,
du2.ctypes,
ipiv.ctypes,
np.array(anorm, dtype=dtype).view(w_type).ctypes,
rcond.view(w_type).ctypes,
work.view(w_type).ctypes,
np.array(anorm, dtype=dtype).ctypes,
rcond.ctypes,
work.ctypes,
iwork.ctypes,
info,
)
......@@ -300,8 +305,9 @@ def _tridiagonal_solve_impl(
transposed: bool,
) -> Callable[[ndarray, ndarray, bool, bool, bool, bool, bool], ndarray]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "solve")
_check_scipy_linalg_matrix(B, "solve")
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
_check_dtypes_match((A, B), func_name="solve")
def impl(
A: ndarray,
......@@ -342,12 +348,26 @@ def _tridiagonal_solve_impl(
@numba_funcify.register(LUFactorTridiagonal)
def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
if any(i.type.numpy_dtype.kind == "c" for i in node.inputs):
return generate_fallback_impl(op, node=node)
overwrite_dl = op.overwrite_dl
overwrite_d = op.overwrite_d
overwrite_du = op.overwrite_du
out_dtype = node.outputs[1].type.numpy_dtype
must_cast_inputs = tuple(inp.type.numpy_dtype != out_dtype for inp in node.inputs)
if any(must_cast_inputs) and config.compiler_verbose:
print("LUFactorTridiagonal requires casting at least one input") # noqa: T201
@numba_basic.numba_njit(cache=False)
def lu_factor_tridiagonal(dl, d, du):
if must_cast_inputs[0]:
d = d.astype(out_dtype)
if must_cast_inputs[1]:
dl = dl.astype(out_dtype)
if must_cast_inputs[2]:
du = du.astype(out_dtype)
dl, d, du, du2, ipiv, _ = _gttrf(
dl,
d,
......@@ -365,11 +385,34 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
def numba_funcify_SolveLUFactorTridiagonal(
op: SolveLUFactorTridiagonal, node, **kwargs
):
if any(i.type.numpy_dtype.kind == "c" for i in node.inputs):
return generate_fallback_impl(op, node=node)
out_dtype = node.outputs[0].type.numpy_dtype
overwrite_b = op.overwrite_b
transposed = op.transposed
must_cast_inputs = tuple(
inp.type.numpy_dtype != (np.int32 if i == 4 else out_dtype)
for i, inp in enumerate(node.inputs)
)
if any(must_cast_inputs) and config.compiler_verbose:
print("SolveLUFactorTridiagonal requires casting at least one input") # noqa: T201
@numba_basic.numba_njit(cache=False)
def solve_lu_factor_tridiagonal(dl, d, du, du2, ipiv, b):
if must_cast_inputs[0]:
dl = dl.astype(out_dtype)
if must_cast_inputs[1]:
d = d.astype(out_dtype)
if must_cast_inputs[2]:
du = du.astype(out_dtype)
if must_cast_inputs[3]:
du2 = du2.astype(out_dtype)
if must_cast_inputs[4]:
ipiv = ipiv.astype("int32")
if must_cast_inputs[5]:
b = b.astype(out_dtype)
x, _ = _gttrs(
dl,
d,
......
from collections.abc import Callable
from collections.abc import Callable, Sequence
import numba
from numba.core import types
......@@ -32,24 +32,34 @@ def _trans_char_to_int(trans):
return ord("C")
def _check_scipy_linalg_matrix(a, func_name):
def _check_linalg_matrix(a, *, ndim: int | Sequence[int], dtype, func_name):
"""
Adapted from https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L831
"""
prefix = "scipy.linalg"
# Unpack optional type
if isinstance(a, types.Optional):
a = a.type
if not isinstance(a, types.Array):
msg = f"{prefix}.{func_name}() only supported for array types"
msg = f"{func_name} only supported for array types"
raise numba.TypingError(msg, highlighting=False)
if a.ndim not in [1, 2]:
msg = (
f"{prefix}.{func_name}() only supported on 1d or 2d arrays, found {a.ndim}."
)
raise numba.TypingError(msg, highlighting=False)
if not isinstance(a.dtype, types.Float | types.Complex):
msg = f"{prefix}.{func_name}() only supported on float and complex arrays."
ndim_msg = f"{func_name} only supported on {ndim}d arrays, got {a.ndim}."
if isinstance(ndim, int):
if a.ndim != ndim:
raise numba.TypingError(ndim_msg, highlighting=False)
elif a.ndim not in ndim:
raise numba.TypingError(ndim_msg, highlighting=False)
dtype_msg = f"{func_name} only supported for {dtype}, got {a.dtype}."
if isinstance(dtype, type | tuple):
if not isinstance(a.dtype, dtype):
raise numba.TypingError(dtype_msg, highlighting=False)
elif a.dtype != dtype:
raise numba.TypingError(dtype_msg, highlighting=False)
def _check_dtypes_match(arrays: Sequence, func_name="cho_solve"):
dtypes = [a.dtype for a in arrays]
first_dtype = dtypes[0]
for other_dtype in dtypes[1:]:
if first_dtype != other_dtype:
msg = f"{func_name} only supported for matching dtypes, got {dtypes}"
raise numba.TypingError(msg, highlighting=False)
......
......@@ -63,13 +63,20 @@ def numba_funcify_Cholesky(op, node, **kwargs):
check_finite = op.check_finite
on_error = op.on_error
dtype = node.inputs[0].dtype
if dtype in complex_dtypes:
inp_dtype = node.inputs[0].type.numpy_dtype
if inp_dtype.kind == "c":
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
discrete_inp = inp_dtype.kind in "ibu"
if discrete_inp and config.compiler_verbose:
print("Cholesky requires casting discrete input to float") # noqa: T201
out_dtype = node.outputs[0].type.numpy_dtype
@numba_basic.numba_njit
def cholesky(a):
if check_finite:
if discrete_inp:
a = a.astype(out_dtype)
elif check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) found in input to cholesky"
......@@ -112,18 +119,24 @@ def pivot_to_permutation(op, node, **kwargs):
@numba_funcify.register(LU)
def numba_funcify_LU(op, node, **kwargs):
inp_dtype = node.inputs[0].type.numpy_dtype
if inp_dtype.kind == "c":
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
discrete_inp = inp_dtype.kind in "ibu"
if discrete_inp and config.compiler_verbose:
print("LU requires casting discrete input to float") # noqa: T201
out_dtype = node.outputs[0].type.numpy_dtype
permute_l = op.permute_l
check_finite = op.check_finite
p_indices = op.p_indices
overwrite_a = op.overwrite_a
dtype = node.inputs[0].dtype
if dtype in complex_dtypes:
NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
@numba_basic.numba_njit
def lu(a):
if check_finite:
if discrete_inp:
a = a.astype(out_dtype)
elif check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) found in input to lu"
......@@ -161,16 +174,22 @@ def numba_funcify_LU(op, node, **kwargs):
@numba_funcify.register(LUFactor)
def numba_funcify_LUFactor(op, node, **kwargs):
dtype = node.inputs[0].dtype
inp_dtype = node.inputs[0].type.numpy_dtype
if inp_dtype.kind == "c":
NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
discrete_inp = inp_dtype.kind in "ibu"
if discrete_inp and config.compiler_verbose:
print("LUFactor requires casting discrete input to float") # noqa: T201
out_dtype = node.outputs[0].type.numpy_dtype
check_finite = op.check_finite
overwrite_a = op.overwrite_a
if dtype in complex_dtypes:
NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
@numba_basic.numba_njit
def lu_factor(a):
if check_finite:
if discrete_inp:
a = a.astype(out_dtype)
elif check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) found in input to cholesky"
......@@ -207,6 +226,21 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs):
@numba_funcify.register(Solve)
def numba_funcify_Solve(op, node, **kwargs):
A_dtype, b_dtype = (i.numpy_dtype for i in node.inputs)
out_dtype = node.outputs[0].type.numpy_dtype
if A_dtype.kind == "c" or b_dtype.kind == "c":
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
must_cast_A = A_dtype != out_dtype
if must_cast_A and config.compiler_verbose:
print("Solve requires casting first input `A`") # noqa: T201
must_cast_B = b_dtype != out_dtype
if must_cast_B and config.compiler_verbose:
print("Solve requires casting second input `b`") # noqa: T201
check_finite = op.check_finite
overwrite_a = op.overwrite_a
assume_a = op.assume_a
lower = op.lower
check_finite = op.check_finite
......@@ -214,10 +248,6 @@ def numba_funcify_Solve(op, node, **kwargs):
overwrite_b = op.overwrite_b
transposed = False # TODO: Solve doesnt currently allow the transposed argument
dtype = node.inputs[0].dtype
if dtype in complex_dtypes:
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
if assume_a == "gen":
solve_fn = _solve_gen
elif assume_a == "sym":
......@@ -239,6 +269,10 @@ def numba_funcify_Solve(op, node, **kwargs):
@numba_basic.numba_njit
def solve(a, b):
if must_cast_A:
a = a.astype(out_dtype)
if must_cast_B:
b = b.astype(out_dtype)
if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError(
......@@ -263,14 +297,24 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
overwrite_b = op.overwrite_b
b_ndim = op.b_ndim
dtype = node.inputs[0].dtype
if dtype in complex_dtypes:
raise NotImplementedError(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op="Solve Triangular")
)
A_dtype, b_dtype = (i.numpy_dtype for i in node.inputs)
out_dtype = node.outputs[0].type.numpy_dtype
if A_dtype.kind == "c" or b_dtype.kind == "c":
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
must_cast_A = A_dtype != out_dtype
if must_cast_A and config.compiler_verbose:
print("SolveTriangular requires casting first input `A`") # noqa: T201
must_cast_B = b_dtype != out_dtype
if must_cast_B and config.compiler_verbose:
print("SolveTriangular requires casting second input `b`") # noqa: T201
@numba_basic.numba_njit
def solve_triangular(a, b):
if must_cast_A:
a = a.astype(out_dtype)
if must_cast_B:
b = b.astype(out_dtype)
if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError(
......@@ -302,24 +346,42 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
overwrite_b = op.overwrite_b
check_finite = op.check_finite
dtype = node.inputs[0].dtype
if dtype in complex_dtypes:
c_dtype, b_dtype = (i.type.numpy_dtype for i in node.inputs)
out_dtype = node.outputs[0].type.numpy_dtype
if c_dtype.kind == "c" or b_dtype.kind == "c":
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
must_cast_c = c_dtype != out_dtype
if must_cast_c and config.compiler_verbose:
print("CholeskySolve requires casting first input `c`") # noqa: T201
must_cast_b = b_dtype != out_dtype
if must_cast_b and config.compiler_verbose:
print("CholeskySolve requires casting second input `b`") # noqa: T201
@numba_basic.numba_njit
def cho_solve(c, b):
if must_cast_c:
c = c.astype(out_dtype)
if check_finite:
if np.any(np.bitwise_or(np.isinf(c), np.isnan(c))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input A to cho_solve"
)
if must_cast_b:
b = b.astype(out_dtype)
if check_finite:
if np.any(np.bitwise_or(np.isinf(b), np.isnan(b))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input b to cho_solve"
)
return _cho_solve(
c, b, lower=lower, overwrite_b=overwrite_b, check_finite=check_finite
c,
b,
lower=lower,
overwrite_b=overwrite_b,
check_finite=check_finite,
)
return cho_solve
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论