Unverified 提交 5137ed3e authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Allow caching of cython functions in numba backend (#1807)

* Add machinery to enable numba caching of function pointers * Numba cache for Cholesky * Numba cache for SolveTriangular * Numba cache for CholeskySolve * Numba cache for solve helpers * Numba cache for GECON * Numba cache for lu_factor * Numba cache for Solve when assume_a="gen" * Numba cache for Solve when assume_a="sym" * Numba cache for Solve when assume_a="pos" * Numba cache for Solve when assume_a='tri' * Numba cache for QR * Clean up obsolete code * Feedback * More feedback * Rename `cache_key_lit` -> `unique_func_name_lit`
上级 0b439c0f
......@@ -5,6 +5,9 @@ from tempfile import NamedTemporaryFile
from typing import Any
from weakref import WeakKeyDictionary
import numba
from llvmlite import ir
from numba.core import cgutils
from numba.core.caching import CacheImpl, _CacheLocator
from pytensor.configdefaults import config
......@@ -127,3 +130,49 @@ def compile_numba_function_src(
CACHED_SRC_FUNCTIONS[res] = cache_key
return res # type: ignore
@numba.extending.intrinsic(prefer_literal=True)
def _call_cached_ptr(typingctx, get_ptr_func, func_type_ref, unique_func_name_lit):
"""
Enable caching of function pointers returned by `get_ptr_func`.
When one of our Numba-dispatched functions depends on a pointer to a compiled function function (e.g. when we call
cython_lapack routines), numba will refuse to cache the function, because the pointer may change between runs.
This intrinsic allows us to cache the pointer ourselves, by storing it in a global variable keyed by a literal
`unique_func_name_lit`. The first time the intrinsic is called, it will call `get_ptr_func` to get the pointer, store it
in the global variable, and return it. Subsequent calls will load the pointer from the global variable.
"""
func_type = func_type_ref.instance_type
cache_key = unique_func_name_lit.literal_value
def codegen(context, builder, signature, args):
ptr_ty = ir.PointerType(ir.IntType(8))
null = ptr_ty(None)
align = 64
mod = builder.module
var = cgutils.add_global_variable(mod, ptr_ty, f"_ptr_cache_{cache_key}")
var.align = align
var.linkage = "private"
var.initializer = null
var_val = builder.load_atomic(var, "acquire", align)
result_ptr = cgutils.alloca_once_value(builder, var_val)
with builder.if_then(builder.icmp_signed("==", var_val, null), likely=False):
sig = typingctx.resolve_function_type(get_ptr_func, [], {})
f = context.get_function(get_ptr_func, sig)
new_ptr = f(builder, [])
new_ptr = builder.inttoptr(new_ptr, ptr_ty)
builder.store_atomic(new_ptr, var, "release", align)
builder.store(new_ptr, result_ptr)
sfunc = cgutils.create_struct_proxy(func_type)(context, builder)
sfunc.c_addr = builder.load(result_ptr)
return sfunc._getvalue()
sig = func_type(get_ptr_func, func_type_ref, unique_func_name_lit)
return sig, codegen
import ctypes
import numba
import numpy as np
from numba.core import cgutils, types
from numba.core.extending import get_cython_function_address, intrinsic
from numba.core.registry import CPUDispatcher
from numba.core.types import Complex
from numba.np.linalg import ensure_lapack, get_blas_kind
from pytensor.link.numba.cache import _call_cached_ptr
from pytensor.link.numba.dispatch import basic as numba_basic
nb_i32 = types.int32
nb_i32p = types.CPointer(nb_i32)
_PTR = ctypes.POINTER
nb_f32 = types.float32
nb_f32p = types.CPointer(nb_f32)
_dbl = ctypes.c_double
_float = ctypes.c_float
_char = ctypes.c_char
_int = ctypes.c_int
nb_f64 = types.float64
nb_f64p = types.CPointer(nb_f64)
_ptr_float = _PTR(_float)
_ptr_dbl = _PTR(_dbl)
_ptr_char = _PTR(_char)
_ptr_int = _PTR(_int)
nb_c64 = types.complex64
nb_c64p = types.CPointer(nb_c64)
nb_c128 = types.complex128
nb_c128p = types.CPointer(nb_c128)
def _get_lapack_ptr_and_ptr_type(dtype, name):
def get_lapack_ptr(dtype, name):
d = get_blas_kind(dtype)
func_name = f"{d}{name}"
float_pointer = _get_float_pointer_for_dtype(d)
lapack_ptr = get_cython_function_address("scipy.linalg.cython_lapack", func_name)
return lapack_ptr, float_pointer
return lapack_ptr
def _get_underlying_float(dtype):
......@@ -40,19 +44,18 @@ def _get_underlying_float(dtype):
return np.dtype(out_type)
def _get_float_pointer_for_dtype(blas_dtype):
if blas_dtype in ["s", "c"]:
return _ptr_float
elif blas_dtype in ["d", "z"]:
return _ptr_dbl
def _get_output_ctype(dtype):
s_dtype = str(dtype)
if s_dtype in ["float32", "complex64"]:
return _float
elif s_dtype in ["float64", "complex128"]:
return _dbl
def _get_nb_float_from_dtype(blas_dtype, return_pointer=True):
match blas_dtype:
case "s":
return nb_f32p if return_pointer else nb_f32
case "d":
return nb_f64p if return_pointer else nb_f64
case "c":
return nb_c64p if return_pointer else nb_c64
case "z":
return nb_c128p if return_pointer else nb_c128
case _:
raise ValueError(f"Unsupported BLAS dtype: {blas_dtype}")
@intrinsic
......@@ -136,422 +139,815 @@ class _LAPACK:
ensure_lapack()
@classmethod
def numba_xtrtrs(cls, dtype):
def numba_xtrtrs(cls, dtype) -> CPUDispatcher:
"""
Solve a triangular system of equations of the form A @ X = B or A.T @ X = B.
Called by scipy.linalg.solve_triangular
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "trtrs")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # UPLO
_ptr_int, # TRANS
_ptr_int, # DIAG
_ptr_int, # N
_ptr_int, # NRHS
float_pointer, # A
_ptr_int, # LDA
float_pointer, # B
_ptr_int, # LDB
_ptr_int, # INFO
kind = get_blas_kind(dtype)
float_ptr = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}trtrs"
@numba_basic.numba_njit
def get_trtrs_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "trtrs")
return ptr
trtrs_function_type = types.FunctionType(
types.void(
nb_i32p, # UPLO
nb_i32p, # TRANS
nb_i32p, # DIAG
nb_i32p, # N
nb_i32p, # NRHS
float_ptr, # A
nb_i32p, # LDA
float_ptr, # B
nb_i32p, # LDB
nb_i32p, # INFO
)
)
return functype(lapack_ptr)
@numba_basic.numba_njit
def trtrs(UPLO, TRANS, DIAG, N, NRHS, A, LDA, B, LDB, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_trtrs_pointer,
func_type_ref=trtrs_function_type,
unique_func_name_lit=unique_func_name,
)
fn(UPLO, TRANS, DIAG, N, NRHS, A, LDA, B, LDB, INFO)
return trtrs
@classmethod
def numba_xpotrf(cls, dtype):
def numba_xpotrf(cls, dtype) -> CPUDispatcher:
"""
Compute the Cholesky factorization of a real symmetric positive definite matrix.
Called by scipy.linalg.cholesky
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrf")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # UPLO,
_ptr_int, # N
float_pointer, # A
_ptr_int, # LDA
_ptr_int, # INFO
kind = get_blas_kind(dtype)
float_ptr = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}potrf"
@numba_basic.numba_njit
def get_potrf_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "potrf")
return ptr
potrf_function_type = types.FunctionType(
types.void(
nb_i32p, # UPLO
nb_i32p, # N
float_ptr, # A
nb_i32p, # LDA
nb_i32p, # INFO
)
)
@numba_basic.numba_njit
def potrf(UPLO, N, A, LDA, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_potrf_pointer,
func_type_ref=potrf_function_type,
unique_func_name_lit=unique_func_name,
)
return functype(lapack_ptr)
fn(UPLO, N, A, LDA, INFO)
return potrf
@classmethod
def numba_xpotrs(cls, dtype):
def numba_xpotrs(cls, dtype) -> CPUDispatcher:
"""
Solve a system of linear equations A @ X = B with a symmetric positive definite matrix A using the Cholesky
factorization computed by numba_potrf.
Called by scipy.linalg.cho_solve
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrs")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # UPLO
_ptr_int, # N
_ptr_int, # NRHS
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}potrs"
@numba_basic.numba_njit
def get_potrs_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "potrs")
return ptr
potrs_function_type = types.FunctionType(
types.void(
nb_i32p, # UPLO
nb_i32p, # N
nb_i32p, # NRHS
float_pointer, # A
_ptr_int, # LDA
nb_i32p, # LDA
float_pointer, # B
_ptr_int, # LDB
_ptr_int, # INFO
nb_i32p, # LDB
nb_i32p, # INFO
)
return functype(lapack_ptr)
)
@numba_basic.numba_njit
def potrs(UPLO, N, NRHS, A, LDA, B, LDB, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_potrs_pointer,
func_type_ref=potrs_function_type,
unique_func_name_lit=unique_func_name,
)
fn(UPLO, N, NRHS, A, LDA, B, LDB, INFO)
return potrs
@classmethod
def numba_xlange(cls, dtype):
def numba_xlange(cls, dtype) -> CPUDispatcher:
"""
Compute the value of the 1-norm, Frobenius norm, infinity-norm, or the largest absolute value of any element of
a general M-by-N matrix A.
Called by scipy.linalg.solve
Called by scipy.linalg.solve, but doesn't correspond to any Op in pytensor.
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "lange")
output_ctype = _get_output_ctype(dtype)
functype = ctypes.CFUNCTYPE(
output_ctype, # Output
_ptr_int, # NORM
_ptr_int, # M
_ptr_int, # N
kind = get_blas_kind(dtype)
float_type = _get_nb_float_from_dtype(kind, return_pointer=False)
float_pointer = _get_nb_float_from_dtype(kind, return_pointer=True)
unique_func_name = f"scipy.lapack.{kind}lange"
@numba_basic.numba_njit
def get_lange_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "lange")
return ptr
lange_function_type = types.FunctionType(
float_type(
nb_i32p, # NORM
nb_i32p, # M
nb_i32p, # N
float_pointer, # A
_ptr_int, # LDA
nb_i32p, # LDA
float_pointer, # WORK
)
return functype(lapack_ptr)
)
@numba_basic.numba_njit
def lange(NORM, M, N, A, LDA, WORK):
fn = _call_cached_ptr(
get_ptr_func=get_lange_pointer,
func_type_ref=lange_function_type,
unique_func_name_lit=unique_func_name,
)
return fn(NORM, M, N, A, LDA, WORK)
return lange
@classmethod
def numba_xlamch(cls, dtype):
def numba_xlamch(cls, dtype) -> CPUDispatcher:
"""
Determine machine precision for floating point arithmetic.
"""
kind = get_blas_kind(dtype)
float_type = _get_nb_float_from_dtype(kind, return_pointer=False)
unique_func_name = f"scipy.lapack.{kind}lamch"
@numba_basic.numba_njit
def get_lamch_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "lamch")
return ptr
lamch_function_type = types.FunctionType(
float_type( # Return type
nb_i32p, # CMACH
)
)
lapack_ptr, _float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "lamch")
output_dtype = _get_output_ctype(dtype)
functype = ctypes.CFUNCTYPE(
output_dtype, # Output
_ptr_int, # CMACH
@numba_basic.numba_njit
def lamch(CMACH):
fn = _call_cached_ptr(
get_ptr_func=get_lamch_pointer,
func_type_ref=lamch_function_type,
unique_func_name_lit=unique_func_name,
)
return functype(lapack_ptr)
res = fn(CMACH)
return res
return lamch
@classmethod
def numba_xgecon(cls, dtype):
def numba_xgecon(cls, dtype) -> CPUDispatcher:
"""
Estimates the condition number of a matrix A, using the LU factorization computed by numba_getrf.
Called by scipy.linalg.solve when assume_a == "gen"
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "gecon")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # NORM
_ptr_int, # N
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}gecon"
@numba_basic.numba_njit
def get_gecon_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "gecon")
return ptr
gecon_function_type = types.FunctionType(
types.void(
nb_i32p, # NORM
nb_i32p, # N
float_pointer, # A
_ptr_int, # LDA
nb_i32p, # LDA
float_pointer, # ANORM
float_pointer, # RCOND
float_pointer, # WORK
_ptr_int, # IWORK
_ptr_int, # INFO
nb_i32p, # IWORK
nb_i32p, # INFO
)
)
return functype(lapack_ptr)
@numba_basic.numba_njit
def gecon(NORM, N, A, LDA, ANORM, RCOND, WORK, IWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_gecon_pointer,
func_type_ref=gecon_function_type,
unique_func_name_lit=unique_func_name,
)
fn(NORM, N, A, LDA, ANORM, RCOND, WORK, IWORK, INFO)
return gecon
@classmethod
def numba_xgetrf(cls, dtype):
def numba_xgetrf(cls, dtype) -> CPUDispatcher:
"""
Compute partial pivoting LU factorization of a general M-by-N matrix A using row interchanges.
Called by scipy.linalg.lu_factor
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "getrf")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # M
_ptr_int, # N
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}getrf"
@numba_basic.numba_njit
def get_getrf_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "getrf")
return ptr
getrf_function_type = types.FunctionType(
types.void(
nb_i32p, # M
nb_i32p, # N
float_pointer, # A
_ptr_int, # LDA
_ptr_int, # IPIV
_ptr_int, # INFO
nb_i32p, # LDA
nb_i32p, # IPIV
nb_i32p, # INFO
)
)
return functype(lapack_ptr)
@numba_basic.numba_njit
def getrf(M, N, A, LDA, IPIV, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_getrf_pointer,
func_type_ref=getrf_function_type,
unique_func_name_lit=unique_func_name,
)
fn(M, N, A, LDA, IPIV, INFO)
return getrf
@classmethod
def numba_xgetrs(cls, dtype):
def numba_xgetrs(cls, dtype) -> CPUDispatcher:
"""
Solve a system of linear equations A @ X = B or A.T @ X = B with a general N-by-N matrix A using the LU
factorization computed by GETRF.
Called by scipy.linalg.lu_solve
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "getrs")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # TRANS
_ptr_int, # N
_ptr_int, # NRHS
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}getrs"
@numba_basic.numba_njit
def get_getrs_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "getrs")
return ptr
getrs_function_type = types.FunctionType(
types.void(
nb_i32p, # TRANS
nb_i32p, # N
nb_i32p, # NRHS
float_pointer, # A
_ptr_int, # LDA
_ptr_int, # IPIV
nb_i32p, # LDA
nb_i32p, # IPIV
float_pointer, # B
_ptr_int, # LDB
_ptr_int, # INFO
nb_i32p, # LDB
nb_i32p, # INFO
)
)
@numba_basic.numba_njit
def getrs(TRANS, N, NRHS, A, LDA, IPIV, B, LDB, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_getrs_pointer,
func_type_ref=getrs_function_type,
unique_func_name_lit=unique_func_name,
)
return functype(lapack_ptr)
fn(TRANS, N, NRHS, A, LDA, IPIV, B, LDB, INFO)
return getrs
@classmethod
def numba_xsysv(cls, dtype):
def numba_xsysv(cls, dtype) -> CPUDispatcher:
"""
Solve a system of linear equations A @ X = B with a symmetric matrix A using the diagonal pivoting method,
factorizing A into LDL^T or UDU^T form, depending on the value of UPLO
Called by scipy.linalg.solve when assume_a == "sym"
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "sysv")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # UPLO
_ptr_int, # N
_ptr_int, # NRHS
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}sysv"
@numba_basic.numba_njit
def get_sysv_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "sysv")
return ptr
sysv_function_type = types.FunctionType(
types.void(
nb_i32p, # UPLO
nb_i32p, # N
nb_i32p, # NRHS
float_pointer, # A
_ptr_int, # LDA
_ptr_int, # IPIV
nb_i32p, # LDA
nb_i32p, # IPIV
float_pointer, # B
_ptr_int, # LDB
nb_i32p, # LDB
float_pointer, # WORK
_ptr_int, # LWORK
_ptr_int, # INFO
nb_i32p, # LWORK
nb_i32p, # INFO
)
)
return functype(lapack_ptr)
@numba_basic.numba_njit
def sysv(UPLO, N, NRHS, A, LDA, IPIV, B, LDB, WORK, LWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_sysv_pointer,
func_type_ref=sysv_function_type,
unique_func_name_lit=unique_func_name,
)
fn(UPLO, N, NRHS, A, LDA, IPIV, B, LDB, WORK, LWORK, INFO)
return sysv
@classmethod
def numba_xsycon(cls, dtype):
def numba_xsycon(cls, dtype) -> CPUDispatcher:
"""
Estimate the reciprocal of the condition number of a symmetric matrix A using the UDU or LDL factorization
computed by xSYTRF.
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "sycon")
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}sycon"
@numba_basic.numba_njit
def get_sycon_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "sycon")
return ptr
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # UPLO
_ptr_int, # N
sycon_function_type = types.FunctionType(
types.void(
nb_i32p, # UPLO
nb_i32p, # N
float_pointer, # A
_ptr_int, # LDA
_ptr_int, # IPIV
nb_i32p, # LDA
nb_i32p, # IPIV
float_pointer, # ANORM
float_pointer, # RCOND
float_pointer, # WORK
_ptr_int, # IWORK
_ptr_int, # INFO
nb_i32p, # IWORK
nb_i32p, # INFO
)
)
return functype(lapack_ptr)
@numba_basic.numba_njit
def sycon(UPLO, N, A, LDA, IPIV, ANORM, RCOND, WORK, IWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_sycon_pointer,
func_type_ref=sycon_function_type,
unique_func_name_lit=unique_func_name,
)
fn(UPLO, N, A, LDA, IPIV, ANORM, RCOND, WORK, IWORK, INFO)
return sycon
@classmethod
def numba_xpocon(cls, dtype):
def numba_xpocon(cls, dtype) -> CPUDispatcher:
"""
Estimates the reciprocal of the condition number of a positive definite matrix A using the Cholesky factorization
computed by potrf.
Called by scipy.linalg.solve when assume_a == "pos"
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "pocon")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # UPLO
_ptr_int, # N
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}pocon"
@numba_basic.numba_njit
def get_pocon_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "pocon")
return ptr
pocon_function_type = types.FunctionType(
types.void(
nb_i32p, # UPLO
nb_i32p, # N
float_pointer, # A
_ptr_int, # LDA
nb_i32p, # LDA
float_pointer, # ANORM
float_pointer, # RCOND
float_pointer, # WORK
_ptr_int, # IWORK
_ptr_int, # INFO
nb_i32p, # IWORK
nb_i32p, # INFO
)
)
return functype(lapack_ptr)
@numba_basic.numba_njit
def pocon(UPLO, N, A, LDA, ANORM, RCOND, WORK, IWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_pocon_pointer,
func_type_ref=pocon_function_type,
unique_func_name_lit=unique_func_name,
)
fn(UPLO, N, A, LDA, ANORM, RCOND, WORK, IWORK, INFO)
return pocon
@classmethod
def numba_xposv(cls, dtype):
def numba_xposv(cls, dtype) -> CPUDispatcher:
"""
Solve a system of linear equations A @ X = B with a symmetric positive definite matrix A using the Cholesky
factorization computed by potrf.
"""
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}posv"
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "posv")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # UPLO
_ptr_int, # N
_ptr_int, # NRHS
@numba_basic.numba_njit
def get_posv_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "posv")
return ptr
posv_function_type = types.FunctionType(
types.void(
nb_i32p, # UPLO
nb_i32p, # N
nb_i32p, # NRHS
float_pointer, # A
_ptr_int, # LDA
nb_i32p, # LDA
float_pointer, # B
_ptr_int, # LDB
_ptr_int, # INFO
nb_i32p, # LDB
nb_i32p, # INFO
)
)
return functype(lapack_ptr)
@numba_basic.numba_njit
def posv(UPLO, N, NRHS, A, LDA, B, LDB, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_posv_pointer,
func_type_ref=posv_function_type,
unique_func_name_lit=unique_func_name,
)
fn(UPLO, N, NRHS, A, LDA, B, LDB, INFO)
return posv
@classmethod
def numba_xgttrf(cls, dtype):
def numba_xgttrf(cls, dtype) -> CPUDispatcher:
"""
Compute the LU factorization of a tridiagonal matrix A using row interchanges.
Called by scipy.linalg.lu_factor
Called by scipy.linalg.solve when assume_a == "tri"
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "gttrf")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # N
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}gttrf"
@numba_basic.numba_njit
def get_gttrf_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "gttrf")
return ptr
gttrf_function_type = types.FunctionType(
types.void(
nb_i32p, # N
float_pointer, # DL
float_pointer, # D
float_pointer, # DU
float_pointer, # DU2
_ptr_int, # IPIV
_ptr_int, # INFO
nb_i32p, # IPIV
nb_i32p, # INFO
)
)
return functype(lapack_ptr)
@numba_basic.numba_njit
def gttrf(N, DL, D, DU, DU2, IPIV, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_gttrf_pointer,
func_type_ref=gttrf_function_type,
unique_func_name_lit=unique_func_name,
)
fn(N, DL, D, DU, DU2, IPIV, INFO)
return gttrf
@classmethod
def numba_xgttrs(cls, dtype):
def numba_xgttrs(cls, dtype) -> CPUDispatcher:
"""
Solve a system of linear equations A @ X = B with a tridiagonal matrix A using the LU factorization computed by numba_gttrf.
Called by scipy.linalg.lu_solve
Called by scipy.linalg.solve, when assume_a == "tri"
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "gttrs")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # TRANS
_ptr_int, # N
_ptr_int, # NRHS
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}gttrs"
@numba_basic.numba_njit
def get_gttrs_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "gttrs")
return ptr
gttrs_function_type = types.FunctionType(
types.void(
nb_i32p, # TRANS
nb_i32p, # N
nb_i32p, # NRHS
float_pointer, # DL
float_pointer, # D
float_pointer, # DU
float_pointer, # DU2
_ptr_int, # IPIV
nb_i32p, # IPIV
float_pointer, # B
_ptr_int, # LDB
_ptr_int, # INFO
nb_i32p, # LDB
nb_i32p, # INFO
)
)
@numba_basic.numba_njit
def gttrs(TRANS, N, NRHS, DL, D, DU, DU2, IPIV, B, LDB, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_gttrs_pointer,
func_type_ref=gttrs_function_type,
unique_func_name_lit=unique_func_name,
)
return functype(lapack_ptr)
fn(TRANS, N, NRHS, DL, D, DU, DU2, IPIV, B, LDB, INFO)
return gttrs
@classmethod
def numba_xgtcon(cls, dtype):
def numba_xgtcon(cls, dtype) -> CPUDispatcher:
"""
Estimate the reciprocal of the condition number of a tridiagonal matrix A using the LU factorization computed by numba_gttrf.
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "gtcon")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # NORM
_ptr_int, # N
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}gtcon"
@numba_basic.numba_njit
def get_gtcon_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "gtcon")
return ptr
gtcon_function_type = types.FunctionType(
types.void(
nb_i32p, # NORM
nb_i32p, # N
float_pointer, # DL
float_pointer, # D
float_pointer, # DU
float_pointer, # DU2
_ptr_int, # IPIV
nb_i32p, # IPIV
float_pointer, # ANORM
float_pointer, # RCOND
float_pointer, # WORK
_ptr_int, # IWORK
_ptr_int, # INFO
nb_i32p, # IWORK
nb_i32p, # INFO
)
return functype(lapack_ptr)
)
@numba_basic.numba_njit
def gtcon(NORM, N, DL, D, DU, DU2, IPIV, ANORM, RCOND, WORK, IWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_gtcon_pointer,
func_type_ref=gtcon_function_type,
unique_func_name_lit=unique_func_name,
)
fn(NORM, N, DL, D, DU, DU2, IPIV, ANORM, RCOND, WORK, IWORK, INFO)
return gtcon
@classmethod
def numba_xgeqrf(cls, dtype):
def numba_xgeqrf(cls, dtype) -> CPUDispatcher:
"""
Compute the QR factorization of a general M-by-N matrix A.
Used in QR decomposition (no pivoting).
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "geqrf")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # M
_ptr_int, # N
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}geqrf"
@numba_basic.numba_njit
def get_geqrf_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "geqrf")
return ptr
geqrf_function_type = types.FunctionType(
types.void(
nb_i32p, # M
nb_i32p, # N
float_pointer, # A
_ptr_int, # LDA
nb_i32p, # LDA
float_pointer, # TAU
float_pointer, # WORK
_ptr_int, # LWORK
_ptr_int, # INFO
nb_i32p, # LWORK
nb_i32p, # INFO
)
)
@numba_basic.numba_njit
def geqrf(M, N, A, LDA, TAU, WORK, LWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_geqrf_pointer,
func_type_ref=geqrf_function_type,
unique_func_name_lit=unique_func_name,
)
return functype(lapack_ptr)
fn(M, N, A, LDA, TAU, WORK, LWORK, INFO)
return geqrf
@classmethod
def numba_xgeqp3(cls, dtype):
def numba_xgeqp3(cls, dtype) -> CPUDispatcher:
"""
Compute the QR factorization with column pivoting of a general M-by-N matrix A.
Used in QR decomposition with pivoting.
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "geqp3")
ctype_args = (
_ptr_int, # M
_ptr_int, # N
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}geqp3"
@numba_basic.numba_njit
def get_geqp3_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "geqp3")
return ptr
if isinstance(dtype, Complex):
real_pointer = nb_f64p if dtype is nb_c128 else nb_f32p
geqp3_function_type = types.FunctionType(
types.void(
nb_i32p, # M
nb_i32p, # N
float_pointer, # A
_ptr_int, # LDA
_ptr_int, # JPVT
nb_i32p, # LDA
nb_i32p, # JPVT
float_pointer, # TAU
float_pointer, # WORK
_ptr_int, # LWORK
nb_i32p, # LWORK
real_pointer, # RWORK
nb_i32p, # INFO
)
)
if isinstance(dtype, Complex):
ctype_args = (
*ctype_args,
float_pointer, # RWORK)
@numba_basic.numba_njit
def geqp3(M, N, A, LDA, JPVT, TAU, WORK, LWORK, RWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_geqp3_pointer,
func_type_ref=geqp3_function_type,
unique_func_name_lit=unique_func_name,
)
fn(M, N, A, LDA, JPVT, TAU, WORK, LWORK, RWORK, INFO)
functype = ctypes.CFUNCTYPE(
None,
*ctype_args,
_ptr_int, # INFO
else:
geqp3_function_type = types.FunctionType(
types.void(
nb_i32p, # M
nb_i32p, # N
float_pointer, # A
nb_i32p, # LDA
nb_i32p, # JPVT
float_pointer, # TAU
float_pointer, # WORK
nb_i32p, # LWORK
nb_i32p, # INFO
)
)
return functype(lapack_ptr)
@numba_basic.numba_njit
def geqp3(M, N, A, LDA, JPVT, TAU, WORK, LWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_geqp3_pointer,
func_type_ref=geqp3_function_type,
unique_func_name_lit=unique_func_name,
)
fn(M, N, A, LDA, JPVT, TAU, WORK, LWORK, INFO)
return geqp3
@classmethod
def numba_xorgqr(cls, dtype):
def numba_xorgqr(cls, dtype) -> CPUDispatcher:
"""
Generate the orthogonal matrix Q from a QR factorization (real types).
Used in QR decomposition to form Q.
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "orgqr")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # M
_ptr_int, # N
_ptr_int, # K
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}orgqr"
@numba_basic.numba_njit
def get_orgqr_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "orgqr")
return ptr
orgqr_function_type = types.FunctionType(
types.void(
nb_i32p, # M
nb_i32p, # N
nb_i32p, # K
float_pointer, # A
_ptr_int, # LDA
nb_i32p, # LDA
float_pointer, # TAU
float_pointer, # WORK
_ptr_int, # LWORK
_ptr_int, # INFO
nb_i32p, # LWORK
nb_i32p, # INFO
)
)
@numba_basic.numba_njit
def orgqr(M, N, K, A, LDA, TAU, WORK, LWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_orgqr_pointer,
func_type_ref=orgqr_function_type,
unique_func_name_lit=unique_func_name,
)
return functype(lapack_ptr)
fn(M, N, K, A, LDA, TAU, WORK, LWORK, INFO)
return orgqr
@classmethod
def numba_xungqr(cls, dtype):
def numba_xungqr(cls, dtype) -> CPUDispatcher:
"""
Generate the unitary matrix Q from a QR factorization (complex types).
Used in QR decomposition to form Q for complex types.
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "ungqr")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # M
_ptr_int, # N
_ptr_int, # K
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}ungqr"
@numba_basic.numba_njit
def get_ungqr_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "ungqr")
return ptr
ungqr_function_type = types.FunctionType(
types.void(
nb_i32p, # M
nb_i32p, # N
nb_i32p, # K
float_pointer, # A
_ptr_int, # LDA
nb_i32p, # LDA
float_pointer, # TAU
float_pointer, # WORK
_ptr_int, # LWORK
_ptr_int, # INFO
nb_i32p, # LWORK
nb_i32p, # INFO
)
return functype(lapack_ptr)
)
@numba_basic.numba_njit
def ungqr(M, N, K, A, LDA, TAU, WORK, LWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_ungqr_pointer,
func_type_ref=ungqr_function_type,
unique_func_name_lit=unique_func_name,
)
fn(M, N, K, A, LDA, TAU, WORK, LWORK, INFO)
return ungqr
......@@ -31,7 +31,6 @@ def _xgeqrf(A: np.ndarray, overwrite_a: bool, lwork: int):
def xgeqrf_impl(A, overwrite_a, lwork):
ensure_lapack()
dtype = A.dtype
w_type = _get_underlying_float(dtype)
geqrf = _LAPACK().numba_xgeqrf(dtype)
def impl(A, overwrite_a, lwork):
......@@ -57,10 +56,10 @@ def xgeqrf_impl(A, overwrite_a, lwork):
geqrf(
val_to_int_ptr(M),
val_to_int_ptr(N),
A_copy.T.view(w_type).T.ctypes,
A_copy.ctypes,
LDA,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
TAU.ctypes,
WORK.ctypes,
LWORK,
INFO,
)
......@@ -82,7 +81,6 @@ def _xgeqp3(A: np.ndarray, overwrite_a: bool, lwork: int):
def xgeqp3_impl(A, overwrite_a, lwork):
ensure_lapack()
dtype = A.dtype
w_type = _get_underlying_float(dtype)
geqp3 = _LAPACK().numba_xgeqp3(dtype)
def impl(A, overwrite_a, lwork):
......@@ -109,11 +107,11 @@ def xgeqp3_impl(A, overwrite_a, lwork):
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
A_copy.T.view(w_type).T.ctypes,
A_copy.ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
TAU.ctypes,
WORK.ctypes,
LWORK,
INFO,
)
......@@ -135,7 +133,6 @@ def _xorgqr(A: np.ndarray, tau: np.ndarray, overwrite_a: bool, lwork: int):
def xorgqr_impl(A, tau, overwrite_a, lwork):
ensure_lapack()
dtype = A.dtype
w_type = _get_underlying_float(dtype)
orgqr = _LAPACK().numba_xorgqr(dtype)
def impl(A, tau, overwrite_a, lwork):
......@@ -162,10 +159,10 @@ def xorgqr_impl(A, tau, overwrite_a, lwork):
val_to_int_ptr(M),
val_to_int_ptr(N),
val_to_int_ptr(K),
A_copy.T.view(w_type).T.ctypes,
A_copy.ctypes,
LDA,
tau.view(w_type).ctypes,
WORK.view(w_type).ctypes,
tau.ctypes,
WORK.ctypes,
LWORK,
INFO,
)
......@@ -188,7 +185,6 @@ def xungqr_impl(A, tau, overwrite_a, lwork):
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
ungqr = _LAPACK().numba_xungqr(dtype)
def impl(A, tau, overwrite_a, lwork):
......@@ -214,10 +210,10 @@ def xungqr_impl(A, tau, overwrite_a, lwork):
val_to_int_ptr(M),
val_to_int_ptr(N),
val_to_int_ptr(K),
A_copy.T.view(w_type).T.ctypes,
A_copy.ctypes,
LDA,
tau.view(w_type).ctypes,
WORK.view(w_type).ctypes,
tau.ctypes,
WORK.ctypes,
LWORK,
INFO,
)
......@@ -426,11 +422,11 @@ def qr_full_pivot_impl(
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes,
x_copy.ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
TAU.ctypes,
WORK.ctypes,
val_to_int_ptr(-1), # LWORK
RWORK.ctypes,
val_to_int_ptr(1), # INFO
......@@ -439,11 +435,11 @@ def qr_full_pivot_impl(
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes,
x_copy.ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
TAU.ctypes,
WORK.ctypes,
val_to_int_ptr(-1),
val_to_int_ptr(1),
)
......@@ -458,11 +454,11 @@ def qr_full_pivot_impl(
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes,
x_copy.ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
TAU.ctypes,
WORK.ctypes,
val_to_int_ptr(lwork_val),
RWORK.ctypes,
INFO,
......@@ -471,11 +467,11 @@ def qr_full_pivot_impl(
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes,
x_copy.ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
TAU.ctypes,
WORK.ctypes,
val_to_int_ptr(lwork_val),
INFO,
)
......@@ -501,10 +497,10 @@ def qr_full_pivot_impl(
val_to_int_ptr(M),
val_to_int_ptr(Q_in.shape[1]),
val_to_int_ptr(K),
Q_in.T.view(w_type).T.ctypes,
Q_in.ctypes,
val_to_int_ptr(M),
TAU.view(w_type).ctypes,
WORKQ.view(w_type).ctypes,
TAU.ctypes,
WORKQ.ctypes,
val_to_int_ptr(-1),
val_to_int_ptr(1),
)
......@@ -519,10 +515,10 @@ def qr_full_pivot_impl(
val_to_int_ptr(M),
val_to_int_ptr(Q_in.shape[1]),
val_to_int_ptr(K),
Q_in.T.view(w_type).T.ctypes,
Q_in.ctypes,
val_to_int_ptr(M),
TAU.view(w_type).ctypes,
WORKQ.view(w_type).ctypes,
TAU.ctypes,
WORKQ.ctypes,
val_to_int_ptr(lwork_q),
INFOQ,
)
......@@ -538,7 +534,6 @@ def qr_full_no_pivot_impl(
ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype
w_type = _get_underlying_float(dtype)
geqrf = _LAPACK().numba_xgeqrf(dtype)
orgqr = (
_LAPACK().numba_xorgqr(dtype)
......@@ -574,10 +569,10 @@ def qr_full_no_pivot_impl(
geqrf(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes,
x_copy.ctypes,
LDA,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
TAU.ctypes,
WORK.ctypes,
val_to_int_ptr(-1),
val_to_int_ptr(1),
)
......@@ -591,10 +586,10 @@ def qr_full_no_pivot_impl(
geqrf(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes,
x_copy.ctypes,
LDA,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
TAU.ctypes,
WORK.ctypes,
val_to_int_ptr(lwork_val),
INFO,
)
......@@ -619,10 +614,10 @@ def qr_full_no_pivot_impl(
val_to_int_ptr(M),
val_to_int_ptr(Q_in.shape[1]),
val_to_int_ptr(K),
Q_in.T.view(w_type).T.ctypes,
Q_in.ctypes,
val_to_int_ptr(M),
TAU.view(w_type).ctypes,
WORKQ.view(w_type).ctypes,
TAU.ctypes,
WORKQ.ctypes,
val_to_int_ptr(-1),
val_to_int_ptr(1),
)
......@@ -637,10 +632,10 @@ def qr_full_no_pivot_impl(
val_to_int_ptr(M), # M
val_to_int_ptr(Q_in.shape[1]), # N
val_to_int_ptr(K), # K
Q_in.T.view(w_type).T.ctypes, # A
Q_in.ctypes, # A
val_to_int_ptr(M), # LDA
TAU.view(w_type).ctypes, # TAU
WORKQ.view(w_type).ctypes, # WORK
TAU.ctypes, # TAU
WORKQ.ctypes, # WORK
val_to_int_ptr(lwork_q), # LWORK
INFOQ, # INFO
)
......@@ -656,7 +651,6 @@ def qr_r_pivot_impl(
ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype
w_type = _get_underlying_float(dtype)
geqp3 = _LAPACK().numba_xgeqp3(dtype)
def impl(
......@@ -687,11 +681,11 @@ def qr_r_pivot_impl(
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes,
x_copy.ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
TAU.ctypes,
WORK.ctypes,
val_to_int_ptr(-1),
val_to_int_ptr(1),
)
......@@ -705,11 +699,11 @@ def qr_r_pivot_impl(
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes,
x_copy.ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
TAU.ctypes,
WORK.ctypes,
val_to_int_ptr(lwork_val),
INFO,
)
......@@ -732,7 +726,6 @@ def qr_r_no_pivot_impl(
ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype
w_type = _get_underlying_float(dtype)
geqrf = _LAPACK().numba_xgeqrf(dtype)
def impl(
......@@ -762,10 +755,10 @@ def qr_r_no_pivot_impl(
geqrf(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes,
x_copy.ctypes,
LDA,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
TAU.ctypes,
WORK.ctypes,
val_to_int_ptr(-1),
val_to_int_ptr(1),
)
......@@ -779,10 +772,10 @@ def qr_r_no_pivot_impl(
geqrf(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes,
x_copy.ctypes,
LDA,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
TAU.ctypes,
WORK.ctypes,
val_to_int_ptr(lwork_val),
INFO,
)
......@@ -805,7 +798,6 @@ def qr_raw_no_pivot_impl(
ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype
w_type = _get_underlying_float(dtype)
geqrf = _LAPACK().numba_xgeqrf(dtype)
def impl(
......@@ -835,10 +827,10 @@ def qr_raw_no_pivot_impl(
geqrf(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes,
x_copy.ctypes,
LDA,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
TAU.ctypes,
WORK.ctypes,
val_to_int_ptr(-1),
val_to_int_ptr(1),
)
......@@ -852,10 +844,10 @@ def qr_raw_no_pivot_impl(
geqrf(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes,
x_copy.ctypes,
LDA,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
TAU.ctypes,
WORK.ctypes,
val_to_int_ptr(lwork_val),
INFO,
)
......@@ -914,11 +906,11 @@ def qr_raw_pivot_impl(
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes,
x_copy.ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
TAU.ctypes,
WORK.ctypes,
val_to_int_ptr(-1), # LWORK
RWORK.ctypes,
val_to_int_ptr(1), # INFO
......@@ -927,11 +919,11 @@ def qr_raw_pivot_impl(
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes,
x_copy.ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
TAU.ctypes,
WORK.ctypes,
val_to_int_ptr(-1),
val_to_int_ptr(1),
)
......@@ -946,11 +938,11 @@ def qr_raw_pivot_impl(
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes,
x_copy.ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
TAU.ctypes,
WORK.ctypes,
val_to_int_ptr(lwork_val),
RWORK.ctypes,
INFO,
......@@ -959,11 +951,11 @@ def qr_raw_pivot_impl(
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes,
x_copy.ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
TAU.ctypes,
WORK.ctypes,
val_to_int_ptr(lwork_val),
INFO,
)
......
......@@ -6,7 +6,6 @@ from pytensor import config
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
generate_fallback_impl,
numba_funcify,
register_funcify_default_op_cache_key,
)
from pytensor.link.numba.dispatch.linalg.decomposition.cholesky import _cholesky
......@@ -44,7 +43,7 @@ from pytensor.tensor.slinalg import (
)
@numba_funcify.register(Cholesky)
@register_funcify_default_op_cache_key(Cholesky)
def numba_funcify_Cholesky(op, node, **kwargs):
"""
Overload scipy.linalg.cholesky with a numba function.
......@@ -95,7 +94,8 @@ def numba_funcify_Cholesky(op, node, **kwargs):
return res
return cholesky
cache_key = 1
return cholesky, cache_key
@register_funcify_default_op_cache_key(PivotToPermutations)
......@@ -115,7 +115,7 @@ def pivot_to_permutation(op, node, **kwargs):
return numba_pivot_to_permutation, cache_key
@numba_funcify.register(LU)
@register_funcify_default_op_cache_key(LU)
def numba_funcify_LU(op, node, **kwargs):
inp_dtype = node.inputs[0].type.numpy_dtype
if inp_dtype.kind == "c":
......@@ -179,10 +179,11 @@ def numba_funcify_LU(op, node, **kwargs):
return res
return lu
cache_key = 1
return lu, cache_key
@numba_funcify.register(LUFactor)
@register_funcify_default_op_cache_key(LUFactor)
def numba_funcify_LUFactor(op, node, **kwargs):
inp_dtype = node.inputs[0].type.numpy_dtype
if inp_dtype.kind == "c":
......@@ -215,7 +216,8 @@ def numba_funcify_LUFactor(op, node, **kwargs):
return LU, piv
return lu_factor
cache_key = 1
return lu_factor, cache_key
@register_funcify_default_op_cache_key(BlockDiagonal)
......@@ -240,7 +242,7 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs):
return block_diag
@numba_funcify.register(Solve)
@register_funcify_default_op_cache_key(Solve)
def numba_funcify_Solve(op, node, **kwargs):
A_dtype, b_dtype = (i.type.numpy_dtype for i in node.inputs)
out_dtype = node.outputs[0].type.numpy_dtype
......@@ -305,10 +307,11 @@ def numba_funcify_Solve(op, node, **kwargs):
res = solve_fn(a, b, lower, overwrite_a, overwrite_b, check_finite, transposed)
return res
return solve
cache_key = 1
return solve, cache_key
@numba_funcify.register(SolveTriangular)
@register_funcify_default_op_cache_key(SolveTriangular)
def numba_funcify_SolveTriangular(op, node, **kwargs):
lower = op.lower
unit_diagonal = op.unit_diagonal
......@@ -358,10 +361,11 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
return res
return solve_triangular
cache_key = 1
return solve_triangular, cache_key
@numba_funcify.register(CholeskySolve)
@register_funcify_default_op_cache_key(CholeskySolve)
def numba_funcify_CholeskySolve(op, node, **kwargs):
lower = op.lower
overwrite_b = op.overwrite_b
......@@ -407,10 +411,11 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
check_finite=check_finite,
)
return cho_solve
cache_key = 1
return cho_solve, cache_key
@numba_funcify.register(QR)
@register_funcify_default_op_cache_key(QR)
def numba_funcify_QR(op, node, **kwargs):
mode = op.mode
check_finite = op.check_finite
......@@ -500,4 +505,5 @@ def numba_funcify_QR(op, node, **kwargs):
f"QR mode={mode}, pivoting={pivoting} not supported in numba mode."
)
return qr
cache_key = 1
return qr, cache_key
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论