提交 ffd999c8 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba linalg: handle dtypes more strictly

上级 edb1b205
import numpy as np import numpy as np
from numba.core.extending import overload from numba.core.extending import overload
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from numba.types import Float
from scipy import linalg from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import ( from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK, _LAPACK,
_get_underlying_float,
int_ptr_to_val, int_ptr_to_val,
val_to_int_ptr, val_to_int_ptr,
) )
from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix
def _cholesky(a, lower=False, overwrite_a=False, check_finite=True): def _cholesky(a, lower=False, overwrite_a=False, check_finite=True):
...@@ -24,9 +24,9 @@ def _cholesky(a, lower=False, overwrite_a=False, check_finite=True): ...@@ -24,9 +24,9 @@ def _cholesky(a, lower=False, overwrite_a=False, check_finite=True):
@overload(_cholesky) @overload(_cholesky)
def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True): def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
ensure_lapack() ensure_lapack()
_check_scipy_linalg_matrix(A, "cholesky") _check_linalg_matrix(A, ndim=2, dtype=Float, func_name="cholesky")
dtype = A.dtype dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_potrf = _LAPACK().numba_xpotrf(dtype) numba_potrf = _LAPACK().numba_xpotrf(dtype)
def impl(A, lower=0, overwrite_a=False, check_finite=True): def impl(A, lower=0, overwrite_a=False, check_finite=True):
...@@ -47,7 +47,7 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True): ...@@ -47,7 +47,7 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
numba_potrf( numba_potrf(
UPLO, UPLO,
N, N,
A_copy.view(w_type).ctypes, A_copy.ctypes,
LDA, LDA,
INFO, INFO,
) )
......
...@@ -3,12 +3,13 @@ from typing import Literal ...@@ -3,12 +3,13 @@ from typing import Literal
import numpy as np import numpy as np
from numba.core.extending import overload from numba.core.extending import overload
from numba.core.types import Float
from numba.np.linalg import ensure_lapack from numba.np.linalg import ensure_lapack
from scipy import linalg from scipy import linalg
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _getrf from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _getrf
from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix
@numba_basic.numba_njit @numba_basic.numba_njit
...@@ -116,7 +117,7 @@ def lu_impl_1( ...@@ -116,7 +117,7 @@ def lu_impl_1(
False. Returns a tuple of (perm, L, U), where perm an integer array of row swaps, such that L[perm] @ U = A. False. Returns a tuple of (perm, L, U), where perm an integer array of row swaps, such that L[perm] @ U = A.
""" """
ensure_lapack() ensure_lapack()
_check_scipy_linalg_matrix(a, "lu") _check_linalg_matrix(a, ndim=2, dtype=Float, func_name="lu")
dtype = a.dtype dtype = a.dtype
def impl( def impl(
...@@ -146,7 +147,7 @@ def lu_impl_2( ...@@ -146,7 +147,7 @@ def lu_impl_2(
""" """
ensure_lapack() ensure_lapack()
_check_scipy_linalg_matrix(a, "lu") _check_linalg_matrix(a, ndim=2, dtype=Float, func_name="lu")
dtype = a.dtype dtype = a.dtype
def impl( def impl(
...@@ -179,7 +180,7 @@ def lu_impl_3( ...@@ -179,7 +180,7 @@ def lu_impl_3(
False. Returns a tuple of (P, L, U), such that P @ L @ U = A. False. Returns a tuple of (P, L, U), such that P @ L @ U = A.
""" """
ensure_lapack() ensure_lapack()
_check_scipy_linalg_matrix(a, "lu") _check_linalg_matrix(a, ndim=2, dtype=Float, func_name="lu")
dtype = a.dtype dtype = a.dtype
def impl( def impl(
......
...@@ -3,18 +3,16 @@ from typing import cast as typing_cast ...@@ -3,18 +3,16 @@ from typing import cast as typing_cast
import numpy as np import numpy as np
from numba.core.extending import overload from numba.core.extending import overload
from numba.core.types import Float
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy import linalg from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import ( from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK, _LAPACK,
_get_underlying_float,
int_ptr_to_val, int_ptr_to_val,
val_to_int_ptr, val_to_int_ptr,
) )
from pytensor.link.numba.dispatch.linalg.utils import ( from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix
_check_scipy_linalg_matrix,
)
def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]: def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]:
...@@ -38,9 +36,8 @@ def getrf_impl( ...@@ -38,9 +36,8 @@ def getrf_impl(
A: np.ndarray, overwrite_a: bool = False A: np.ndarray, overwrite_a: bool = False
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray, int]]: ) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray, int]]:
ensure_lapack() ensure_lapack()
_check_scipy_linalg_matrix(A, "getrf") _check_linalg_matrix(A, ndim=2, dtype=Float, func_name="getrf")
dtype = A.dtype dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_getrf = _LAPACK().numba_xgetrf(dtype) numba_getrf = _LAPACK().numba_xgetrf(dtype)
def impl( def impl(
...@@ -59,7 +56,7 @@ def getrf_impl( ...@@ -59,7 +56,7 @@ def getrf_impl(
IPIV = np.empty(_N, dtype=np.int32) # type: ignore IPIV = np.empty(_N, dtype=np.int32) # type: ignore
INFO = val_to_int_ptr(0) INFO = val_to_int_ptr(0)
numba_getrf(M, N, A_copy.view(w_type).ctypes, LDA, IPIV.ctypes, INFO) numba_getrf(M, N, A_copy.ctypes, LDA, IPIV.ctypes, INFO)
return A_copy, IPIV, int_ptr_to_val(INFO) return A_copy, IPIV, int_ptr_to_val(INFO)
...@@ -79,7 +76,7 @@ def lu_factor_impl( ...@@ -79,7 +76,7 @@ def lu_factor_impl(
A: np.ndarray, overwrite_a: bool = False A: np.ndarray, overwrite_a: bool = False
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray]]: ) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray]]:
ensure_lapack() ensure_lapack()
_check_scipy_linalg_matrix(A, "lu_factor") _check_linalg_matrix(A, ndim=2, dtype=Float, func_name="lu_factor")
def impl(A: np.ndarray, overwrite_a: bool = False) -> tuple[np.ndarray, np.ndarray]: def impl(A: np.ndarray, overwrite_a: bool = False) -> tuple[np.ndarray, np.ndarray]:
A_copy, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a) A_copy, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a)
......
import numpy as np import numpy as np
from numba.core.extending import overload from numba.core.extending import overload
from numba.core.types import Float
from numba.np.linalg import ensure_lapack from numba.np.linalg import ensure_lapack
from scipy import linalg from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import ( from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK, _LAPACK,
_get_underlying_float,
int_ptr_to_val, int_ptr_to_val,
val_to_int_ptr, val_to_int_ptr,
) )
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import ( from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix, _check_dtypes_match,
_check_linalg_matrix,
_copy_to_fortran_order_even_if_1d, _copy_to_fortran_order_even_if_1d,
_solve_check, _solve_check,
) )
...@@ -31,10 +32,10 @@ def _cho_solve( ...@@ -31,10 +32,10 @@ def _cho_solve(
@overload(_cho_solve) @overload(_cho_solve)
def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True): def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
ensure_lapack() ensure_lapack()
_check_scipy_linalg_matrix(C, "cho_solve") _check_linalg_matrix(C, ndim=2, dtype=Float, func_name="cho_solve")
_check_scipy_linalg_matrix(B, "cho_solve") _check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="cho_solve")
_check_dtypes_match((C, B), func_name="cho_solve")
dtype = C.dtype dtype = C.dtype
w_type = _get_underlying_float(dtype)
numba_potrs = _LAPACK().numba_xpotrs(dtype) numba_potrs = _LAPACK().numba_xpotrs(dtype)
def impl(C, B, lower=False, overwrite_b=False, check_finite=True): def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
...@@ -71,9 +72,9 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True): ...@@ -71,9 +72,9 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
UPLO, UPLO,
N, N,
NRHS, NRHS,
C_f.view(w_type).ctypes, C_f.ctypes,
LDA, LDA,
B_copy.view(w_type).ctypes, B_copy.ctypes,
LDB, LDB,
INFO, INFO,
) )
......
...@@ -2,12 +2,12 @@ from collections.abc import Callable ...@@ -2,12 +2,12 @@ from collections.abc import Callable
import numpy as np import numpy as np
from numba.core.extending import overload from numba.core.extending import overload
from numba.core.types import Float
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy import linalg from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import ( from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK, _LAPACK,
_get_underlying_float,
int_ptr_to_val, int_ptr_to_val,
val_to_int_ptr, val_to_int_ptr,
) )
...@@ -16,7 +16,8 @@ from pytensor.link.numba.dispatch.linalg.solve.lu_solve import _getrs ...@@ -16,7 +16,8 @@ from pytensor.link.numba.dispatch.linalg.solve.lu_solve import _getrs
from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import ( from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix, _check_dtypes_match,
_check_linalg_matrix,
_solve_check, _solve_check,
) )
...@@ -37,9 +38,8 @@ def xgecon_impl( ...@@ -37,9 +38,8 @@ def xgecon_impl(
Compute the condition number of a matrix A. Compute the condition number of a matrix A.
""" """
ensure_lapack() ensure_lapack()
_check_scipy_linalg_matrix(A, "gecon") _check_linalg_matrix(A, ndim=2, dtype=Float, func_name="gecon")
dtype = A.dtype dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_gecon = _LAPACK().numba_xgecon(dtype) numba_gecon = _LAPACK().numba_xgecon(dtype)
def impl(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]: def impl(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]:
...@@ -58,11 +58,11 @@ def xgecon_impl( ...@@ -58,11 +58,11 @@ def xgecon_impl(
numba_gecon( numba_gecon(
NORM, NORM,
N, N,
A_copy.view(w_type).ctypes, A_copy.ctypes,
LDA, LDA,
A_NORM.view(w_type).ctypes, A_NORM.ctypes,
RCOND.view(w_type).ctypes, RCOND.ctypes,
WORK.view(w_type).ctypes, WORK.ctypes,
IWORK.ctypes, IWORK.ctypes,
INFO, INFO,
) )
...@@ -106,8 +106,9 @@ def solve_gen_impl( ...@@ -106,8 +106,9 @@ def solve_gen_impl(
transposed: bool, transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]: ) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]:
ensure_lapack() ensure_lapack()
_check_scipy_linalg_matrix(A, "solve") _check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
_check_scipy_linalg_matrix(B, "solve") _check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
_check_dtypes_match((A, B), "solve")
def impl( def impl(
A: np.ndarray, A: np.ndarray,
......
...@@ -3,18 +3,19 @@ from typing import Literal, TypeAlias ...@@ -3,18 +3,19 @@ from typing import Literal, TypeAlias
import numpy as np import numpy as np
from numba.core.extending import overload from numba.core.extending import overload
from numba.core.types import Float, int32
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy import linalg from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import ( from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK, _LAPACK,
_get_underlying_float,
int_ptr_to_val, int_ptr_to_val,
val_to_int_ptr, val_to_int_ptr,
) )
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import ( from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix, _check_dtypes_match,
_check_linalg_matrix,
_copy_to_fortran_order_even_if_1d, _copy_to_fortran_order_even_if_1d,
_solve_check, _solve_check,
_trans_char_to_int, _trans_char_to_int,
...@@ -44,10 +45,11 @@ def getrs_impl( ...@@ -44,10 +45,11 @@ def getrs_impl(
[np.ndarray, np.ndarray, np.ndarray, _Trans, bool], tuple[np.ndarray, int] [np.ndarray, np.ndarray, np.ndarray, _Trans, bool], tuple[np.ndarray, int]
]: ]:
ensure_lapack() ensure_lapack()
_check_scipy_linalg_matrix(LU, "getrs") _check_linalg_matrix(LU, ndim=2, dtype=Float, func_name="getrs")
_check_scipy_linalg_matrix(B, "getrs") _check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="getrs")
_check_dtypes_match((LU, B), func_name="getrs")
_check_linalg_matrix(IPIV, ndim=1, dtype=int32, func_name="getrs")
dtype = LU.dtype dtype = LU.dtype
w_type = _get_underlying_float(dtype)
numba_getrs = _LAPACK().numba_xgetrs(dtype) numba_getrs = _LAPACK().numba_xgetrs(dtype)
def impl( def impl(
...@@ -84,10 +86,10 @@ def getrs_impl( ...@@ -84,10 +86,10 @@ def getrs_impl(
TRANS, TRANS,
N, N,
NRHS, NRHS,
LU.view(w_type).ctypes, LU.ctypes,
LDA, LDA,
IPIV.ctypes, IPIV.ctypes,
B_copy.view(w_type).ctypes, B_copy.ctypes,
LDB, LDB,
INFO, INFO,
) )
...@@ -124,8 +126,10 @@ def lu_solve_impl( ...@@ -124,8 +126,10 @@ def lu_solve_impl(
check_finite: bool, check_finite: bool,
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, _Trans, bool, bool], np.ndarray]: ) -> Callable[[np.ndarray, np.ndarray, np.ndarray, _Trans, bool, bool], np.ndarray]:
ensure_lapack() ensure_lapack()
_check_scipy_linalg_matrix(lu_and_piv[0], "lu_solve") lu, _piv = lu_and_piv
_check_scipy_linalg_matrix(b, "lu_solve") _check_linalg_matrix(lu, ndim=2, dtype=Float, func_name="lu_solve")
_check_linalg_matrix(b, ndim=(1, 2), dtype=Float, func_name="lu_solve")
_check_dtypes_match((lu, b), func_name="lu_solve")
def impl( def impl(
lu: np.ndarray, lu: np.ndarray,
......
...@@ -2,14 +2,14 @@ from collections.abc import Callable ...@@ -2,14 +2,14 @@ from collections.abc import Callable
import numpy as np import numpy as np
from numba.core.extending import overload from numba.core.extending import overload
from numba.core.types import Float
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from pytensor.link.numba.dispatch.linalg._LAPACK import ( from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK, _LAPACK,
_get_underlying_float,
val_to_int_ptr, val_to_int_ptr,
) )
from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix
def _xlange(A: np.ndarray, order: str | None = None) -> float: def _xlange(A: np.ndarray, order: str | None = None) -> float:
...@@ -28,9 +28,8 @@ def xlange_impl( ...@@ -28,9 +28,8 @@ def xlange_impl(
largest absolute value of a matrix A. largest absolute value of a matrix A.
""" """
ensure_lapack() ensure_lapack()
_check_scipy_linalg_matrix(A, "norm") _check_linalg_matrix(A, ndim=2, dtype=Float, func_name="norm")
dtype = A.dtype dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_lange = _LAPACK().numba_xlange(dtype) numba_lange = _LAPACK().numba_xlange(dtype)
def impl(A: np.ndarray, order: str | None = None): def impl(A: np.ndarray, order: str | None = None):
...@@ -49,9 +48,7 @@ def xlange_impl( ...@@ -49,9 +48,7 @@ def xlange_impl(
) )
WORK = np.empty(_M, dtype=dtype) # type: ignore WORK = np.empty(_M, dtype=dtype) # type: ignore
result = numba_lange( result = numba_lange(NORM, M, N, A_copy.ctypes, LDA, WORK.ctypes)
NORM, M, N, A_copy.view(w_type).ctypes, LDA, WORK.view(w_type).ctypes
)
return result return result
......
...@@ -2,19 +2,20 @@ from collections.abc import Callable ...@@ -2,19 +2,20 @@ from collections.abc import Callable
import numpy as np import numpy as np
from numba.core.extending import overload from numba.core.extending import overload
from numba.core.types import Float
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy import linalg from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import ( from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK, _LAPACK,
_get_underlying_float,
int_ptr_to_val, int_ptr_to_val,
val_to_int_ptr, val_to_int_ptr,
) )
from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import ( from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix, _check_dtypes_match,
_check_linalg_matrix,
_copy_to_fortran_order_even_if_1d, _copy_to_fortran_order_even_if_1d,
_solve_check, _solve_check,
) )
...@@ -49,10 +50,10 @@ def posv_impl( ...@@ -49,10 +50,10 @@ def posv_impl(
tuple[np.ndarray, np.ndarray, int], tuple[np.ndarray, np.ndarray, int],
]: ]:
ensure_lapack() ensure_lapack()
_check_scipy_linalg_matrix(A, "solve") _check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
_check_scipy_linalg_matrix(B, "solve") _check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
_check_dtypes_match((A, B), func_name="solve")
dtype = A.dtype dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_posv = _LAPACK().numba_xposv(dtype) numba_posv = _LAPACK().numba_xposv(dtype)
def impl( def impl(
...@@ -99,9 +100,9 @@ def posv_impl( ...@@ -99,9 +100,9 @@ def posv_impl(
UPLO, UPLO,
N, N,
NRHS, NRHS,
A_copy.view(w_type).ctypes, A_copy.ctypes,
LDA, LDA,
B_copy.view(w_type).ctypes, B_copy.ctypes,
LDB, LDB,
INFO, INFO,
) )
...@@ -127,9 +128,8 @@ def pocon_impl( ...@@ -127,9 +128,8 @@ def pocon_impl(
A: np.ndarray, anorm: float A: np.ndarray, anorm: float
) -> Callable[[np.ndarray, float], tuple[np.ndarray, int]]: ) -> Callable[[np.ndarray, float], tuple[np.ndarray, int]]:
ensure_lapack() ensure_lapack()
_check_scipy_linalg_matrix(A, "pocon") _check_linalg_matrix(A, ndim=2, dtype=Float, func_name="pocon")
dtype = A.dtype dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_pocon = _LAPACK().numba_xpocon(dtype) numba_pocon = _LAPACK().numba_xpocon(dtype)
def impl(A: np.ndarray, anorm: float): def impl(A: np.ndarray, anorm: float):
...@@ -148,11 +148,11 @@ def pocon_impl( ...@@ -148,11 +148,11 @@ def pocon_impl(
numba_pocon( numba_pocon(
UPLO, UPLO,
N, N,
A_copy.view(w_type).ctypes, A_copy.ctypes,
LDA, LDA,
ANORM.view(w_type).ctypes, ANORM.ctypes,
RCOND.view(w_type).ctypes, RCOND.ctypes,
WORK.view(w_type).ctypes, WORK.ctypes,
IWORK.ctypes, IWORK.ctypes,
INFO, INFO,
) )
...@@ -196,8 +196,9 @@ def solve_psd_impl( ...@@ -196,8 +196,9 @@ def solve_psd_impl(
transposed: bool, transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]: ) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]:
ensure_lapack() ensure_lapack()
_check_scipy_linalg_matrix(A, "solve") _check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
_check_scipy_linalg_matrix(B, "solve") _check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
_check_dtypes_match((A, B), func_name="solve")
def impl( def impl(
A: np.ndarray, A: np.ndarray,
......
...@@ -2,19 +2,20 @@ from collections.abc import Callable ...@@ -2,19 +2,20 @@ from collections.abc import Callable
import numpy as np import numpy as np
from numba.core.extending import overload from numba.core.extending import overload
from numba.core.types import Float
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy import linalg from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import ( from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK, _LAPACK,
_get_underlying_float,
int_ptr_to_val, int_ptr_to_val,
val_to_int_ptr, val_to_int_ptr,
) )
from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import ( from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix, _check_dtypes_match,
_check_linalg_matrix,
_copy_to_fortran_order_even_if_1d, _copy_to_fortran_order_even_if_1d,
_solve_check, _solve_check,
) )
...@@ -37,10 +38,10 @@ def sysv_impl( ...@@ -37,10 +38,10 @@ def sysv_impl(
tuple[np.ndarray, np.ndarray, np.ndarray, int], tuple[np.ndarray, np.ndarray, np.ndarray, int],
]: ]:
ensure_lapack() ensure_lapack()
_check_scipy_linalg_matrix(A, "sysv") _check_linalg_matrix(A, ndim=2, dtype=Float, func_name="sysv")
_check_scipy_linalg_matrix(B, "sysv") _check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="sysv")
_check_dtypes_match((A, B), func_name="sysv")
dtype = A.dtype dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_sysv = _LAPACK().numba_xsysv(dtype) numba_sysv = _LAPACK().numba_xsysv(dtype)
def impl( def impl(
...@@ -84,12 +85,12 @@ def sysv_impl( ...@@ -84,12 +85,12 @@ def sysv_impl(
UPLO, UPLO,
N, N,
NRHS, NRHS,
A_copy.view(w_type).ctypes, A_copy.ctypes,
LDA, LDA,
IPIV.ctypes, IPIV.ctypes,
B_copy.view(w_type).ctypes, B_copy.ctypes,
LDB, LDB,
WORK.view(w_type).ctypes, WORK.ctypes,
LWORK, LWORK,
INFO, INFO,
) )
...@@ -103,12 +104,12 @@ def sysv_impl( ...@@ -103,12 +104,12 @@ def sysv_impl(
UPLO, UPLO,
N, N,
NRHS, NRHS,
A_copy.view(w_type).ctypes, A_copy.ctypes,
LDA, LDA,
IPIV.ctypes, IPIV.ctypes,
B_copy.view(w_type).ctypes, B_copy.ctypes,
LDB, LDB,
WORK.view(w_type).ctypes, WORK.ctypes,
LWORK, LWORK,
INFO, INFO,
) )
...@@ -133,9 +134,8 @@ def sycon_impl( ...@@ -133,9 +134,8 @@ def sycon_impl(
A: np.ndarray, ipiv: np.ndarray, anorm: float A: np.ndarray, ipiv: np.ndarray, anorm: float
) -> Callable[[np.ndarray, np.ndarray, float], tuple[np.ndarray, int]]: ) -> Callable[[np.ndarray, np.ndarray, float], tuple[np.ndarray, int]]:
ensure_lapack() ensure_lapack()
_check_scipy_linalg_matrix(A, "sycon") _check_linalg_matrix(A, ndim=2, dtype=Float, func_name="sycon")
dtype = A.dtype dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_sycon = _LAPACK().numba_xsycon(dtype) numba_sycon = _LAPACK().numba_xsycon(dtype)
def impl(A: np.ndarray, ipiv: np.ndarray, anorm: float) -> tuple[np.ndarray, int]: def impl(A: np.ndarray, ipiv: np.ndarray, anorm: float) -> tuple[np.ndarray, int]:
...@@ -154,12 +154,12 @@ def sycon_impl( ...@@ -154,12 +154,12 @@ def sycon_impl(
numba_sycon( numba_sycon(
UPLO, UPLO,
N, N,
A_copy.view(w_type).ctypes, A_copy.ctypes,
LDA, LDA,
ipiv.ctypes, ipiv.ctypes,
ANORM.view(w_type).ctypes, ANORM.ctypes,
RCOND.view(w_type).ctypes, RCOND.ctypes,
WORK.view(w_type).ctypes, WORK.ctypes,
IWORK.ctypes, IWORK.ctypes,
INFO, INFO,
) )
...@@ -203,8 +203,9 @@ def solve_symmetric_impl( ...@@ -203,8 +203,9 @@ def solve_symmetric_impl(
transposed: bool, transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]: ) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]:
ensure_lapack() ensure_lapack()
_check_scipy_linalg_matrix(A, "solve") _check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
_check_scipy_linalg_matrix(B, "solve") _check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
_check_dtypes_match((A, B), func_name="solve")
def impl( def impl(
A: np.ndarray, A: np.ndarray,
......
import numpy as np import numpy as np
from numba.core import types from numba.core import types
from numba.core.extending import overload from numba.core.extending import overload
from numba.core.types import Float
from numba.np.linalg import ensure_lapack from numba.np.linalg import ensure_lapack
from scipy import linalg from scipy import linalg
from pytensor.link.numba.dispatch.linalg._LAPACK import ( from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK, _LAPACK,
_get_underlying_float,
int_ptr_to_val, int_ptr_to_val,
val_to_int_ptr, val_to_int_ptr,
) )
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import ( from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix, _check_dtypes_match,
_check_linalg_matrix,
_copy_to_fortran_order_even_if_1d, _copy_to_fortran_order_even_if_1d,
_solve_check, _solve_check,
_trans_char_to_int, _trans_char_to_int,
...@@ -45,10 +46,10 @@ def _solve_triangular( ...@@ -45,10 +46,10 @@ def _solve_triangular(
def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b): def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
ensure_lapack() ensure_lapack()
_check_scipy_linalg_matrix(A, "solve_triangular") _check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve_triangular")
_check_scipy_linalg_matrix(B, "solve_triangular") _check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve_triangular")
_check_dtypes_match((A, B), func_name="solve_triangular")
dtype = A.dtype dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_trtrs = _LAPACK().numba_xtrtrs(dtype) numba_trtrs = _LAPACK().numba_xtrtrs(dtype)
if isinstance(dtype, types.Complex): if isinstance(dtype, types.Complex):
# If you want to make this work with complex numbers make sure you handle the c_contiguous trick correctly # If you want to make this work with complex numbers make sure you handle the c_contiguous trick correctly
...@@ -99,9 +100,9 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b ...@@ -99,9 +100,9 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
DIAG, DIAG,
N, N,
NRHS, NRHS,
A_f.view(w_type).ctypes, A_f.ctypes,
LDA, LDA,
B_copy.view(w_type).ctypes, B_copy.ctypes,
LDB, LDB,
INFO, INFO,
) )
......
...@@ -2,21 +2,24 @@ from collections.abc import Callable ...@@ -2,21 +2,24 @@ from collections.abc import Callable
import numpy as np import numpy as np
from numba.core.extending import overload from numba.core.extending import overload
from numba.core.types import Float, int32
from numba.np.linalg import ensure_lapack from numba.np.linalg import ensure_lapack
from numpy import ndarray from numpy import ndarray
from scipy import linalg from scipy import linalg
from pytensor import config
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch import numba_funcify from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import generate_fallback_impl
from pytensor.link.numba.dispatch.linalg._LAPACK import ( from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK, _LAPACK,
_get_underlying_float,
int_ptr_to_val, int_ptr_to_val,
val_to_int_ptr, val_to_int_ptr,
) )
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import ( from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix, _check_dtypes_match,
_check_linalg_matrix,
_copy_to_fortran_order_even_if_1d, _copy_to_fortran_order_even_if_1d,
_solve_check, _solve_check,
_trans_char_to_int, _trans_char_to_int,
...@@ -63,11 +66,11 @@ def gttrf_impl( ...@@ -63,11 +66,11 @@ def gttrf_impl(
tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int], tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int],
]: ]:
ensure_lapack() ensure_lapack()
_check_scipy_linalg_matrix(dl, "gttrf") _check_linalg_matrix(dl, ndim=1, dtype=Float, func_name="gttrf")
_check_scipy_linalg_matrix(d, "gttrf") _check_linalg_matrix(d, ndim=1, dtype=Float, func_name="gttrf")
_check_scipy_linalg_matrix(du, "gttrf") _check_linalg_matrix(du, ndim=1, dtype=Float, func_name="gttrf")
_check_dtypes_match((dl, d, du), func_name="gttrf")
dtype = d.dtype dtype = d.dtype
w_type = _get_underlying_float(dtype)
numba_gttrf = _LAPACK().numba_xgttrf(dtype) numba_gttrf = _LAPACK().numba_xgttrf(dtype)
def impl( def impl(
...@@ -94,10 +97,10 @@ def gttrf_impl( ...@@ -94,10 +97,10 @@ def gttrf_impl(
numba_gttrf( numba_gttrf(
val_to_int_ptr(n), val_to_int_ptr(n),
dl.view(w_type).ctypes, dl.ctypes,
d.view(w_type).ctypes, d.ctypes,
du.view(w_type).ctypes, du.ctypes,
du2.view(w_type).ctypes, du2.ctypes,
ipiv.ctypes, ipiv.ctypes,
info, info,
) )
...@@ -136,13 +139,14 @@ def gttrs_impl( ...@@ -136,13 +139,14 @@ def gttrs_impl(
tuple[ndarray, int], tuple[ndarray, int],
]: ]:
ensure_lapack() ensure_lapack()
_check_scipy_linalg_matrix(dl, "gttrs") _check_linalg_matrix(dl, ndim=1, dtype=Float, func_name="gttrs")
_check_scipy_linalg_matrix(d, "gttrs") _check_linalg_matrix(d, ndim=1, dtype=Float, func_name="gttrs")
_check_scipy_linalg_matrix(du, "gttrs") _check_linalg_matrix(du, ndim=1, dtype=Float, func_name="gttrs")
_check_scipy_linalg_matrix(du2, "gttrs") _check_linalg_matrix(du2, ndim=1, dtype=Float, func_name="gttrs")
_check_scipy_linalg_matrix(b, "gttrs") _check_linalg_matrix(b, ndim=(1, 2), dtype=Float, func_name="gttrs")
_check_dtypes_match((dl, d, du, du2, b), func_name="gttrs")
_check_linalg_matrix(ipiv, ndim=1, dtype=int32, func_name="gttrs")
dtype = d.dtype dtype = d.dtype
w_type = _get_underlying_float(dtype)
numba_gttrs = _LAPACK().numba_xgttrs(dtype) numba_gttrs = _LAPACK().numba_xgttrs(dtype)
def impl( def impl(
...@@ -181,12 +185,12 @@ def gttrs_impl( ...@@ -181,12 +185,12 @@ def gttrs_impl(
val_to_int_ptr(_trans_char_to_int(trans)), val_to_int_ptr(_trans_char_to_int(trans)),
val_to_int_ptr(n), val_to_int_ptr(n),
val_to_int_ptr(nrhs), val_to_int_ptr(nrhs),
dl.view(w_type).ctypes, dl.ctypes,
d.view(w_type).ctypes, d.ctypes,
du.view(w_type).ctypes, du.ctypes,
du2.view(w_type).ctypes, du2.ctypes,
ipiv.ctypes, ipiv.ctypes,
b.view(w_type).ctypes, b.ctypes,
val_to_int_ptr(n), val_to_int_ptr(n),
info, info,
) )
...@@ -222,12 +226,13 @@ def gtcon_impl( ...@@ -222,12 +226,13 @@ def gtcon_impl(
[ndarray, ndarray, ndarray, ndarray, ndarray, float, str], tuple[ndarray, int] [ndarray, ndarray, ndarray, ndarray, ndarray, float, str], tuple[ndarray, int]
]: ]:
ensure_lapack() ensure_lapack()
_check_scipy_linalg_matrix(dl, "gtcon") _check_linalg_matrix(dl, ndim=1, dtype=Float, func_name="gtcon")
_check_scipy_linalg_matrix(d, "gtcon") _check_linalg_matrix(d, ndim=1, dtype=Float, func_name="gtcon")
_check_scipy_linalg_matrix(du, "gtcon") _check_linalg_matrix(du, ndim=1, dtype=Float, func_name="gtcon")
_check_scipy_linalg_matrix(du2, "gtcon") _check_linalg_matrix(du2, ndim=1, dtype=Float, func_name="gtcon")
_check_dtypes_match((dl, d, du, du2), func_name="gtcon")
_check_linalg_matrix(ipiv, ndim=1, dtype=int32, func_name="gtcon")
dtype = d.dtype dtype = d.dtype
w_type = _get_underlying_float(dtype)
numba_gtcon = _LAPACK().numba_xgtcon(dtype) numba_gtcon = _LAPACK().numba_xgtcon(dtype)
def impl( def impl(
...@@ -248,14 +253,14 @@ def gtcon_impl( ...@@ -248,14 +253,14 @@ def gtcon_impl(
numba_gtcon( numba_gtcon(
val_to_int_ptr(ord(norm)), val_to_int_ptr(ord(norm)),
val_to_int_ptr(n), val_to_int_ptr(n),
dl.view(w_type).ctypes, dl.ctypes,
d.view(w_type).ctypes, d.ctypes,
du.view(w_type).ctypes, du.ctypes,
du2.view(w_type).ctypes, du2.ctypes,
ipiv.ctypes, ipiv.ctypes,
np.array(anorm, dtype=dtype).view(w_type).ctypes, np.array(anorm, dtype=dtype).ctypes,
rcond.view(w_type).ctypes, rcond.ctypes,
work.view(w_type).ctypes, work.ctypes,
iwork.ctypes, iwork.ctypes,
info, info,
) )
...@@ -300,8 +305,9 @@ def _tridiagonal_solve_impl( ...@@ -300,8 +305,9 @@ def _tridiagonal_solve_impl(
transposed: bool, transposed: bool,
) -> Callable[[ndarray, ndarray, bool, bool, bool, bool, bool], ndarray]: ) -> Callable[[ndarray, ndarray, bool, bool, bool, bool, bool], ndarray]:
ensure_lapack() ensure_lapack()
_check_scipy_linalg_matrix(A, "solve") _check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
_check_scipy_linalg_matrix(B, "solve") _check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
_check_dtypes_match((A, B), func_name="solve")
def impl( def impl(
A: ndarray, A: ndarray,
...@@ -342,12 +348,26 @@ def _tridiagonal_solve_impl( ...@@ -342,12 +348,26 @@ def _tridiagonal_solve_impl(
@numba_funcify.register(LUFactorTridiagonal) @numba_funcify.register(LUFactorTridiagonal)
def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs): def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
if any(i.type.numpy_dtype.kind == "c" for i in node.inputs):
return generate_fallback_impl(op, node=node)
overwrite_dl = op.overwrite_dl overwrite_dl = op.overwrite_dl
overwrite_d = op.overwrite_d overwrite_d = op.overwrite_d
overwrite_du = op.overwrite_du overwrite_du = op.overwrite_du
out_dtype = node.outputs[1].type.numpy_dtype
must_cast_inputs = tuple(inp.type.numpy_dtype != out_dtype for inp in node.inputs)
if any(must_cast_inputs) and config.compiler_verbose:
print("LUFactorTridiagonal requires casting at least one input") # noqa: T201
@numba_basic.numba_njit(cache=False) @numba_basic.numba_njit(cache=False)
def lu_factor_tridiagonal(dl, d, du): def lu_factor_tridiagonal(dl, d, du):
if must_cast_inputs[0]:
d = d.astype(out_dtype)
if must_cast_inputs[1]:
dl = dl.astype(out_dtype)
if must_cast_inputs[2]:
du = du.astype(out_dtype)
dl, d, du, du2, ipiv, _ = _gttrf( dl, d, du, du2, ipiv, _ = _gttrf(
dl, dl,
d, d,
...@@ -365,11 +385,34 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs): ...@@ -365,11 +385,34 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
def numba_funcify_SolveLUFactorTridiagonal( def numba_funcify_SolveLUFactorTridiagonal(
op: SolveLUFactorTridiagonal, node, **kwargs op: SolveLUFactorTridiagonal, node, **kwargs
): ):
if any(i.type.numpy_dtype.kind == "c" for i in node.inputs):
return generate_fallback_impl(op, node=node)
out_dtype = node.outputs[0].type.numpy_dtype
overwrite_b = op.overwrite_b overwrite_b = op.overwrite_b
transposed = op.transposed transposed = op.transposed
must_cast_inputs = tuple(
inp.type.numpy_dtype != (np.int32 if i == 4 else out_dtype)
for i, inp in enumerate(node.inputs)
)
if any(must_cast_inputs) and config.compiler_verbose:
print("SolveLUFactorTridiagonal requires casting at least one input") # noqa: T201
@numba_basic.numba_njit(cache=False) @numba_basic.numba_njit(cache=False)
def solve_lu_factor_tridiagonal(dl, d, du, du2, ipiv, b): def solve_lu_factor_tridiagonal(dl, d, du, du2, ipiv, b):
if must_cast_inputs[0]:
dl = dl.astype(out_dtype)
if must_cast_inputs[1]:
d = d.astype(out_dtype)
if must_cast_inputs[2]:
du = du.astype(out_dtype)
if must_cast_inputs[3]:
du2 = du2.astype(out_dtype)
if must_cast_inputs[4]:
ipiv = ipiv.astype("int32")
if must_cast_inputs[5]:
b = b.astype(out_dtype)
x, _ = _gttrs( x, _ = _gttrs(
dl, dl,
d, d,
......
from collections.abc import Callable from collections.abc import Callable, Sequence
import numba import numba
from numba.core import types from numba.core import types
...@@ -32,24 +32,34 @@ def _trans_char_to_int(trans): ...@@ -32,24 +32,34 @@ def _trans_char_to_int(trans):
return ord("C") return ord("C")
def _check_scipy_linalg_matrix(a, func_name): def _check_linalg_matrix(a, *, ndim: int | Sequence[int], dtype, func_name):
""" """
Adapted from https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L831 Adapted from https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L831
""" """
prefix = "scipy.linalg"
# Unpack optional type
if isinstance(a, types.Optional):
a = a.type
if not isinstance(a, types.Array): if not isinstance(a, types.Array):
msg = f"{prefix}.{func_name}() only supported for array types" msg = f"{func_name} only supported for array types"
raise numba.TypingError(msg, highlighting=False) raise numba.TypingError(msg, highlighting=False)
if a.ndim not in [1, 2]: ndim_msg = f"{func_name} only supported on {ndim}d arrays, got {a.ndim}."
msg = ( if isinstance(ndim, int):
f"{prefix}.{func_name}() only supported on 1d or 2d arrays, found {a.ndim}." if a.ndim != ndim:
) raise numba.TypingError(ndim_msg, highlighting=False)
raise numba.TypingError(msg, highlighting=False) elif a.ndim not in ndim:
if not isinstance(a.dtype, types.Float | types.Complex): raise numba.TypingError(ndim_msg, highlighting=False)
msg = f"{prefix}.{func_name}() only supported on float and complex arrays."
dtype_msg = f"{func_name} only supported for {dtype}, got {a.dtype}."
if isinstance(dtype, type | tuple):
if not isinstance(a.dtype, dtype):
raise numba.TypingError(dtype_msg, highlighting=False)
elif a.dtype != dtype:
raise numba.TypingError(dtype_msg, highlighting=False)
def _check_dtypes_match(arrays: Sequence, func_name="cho_solve"):
dtypes = [a.dtype for a in arrays]
first_dtype = dtypes[0]
for other_dtype in dtypes[1:]:
if first_dtype != other_dtype:
msg = f"{func_name} only supported for matching dtypes, got {dtypes}"
raise numba.TypingError(msg, highlighting=False) raise numba.TypingError(msg, highlighting=False)
......
...@@ -63,13 +63,20 @@ def numba_funcify_Cholesky(op, node, **kwargs): ...@@ -63,13 +63,20 @@ def numba_funcify_Cholesky(op, node, **kwargs):
check_finite = op.check_finite check_finite = op.check_finite
on_error = op.on_error on_error = op.on_error
dtype = node.inputs[0].dtype inp_dtype = node.inputs[0].type.numpy_dtype
if dtype in complex_dtypes: if inp_dtype.kind == "c":
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
discrete_inp = inp_dtype.kind in "ibu"
if discrete_inp and config.compiler_verbose:
print("Cholesky requires casting discrete input to float") # noqa: T201
out_dtype = node.outputs[0].type.numpy_dtype
@numba_basic.numba_njit @numba_basic.numba_njit
def cholesky(a): def cholesky(a):
if check_finite: if discrete_inp:
a = a.astype(out_dtype)
elif check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError( raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) found in input to cholesky" "Non-numeric values (nan or inf) found in input to cholesky"
...@@ -112,18 +119,24 @@ def pivot_to_permutation(op, node, **kwargs): ...@@ -112,18 +119,24 @@ def pivot_to_permutation(op, node, **kwargs):
@numba_funcify.register(LU) @numba_funcify.register(LU)
def numba_funcify_LU(op, node, **kwargs): def numba_funcify_LU(op, node, **kwargs):
inp_dtype = node.inputs[0].type.numpy_dtype
if inp_dtype.kind == "c":
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
discrete_inp = inp_dtype.kind in "ibu"
if discrete_inp and config.compiler_verbose:
print("LU requires casting discrete input to float") # noqa: T201
out_dtype = node.outputs[0].type.numpy_dtype
permute_l = op.permute_l permute_l = op.permute_l
check_finite = op.check_finite check_finite = op.check_finite
p_indices = op.p_indices p_indices = op.p_indices
overwrite_a = op.overwrite_a overwrite_a = op.overwrite_a
dtype = node.inputs[0].dtype
if dtype in complex_dtypes:
NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
@numba_basic.numba_njit @numba_basic.numba_njit
def lu(a): def lu(a):
if check_finite: if discrete_inp:
a = a.astype(out_dtype)
elif check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError( raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) found in input to lu" "Non-numeric values (nan or inf) found in input to lu"
...@@ -161,16 +174,22 @@ def numba_funcify_LU(op, node, **kwargs): ...@@ -161,16 +174,22 @@ def numba_funcify_LU(op, node, **kwargs):
@numba_funcify.register(LUFactor) @numba_funcify.register(LUFactor)
def numba_funcify_LUFactor(op, node, **kwargs): def numba_funcify_LUFactor(op, node, **kwargs):
dtype = node.inputs[0].dtype inp_dtype = node.inputs[0].type.numpy_dtype
if inp_dtype.kind == "c":
NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
discrete_inp = inp_dtype.kind in "ibu"
if discrete_inp and config.compiler_verbose:
print("LUFactor requires casting discrete input to float") # noqa: T201
out_dtype = node.outputs[0].type.numpy_dtype
check_finite = op.check_finite check_finite = op.check_finite
overwrite_a = op.overwrite_a overwrite_a = op.overwrite_a
if dtype in complex_dtypes:
NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
@numba_basic.numba_njit @numba_basic.numba_njit
def lu_factor(a): def lu_factor(a):
if check_finite: if discrete_inp:
a = a.astype(out_dtype)
elif check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError( raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) found in input to cholesky" "Non-numeric values (nan or inf) found in input to cholesky"
...@@ -207,6 +226,21 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs): ...@@ -207,6 +226,21 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs):
@numba_funcify.register(Solve) @numba_funcify.register(Solve)
def numba_funcify_Solve(op, node, **kwargs): def numba_funcify_Solve(op, node, **kwargs):
A_dtype, b_dtype = (i.numpy_dtype for i in node.inputs)
out_dtype = node.outputs[0].type.numpy_dtype
if A_dtype.kind == "c" or b_dtype.kind == "c":
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
must_cast_A = A_dtype != out_dtype
if must_cast_A and config.compiler_verbose:
print("Solve requires casting first input `A`") # noqa: T201
must_cast_B = b_dtype != out_dtype
if must_cast_B and config.compiler_verbose:
print("Solve requires casting second input `b`") # noqa: T201
check_finite = op.check_finite
overwrite_a = op.overwrite_a
assume_a = op.assume_a assume_a = op.assume_a
lower = op.lower lower = op.lower
check_finite = op.check_finite check_finite = op.check_finite
...@@ -214,10 +248,6 @@ def numba_funcify_Solve(op, node, **kwargs): ...@@ -214,10 +248,6 @@ def numba_funcify_Solve(op, node, **kwargs):
overwrite_b = op.overwrite_b overwrite_b = op.overwrite_b
transposed = False # TODO: Solve doesnt currently allow the transposed argument transposed = False # TODO: Solve doesnt currently allow the transposed argument
dtype = node.inputs[0].dtype
if dtype in complex_dtypes:
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
if assume_a == "gen": if assume_a == "gen":
solve_fn = _solve_gen solve_fn = _solve_gen
elif assume_a == "sym": elif assume_a == "sym":
...@@ -239,6 +269,10 @@ def numba_funcify_Solve(op, node, **kwargs): ...@@ -239,6 +269,10 @@ def numba_funcify_Solve(op, node, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def solve(a, b): def solve(a, b):
if must_cast_A:
a = a.astype(out_dtype)
if must_cast_B:
b = b.astype(out_dtype)
if check_finite: if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError( raise np.linalg.LinAlgError(
...@@ -263,14 +297,24 @@ def numba_funcify_SolveTriangular(op, node, **kwargs): ...@@ -263,14 +297,24 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
overwrite_b = op.overwrite_b overwrite_b = op.overwrite_b
b_ndim = op.b_ndim b_ndim = op.b_ndim
dtype = node.inputs[0].dtype A_dtype, b_dtype = (i.numpy_dtype for i in node.inputs)
if dtype in complex_dtypes: out_dtype = node.outputs[0].type.numpy_dtype
raise NotImplementedError(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op="Solve Triangular") if A_dtype.kind == "c" or b_dtype.kind == "c":
) raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
must_cast_A = A_dtype != out_dtype
if must_cast_A and config.compiler_verbose:
print("SolveTriangular requires casting first input `A`") # noqa: T201
must_cast_B = b_dtype != out_dtype
if must_cast_B and config.compiler_verbose:
print("SolveTriangular requires casting second input `b`") # noqa: T201
@numba_basic.numba_njit @numba_basic.numba_njit
def solve_triangular(a, b): def solve_triangular(a, b):
if must_cast_A:
a = a.astype(out_dtype)
if must_cast_B:
b = b.astype(out_dtype)
if check_finite: if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError( raise np.linalg.LinAlgError(
...@@ -302,24 +346,42 @@ def numba_funcify_CholeskySolve(op, node, **kwargs): ...@@ -302,24 +346,42 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
overwrite_b = op.overwrite_b overwrite_b = op.overwrite_b
check_finite = op.check_finite check_finite = op.check_finite
dtype = node.inputs[0].dtype c_dtype, b_dtype = (i.type.numpy_dtype for i in node.inputs)
if dtype in complex_dtypes: out_dtype = node.outputs[0].type.numpy_dtype
if c_dtype.kind == "c" or b_dtype.kind == "c":
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
must_cast_c = c_dtype != out_dtype
if must_cast_c and config.compiler_verbose:
print("CholeskySolve requires casting first input `c`") # noqa: T201
must_cast_b = b_dtype != out_dtype
if must_cast_b and config.compiler_verbose:
print("CholeskySolve requires casting second input `b`") # noqa: T201
@numba_basic.numba_njit @numba_basic.numba_njit
def cho_solve(c, b): def cho_solve(c, b):
if must_cast_c:
c = c.astype(out_dtype)
if check_finite: if check_finite:
if np.any(np.bitwise_or(np.isinf(c), np.isnan(c))): if np.any(np.bitwise_or(np.isinf(c), np.isnan(c))):
raise np.linalg.LinAlgError( raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input A to cho_solve" "Non-numeric values (nan or inf) in input A to cho_solve"
) )
if must_cast_b:
b = b.astype(out_dtype)
if check_finite:
if np.any(np.bitwise_or(np.isinf(b), np.isnan(b))): if np.any(np.bitwise_or(np.isinf(b), np.isnan(b))):
raise np.linalg.LinAlgError( raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input b to cho_solve" "Non-numeric values (nan or inf) in input b to cho_solve"
) )
return _cho_solve( return _cho_solve(
c, b, lower=lower, overwrite_b=overwrite_b, check_finite=check_finite c,
b,
lower=lower,
overwrite_b=overwrite_b,
check_finite=check_finite,
) )
return cho_solve return cho_solve
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论