提交 b37dd14b authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

allow complex inputs to numba lu and lu_factor

上级 e3290af0
......@@ -3,7 +3,7 @@ from typing import cast as typing_cast
import numpy as np
from numba.core.extending import overload
from numba.core.types import Float
from numba.core.types import Complex, Float
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy import linalg
......@@ -36,7 +36,7 @@ def getrf_impl(
A: np.ndarray, overwrite_a: bool = False
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray, int]]:
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="getrf")
_check_linalg_matrix(A, ndim=2, dtype=(Float, Complex), func_name="getrf")
dtype = A.dtype
numba_getrf = _LAPACK().numba_xgetrf(dtype)
......@@ -76,7 +76,7 @@ def lu_factor_impl(
A: np.ndarray, overwrite_a: bool = False
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray]]:
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="lu_factor")
_check_linalg_matrix(A, ndim=2, dtype=(Float, Complex), func_name="lu_factor")
def impl(A: np.ndarray, overwrite_a: bool = False) -> tuple[np.ndarray, np.ndarray]:
A_copy, IPIV, info = _getrf(A, overwrite_a=overwrite_a)
......
......@@ -3,7 +3,7 @@ from typing import Literal, TypeAlias
import numpy as np
from numba.core.extending import overload
from numba.core.types import Float, int32
from numba.core.types import Complex, Float, int32
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy import linalg
......@@ -44,8 +44,8 @@ def getrs_impl(
[np.ndarray, np.ndarray, np.ndarray, _Trans, bool], tuple[np.ndarray, int]
]:
ensure_lapack()
_check_linalg_matrix(LU, ndim=2, dtype=Float, func_name="getrs")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="getrs")
_check_linalg_matrix(LU, ndim=2, dtype=(Float, Complex), func_name="getrs")
_check_linalg_matrix(B, ndim=(1, 2), dtype=(Float, Complex), func_name="getrs")
_check_dtypes_match((LU, B), func_name="getrs")
_check_linalg_matrix(IPIV, ndim=1, dtype=int32, func_name="getrs")
dtype = LU.dtype
......@@ -122,8 +122,8 @@ def lu_solve_impl(
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, _Trans, bool], np.ndarray]:
ensure_lapack()
lu, _piv = lu_and_piv
_check_linalg_matrix(lu, ndim=2, dtype=Float, func_name="lu_solve")
_check_linalg_matrix(b, ndim=(1, 2), dtype=Float, func_name="lu_solve")
_check_linalg_matrix(lu, ndim=2, dtype=(Float, Complex), func_name="lu_solve")
_check_linalg_matrix(b, ndim=(1, 2), dtype=(Float, Complex), func_name="lu_solve")
_check_dtypes_match((lu, b), func_name="lu_solve")
def impl(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论