提交 bf0fe7af authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

allow complex inputs to numba tridiagonal solve

上级 2741295a
......@@ -2,7 +2,7 @@ from collections.abc import Callable
import numpy as np
from numba.core.extending import overload
from numba.core.types import Float, int32
from numba.core.types import Complex, Float, int32
from numba.np.linalg import ensure_lapack
from numpy import ndarray
from scipy import linalg
......@@ -67,9 +67,9 @@ def gttrf_impl(
tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int],
]:
ensure_lapack()
_check_linalg_matrix(dl, ndim=1, dtype=Float, func_name="gttrf")
_check_linalg_matrix(d, ndim=1, dtype=Float, func_name="gttrf")
_check_linalg_matrix(du, ndim=1, dtype=Float, func_name="gttrf")
_check_linalg_matrix(dl, ndim=1, dtype=(Float, Complex), func_name="gttrf")
_check_linalg_matrix(d, ndim=1, dtype=(Float, Complex), func_name="gttrf")
_check_linalg_matrix(du, ndim=1, dtype=(Float, Complex), func_name="gttrf")
_check_dtypes_match((dl, d, du), func_name="gttrf")
dtype = d.dtype
numba_gttrf = _LAPACK().numba_xgttrf(dtype)
......@@ -140,11 +140,11 @@ def gttrs_impl(
tuple[ndarray, int],
]:
ensure_lapack()
_check_linalg_matrix(dl, ndim=1, dtype=Float, func_name="gttrs")
_check_linalg_matrix(d, ndim=1, dtype=Float, func_name="gttrs")
_check_linalg_matrix(du, ndim=1, dtype=Float, func_name="gttrs")
_check_linalg_matrix(du2, ndim=1, dtype=Float, func_name="gttrs")
_check_linalg_matrix(b, ndim=(1, 2), dtype=Float, func_name="gttrs")
_check_linalg_matrix(dl, ndim=1, dtype=(Float, Complex), func_name="gttrs")
_check_linalg_matrix(d, ndim=1, dtype=(Float, Complex), func_name="gttrs")
_check_linalg_matrix(du, ndim=1, dtype=(Float, Complex), func_name="gttrs")
_check_linalg_matrix(du2, ndim=1, dtype=(Float, Complex), func_name="gttrs")
_check_linalg_matrix(b, ndim=(1, 2), dtype=(Float, Complex), func_name="gttrs")
_check_dtypes_match((dl, d, du, du2, b), func_name="gttrs")
_check_linalg_matrix(ipiv, ndim=1, dtype=int32, func_name="gttrs")
dtype = d.dtype
......@@ -234,8 +234,8 @@ def _tridiagonal_solve_impl(
transposed: bool,
) -> Callable[[ndarray, ndarray, bool, bool, bool, bool], ndarray]:
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
_check_linalg_matrix(A, ndim=2, dtype=(Float, Complex), func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=(Float, Complex), func_name="solve")
_check_dtypes_match((A, B), func_name="solve")
def impl(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论