提交 b37dd14b authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

allow complex inputs to numba lu and lu_factor

上级 e3290af0
...@@ -3,7 +3,7 @@ from typing import cast as typing_cast ...@@ -3,7 +3,7 @@ from typing import cast as typing_cast
import numpy as np import numpy as np
from numba.core.extending import overload from numba.core.extending import overload
from numba.core.types import Float from numba.core.types import Complex, Float
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy import linalg from scipy import linalg
...@@ -36,7 +36,7 @@ def getrf_impl( ...@@ -36,7 +36,7 @@ def getrf_impl(
A: np.ndarray, overwrite_a: bool = False A: np.ndarray, overwrite_a: bool = False
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray, int]]: ) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray, int]]:
ensure_lapack() ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="getrf") _check_linalg_matrix(A, ndim=2, dtype=(Float, Complex), func_name="getrf")
dtype = A.dtype dtype = A.dtype
numba_getrf = _LAPACK().numba_xgetrf(dtype) numba_getrf = _LAPACK().numba_xgetrf(dtype)
...@@ -76,7 +76,7 @@ def lu_factor_impl( ...@@ -76,7 +76,7 @@ def lu_factor_impl(
A: np.ndarray, overwrite_a: bool = False A: np.ndarray, overwrite_a: bool = False
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray]]: ) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray]]:
ensure_lapack() ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="lu_factor") _check_linalg_matrix(A, ndim=2, dtype=(Float, Complex), func_name="lu_factor")
def impl(A: np.ndarray, overwrite_a: bool = False) -> tuple[np.ndarray, np.ndarray]: def impl(A: np.ndarray, overwrite_a: bool = False) -> tuple[np.ndarray, np.ndarray]:
A_copy, IPIV, info = _getrf(A, overwrite_a=overwrite_a) A_copy, IPIV, info = _getrf(A, overwrite_a=overwrite_a)
......
...@@ -3,7 +3,7 @@ from typing import Literal, TypeAlias ...@@ -3,7 +3,7 @@ from typing import Literal, TypeAlias
import numpy as np import numpy as np
from numba.core.extending import overload from numba.core.extending import overload
from numba.core.types import Float, int32 from numba.core.types import Complex, Float, int32
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy import linalg from scipy import linalg
...@@ -44,8 +44,8 @@ def getrs_impl( ...@@ -44,8 +44,8 @@ def getrs_impl(
[np.ndarray, np.ndarray, np.ndarray, _Trans, bool], tuple[np.ndarray, int] [np.ndarray, np.ndarray, np.ndarray, _Trans, bool], tuple[np.ndarray, int]
]: ]:
ensure_lapack() ensure_lapack()
_check_linalg_matrix(LU, ndim=2, dtype=Float, func_name="getrs") _check_linalg_matrix(LU, ndim=2, dtype=(Float, Complex), func_name="getrs")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="getrs") _check_linalg_matrix(B, ndim=(1, 2), dtype=(Float, Complex), func_name="getrs")
_check_dtypes_match((LU, B), func_name="getrs") _check_dtypes_match((LU, B), func_name="getrs")
_check_linalg_matrix(IPIV, ndim=1, dtype=int32, func_name="getrs") _check_linalg_matrix(IPIV, ndim=1, dtype=int32, func_name="getrs")
dtype = LU.dtype dtype = LU.dtype
...@@ -122,8 +122,8 @@ def lu_solve_impl( ...@@ -122,8 +122,8 @@ def lu_solve_impl(
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, _Trans, bool], np.ndarray]: ) -> Callable[[np.ndarray, np.ndarray, np.ndarray, _Trans, bool], np.ndarray]:
ensure_lapack() ensure_lapack()
lu, _piv = lu_and_piv lu, _piv = lu_and_piv
_check_linalg_matrix(lu, ndim=2, dtype=Float, func_name="lu_solve") _check_linalg_matrix(lu, ndim=2, dtype=(Float, Complex), func_name="lu_solve")
_check_linalg_matrix(b, ndim=(1, 2), dtype=Float, func_name="lu_solve") _check_linalg_matrix(b, ndim=(1, 2), dtype=(Float, Complex), func_name="lu_solve")
_check_dtypes_match((lu, b), func_name="lu_solve") _check_dtypes_match((lu, b), func_name="lu_solve")
def impl( def impl(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论