提交 bf0fe7af authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

allow complex inputs to numba tridiagonal solve

上级 2741295a
...@@ -2,7 +2,7 @@ from collections.abc import Callable ...@@ -2,7 +2,7 @@ from collections.abc import Callable
import numpy as np import numpy as np
from numba.core.extending import overload from numba.core.extending import overload
from numba.core.types import Float, int32 from numba.core.types import Complex, Float, int32
from numba.np.linalg import ensure_lapack from numba.np.linalg import ensure_lapack
from numpy import ndarray from numpy import ndarray
from scipy import linalg from scipy import linalg
...@@ -67,9 +67,9 @@ def gttrf_impl( ...@@ -67,9 +67,9 @@ def gttrf_impl(
tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int], tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int],
]: ]:
ensure_lapack() ensure_lapack()
_check_linalg_matrix(dl, ndim=1, dtype=Float, func_name="gttrf") _check_linalg_matrix(dl, ndim=1, dtype=(Float, Complex), func_name="gttrf")
_check_linalg_matrix(d, ndim=1, dtype=Float, func_name="gttrf") _check_linalg_matrix(d, ndim=1, dtype=(Float, Complex), func_name="gttrf")
_check_linalg_matrix(du, ndim=1, dtype=Float, func_name="gttrf") _check_linalg_matrix(du, ndim=1, dtype=(Float, Complex), func_name="gttrf")
_check_dtypes_match((dl, d, du), func_name="gttrf") _check_dtypes_match((dl, d, du), func_name="gttrf")
dtype = d.dtype dtype = d.dtype
numba_gttrf = _LAPACK().numba_xgttrf(dtype) numba_gttrf = _LAPACK().numba_xgttrf(dtype)
...@@ -140,11 +140,11 @@ def gttrs_impl( ...@@ -140,11 +140,11 @@ def gttrs_impl(
tuple[ndarray, int], tuple[ndarray, int],
]: ]:
ensure_lapack() ensure_lapack()
_check_linalg_matrix(dl, ndim=1, dtype=Float, func_name="gttrs") _check_linalg_matrix(dl, ndim=1, dtype=(Float, Complex), func_name="gttrs")
_check_linalg_matrix(d, ndim=1, dtype=Float, func_name="gttrs") _check_linalg_matrix(d, ndim=1, dtype=(Float, Complex), func_name="gttrs")
_check_linalg_matrix(du, ndim=1, dtype=Float, func_name="gttrs") _check_linalg_matrix(du, ndim=1, dtype=(Float, Complex), func_name="gttrs")
_check_linalg_matrix(du2, ndim=1, dtype=Float, func_name="gttrs") _check_linalg_matrix(du2, ndim=1, dtype=(Float, Complex), func_name="gttrs")
_check_linalg_matrix(b, ndim=(1, 2), dtype=Float, func_name="gttrs") _check_linalg_matrix(b, ndim=(1, 2), dtype=(Float, Complex), func_name="gttrs")
_check_dtypes_match((dl, d, du, du2, b), func_name="gttrs") _check_dtypes_match((dl, d, du, du2, b), func_name="gttrs")
_check_linalg_matrix(ipiv, ndim=1, dtype=int32, func_name="gttrs") _check_linalg_matrix(ipiv, ndim=1, dtype=int32, func_name="gttrs")
dtype = d.dtype dtype = d.dtype
...@@ -234,8 +234,8 @@ def _tridiagonal_solve_impl( ...@@ -234,8 +234,8 @@ def _tridiagonal_solve_impl(
transposed: bool, transposed: bool,
) -> Callable[[ndarray, ndarray, bool, bool, bool, bool], ndarray]: ) -> Callable[[ndarray, ndarray, bool, bool, bool, bool], ndarray]:
ensure_lapack() ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve") _check_linalg_matrix(A, ndim=2, dtype=(Float, Complex), func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve") _check_linalg_matrix(B, ndim=(1, 2), dtype=(Float, Complex), func_name="solve")
_check_dtypes_match((A, B), func_name="solve") _check_dtypes_match((A, B), func_name="solve")
def impl( def impl(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论