Unverified 提交 5137ed3e authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Allow caching of cython functions in numba backend (#1807)

* Add machinery to enable numba caching of function pointers * Numba cache for Cholesky * Numba cache for SolveTriangular * Numba cache for CholeskySolve * Numba cache for solve helpers * Numba cache for GECON * Numba cache for lu_factor * Numba cache for Solve when assume_a="gen" * Numba cache for Solve when assume_a="sym" * Numba cache for Solve when assume_a="pos" * Numba cache for Solve when assume_a='tri' * Numba cache for QR * Clean up obsolete code * Feedback * More feedback * Rename `cache_key_lit` -> `unique_func_name_lit`
上级 0b439c0f
...@@ -5,6 +5,9 @@ from tempfile import NamedTemporaryFile ...@@ -5,6 +5,9 @@ from tempfile import NamedTemporaryFile
from typing import Any from typing import Any
from weakref import WeakKeyDictionary from weakref import WeakKeyDictionary
import numba
from llvmlite import ir
from numba.core import cgutils
from numba.core.caching import CacheImpl, _CacheLocator from numba.core.caching import CacheImpl, _CacheLocator
from pytensor.configdefaults import config from pytensor.configdefaults import config
...@@ -127,3 +130,49 @@ def compile_numba_function_src( ...@@ -127,3 +130,49 @@ def compile_numba_function_src(
CACHED_SRC_FUNCTIONS[res] = cache_key CACHED_SRC_FUNCTIONS[res] = cache_key
return res # type: ignore return res # type: ignore
@numba.extending.intrinsic(prefer_literal=True)
def _call_cached_ptr(typingctx, get_ptr_func, func_type_ref, unique_func_name_lit):
"""
Enable caching of function pointers returned by `get_ptr_func`.
When one of our Numba-dispatched functions depends on a pointer to a compiled function function (e.g. when we call
cython_lapack routines), numba will refuse to cache the function, because the pointer may change between runs.
This intrinsic allows us to cache the pointer ourselves, by storing it in a global variable keyed by a literal
`unique_func_name_lit`. The first time the intrinsic is called, it will call `get_ptr_func` to get the pointer, store it
in the global variable, and return it. Subsequent calls will load the pointer from the global variable.
"""
func_type = func_type_ref.instance_type
cache_key = unique_func_name_lit.literal_value
def codegen(context, builder, signature, args):
ptr_ty = ir.PointerType(ir.IntType(8))
null = ptr_ty(None)
align = 64
mod = builder.module
var = cgutils.add_global_variable(mod, ptr_ty, f"_ptr_cache_{cache_key}")
var.align = align
var.linkage = "private"
var.initializer = null
var_val = builder.load_atomic(var, "acquire", align)
result_ptr = cgutils.alloca_once_value(builder, var_val)
with builder.if_then(builder.icmp_signed("==", var_val, null), likely=False):
sig = typingctx.resolve_function_type(get_ptr_func, [], {})
f = context.get_function(get_ptr_func, sig)
new_ptr = f(builder, [])
new_ptr = builder.inttoptr(new_ptr, ptr_ty)
builder.store_atomic(new_ptr, var, "release", align)
builder.store(new_ptr, result_ptr)
sfunc = cgutils.create_struct_proxy(func_type)(context, builder)
sfunc.c_addr = builder.load(result_ptr)
return sfunc._getvalue()
sig = func_type(get_ptr_func, func_type_ref, unique_func_name_lit)
return sig, codegen
import ctypes import numba
import numpy as np import numpy as np
from numba.core import cgutils, types from numba.core import cgutils, types
from numba.core.extending import get_cython_function_address, intrinsic from numba.core.extending import get_cython_function_address, intrinsic
from numba.core.registry import CPUDispatcher
from numba.core.types import Complex from numba.core.types import Complex
from numba.np.linalg import ensure_lapack, get_blas_kind from numba.np.linalg import ensure_lapack, get_blas_kind
from pytensor.link.numba.cache import _call_cached_ptr
from pytensor.link.numba.dispatch import basic as numba_basic
nb_i32 = types.int32
nb_i32p = types.CPointer(nb_i32)
_PTR = ctypes.POINTER nb_f32 = types.float32
nb_f32p = types.CPointer(nb_f32)
_dbl = ctypes.c_double nb_f64 = types.float64
_float = ctypes.c_float nb_f64p = types.CPointer(nb_f64)
_char = ctypes.c_char
_int = ctypes.c_int
_ptr_float = _PTR(_float) nb_c64 = types.complex64
_ptr_dbl = _PTR(_dbl) nb_c64p = types.CPointer(nb_c64)
_ptr_char = _PTR(_char)
_ptr_int = _PTR(_int)
nb_c128 = types.complex128
nb_c128p = types.CPointer(nb_c128)
def _get_lapack_ptr_and_ptr_type(dtype, name):
def get_lapack_ptr(dtype, name):
d = get_blas_kind(dtype) d = get_blas_kind(dtype)
func_name = f"{d}{name}" func_name = f"{d}{name}"
float_pointer = _get_float_pointer_for_dtype(d)
lapack_ptr = get_cython_function_address("scipy.linalg.cython_lapack", func_name) lapack_ptr = get_cython_function_address("scipy.linalg.cython_lapack", func_name)
return lapack_ptr
return lapack_ptr, float_pointer
def _get_underlying_float(dtype): def _get_underlying_float(dtype):
...@@ -40,19 +44,18 @@ def _get_underlying_float(dtype): ...@@ -40,19 +44,18 @@ def _get_underlying_float(dtype):
return np.dtype(out_type) return np.dtype(out_type)
def _get_float_pointer_for_dtype(blas_dtype): def _get_nb_float_from_dtype(blas_dtype, return_pointer=True):
if blas_dtype in ["s", "c"]: match blas_dtype:
return _ptr_float case "s":
elif blas_dtype in ["d", "z"]: return nb_f32p if return_pointer else nb_f32
return _ptr_dbl case "d":
return nb_f64p if return_pointer else nb_f64
case "c":
def _get_output_ctype(dtype): return nb_c64p if return_pointer else nb_c64
s_dtype = str(dtype) case "z":
if s_dtype in ["float32", "complex64"]: return nb_c128p if return_pointer else nb_c128
return _float case _:
elif s_dtype in ["float64", "complex128"]: raise ValueError(f"Unsupported BLAS dtype: {blas_dtype}")
return _dbl
@intrinsic @intrinsic
...@@ -136,422 +139,815 @@ class _LAPACK: ...@@ -136,422 +139,815 @@ class _LAPACK:
ensure_lapack() ensure_lapack()
@classmethod @classmethod
def numba_xtrtrs(cls, dtype): def numba_xtrtrs(cls, dtype) -> CPUDispatcher:
""" """
Solve a triangular system of equations of the form A @ X = B or A.T @ X = B. Solve a triangular system of equations of the form A @ X = B or A.T @ X = B.
Called by scipy.linalg.solve_triangular Called by scipy.linalg.solve_triangular
""" """
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "trtrs")
functype = ctypes.CFUNCTYPE( kind = get_blas_kind(dtype)
None, float_ptr = _get_nb_float_from_dtype(kind)
_ptr_int, # UPLO unique_func_name = f"scipy.lapack.{kind}trtrs"
_ptr_int, # TRANS
_ptr_int, # DIAG @numba_basic.numba_njit
_ptr_int, # N def get_trtrs_pointer():
_ptr_int, # NRHS with numba.objmode(ptr=types.intp):
float_pointer, # A ptr = get_lapack_ptr(dtype, "trtrs")
_ptr_int, # LDA return ptr
float_pointer, # B
_ptr_int, # LDB trtrs_function_type = types.FunctionType(
_ptr_int, # INFO types.void(
nb_i32p, # UPLO
nb_i32p, # TRANS
nb_i32p, # DIAG
nb_i32p, # N
nb_i32p, # NRHS
float_ptr, # A
nb_i32p, # LDA
float_ptr, # B
nb_i32p, # LDB
nb_i32p, # INFO
)
) )
return functype(lapack_ptr) @numba_basic.numba_njit
def trtrs(UPLO, TRANS, DIAG, N, NRHS, A, LDA, B, LDB, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_trtrs_pointer,
func_type_ref=trtrs_function_type,
unique_func_name_lit=unique_func_name,
)
fn(UPLO, TRANS, DIAG, N, NRHS, A, LDA, B, LDB, INFO)
return trtrs
@classmethod @classmethod
def numba_xpotrf(cls, dtype): def numba_xpotrf(cls, dtype) -> CPUDispatcher:
""" """
Compute the Cholesky factorization of a real symmetric positive definite matrix. Compute the Cholesky factorization of a real symmetric positive definite matrix.
Called by scipy.linalg.cholesky Called by scipy.linalg.cholesky
""" """
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrf")
functype = ctypes.CFUNCTYPE( kind = get_blas_kind(dtype)
None, float_ptr = _get_nb_float_from_dtype(kind)
_ptr_int, # UPLO, unique_func_name = f"scipy.lapack.{kind}potrf"
_ptr_int, # N
float_pointer, # A @numba_basic.numba_njit
_ptr_int, # LDA def get_potrf_pointer():
_ptr_int, # INFO with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "potrf")
return ptr
potrf_function_type = types.FunctionType(
types.void(
nb_i32p, # UPLO
nb_i32p, # N
float_ptr, # A
nb_i32p, # LDA
nb_i32p, # INFO
)
)
@numba_basic.numba_njit
def potrf(UPLO, N, A, LDA, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_potrf_pointer,
func_type_ref=potrf_function_type,
unique_func_name_lit=unique_func_name,
) )
return functype(lapack_ptr) fn(UPLO, N, A, LDA, INFO)
return potrf
@classmethod @classmethod
def numba_xpotrs(cls, dtype): def numba_xpotrs(cls, dtype) -> CPUDispatcher:
""" """
Solve a system of linear equations A @ X = B with a symmetric positive definite matrix A using the Cholesky Solve a system of linear equations A @ X = B with a symmetric positive definite matrix A using the Cholesky
factorization computed by numba_potrf. factorization computed by numba_potrf.
Called by scipy.linalg.cho_solve Called by scipy.linalg.cho_solve
""" """
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrs") kind = get_blas_kind(dtype)
functype = ctypes.CFUNCTYPE( float_pointer = _get_nb_float_from_dtype(kind)
None, unique_func_name = f"scipy.lapack.{kind}potrs"
_ptr_int, # UPLO
_ptr_int, # N @numba_basic.numba_njit
_ptr_int, # NRHS def get_potrs_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "potrs")
return ptr
potrs_function_type = types.FunctionType(
types.void(
nb_i32p, # UPLO
nb_i32p, # N
nb_i32p, # NRHS
float_pointer, # A float_pointer, # A
_ptr_int, # LDA nb_i32p, # LDA
float_pointer, # B float_pointer, # B
_ptr_int, # LDB nb_i32p, # LDB
_ptr_int, # INFO nb_i32p, # INFO
) )
return functype(lapack_ptr) )
@numba_basic.numba_njit
def potrs(UPLO, N, NRHS, A, LDA, B, LDB, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_potrs_pointer,
func_type_ref=potrs_function_type,
unique_func_name_lit=unique_func_name,
)
fn(UPLO, N, NRHS, A, LDA, B, LDB, INFO)
return potrs
@classmethod @classmethod
def numba_xlange(cls, dtype): def numba_xlange(cls, dtype) -> CPUDispatcher:
""" """
Compute the value of the 1-norm, Frobenius norm, infinity-norm, or the largest absolute value of any element of Compute the value of the 1-norm, Frobenius norm, infinity-norm, or the largest absolute value of any element of
a general M-by-N matrix A. a general M-by-N matrix A.
Called by scipy.linalg.solve Called by scipy.linalg.solve, but doesn't correspond to any Op in pytensor.
""" """
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "lange") kind = get_blas_kind(dtype)
output_ctype = _get_output_ctype(dtype) float_type = _get_nb_float_from_dtype(kind, return_pointer=False)
functype = ctypes.CFUNCTYPE( float_pointer = _get_nb_float_from_dtype(kind, return_pointer=True)
output_ctype, # Output unique_func_name = f"scipy.lapack.{kind}lange"
_ptr_int, # NORM
_ptr_int, # M @numba_basic.numba_njit
_ptr_int, # N def get_lange_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "lange")
return ptr
lange_function_type = types.FunctionType(
float_type(
nb_i32p, # NORM
nb_i32p, # M
nb_i32p, # N
float_pointer, # A float_pointer, # A
_ptr_int, # LDA nb_i32p, # LDA
float_pointer, # WORK float_pointer, # WORK
) )
return functype(lapack_ptr) )
@numba_basic.numba_njit
def lange(NORM, M, N, A, LDA, WORK):
fn = _call_cached_ptr(
get_ptr_func=get_lange_pointer,
func_type_ref=lange_function_type,
unique_func_name_lit=unique_func_name,
)
return fn(NORM, M, N, A, LDA, WORK)
return lange
@classmethod @classmethod
def numba_xlamch(cls, dtype): def numba_xlamch(cls, dtype) -> CPUDispatcher:
""" """
Determine machine precision for floating point arithmetic. Determine machine precision for floating point arithmetic.
""" """
kind = get_blas_kind(dtype)
float_type = _get_nb_float_from_dtype(kind, return_pointer=False)
unique_func_name = f"scipy.lapack.{kind}lamch"
@numba_basic.numba_njit
def get_lamch_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "lamch")
return ptr
lamch_function_type = types.FunctionType(
float_type( # Return type
nb_i32p, # CMACH
)
)
lapack_ptr, _float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "lamch") @numba_basic.numba_njit
output_dtype = _get_output_ctype(dtype) def lamch(CMACH):
functype = ctypes.CFUNCTYPE( fn = _call_cached_ptr(
output_dtype, # Output get_ptr_func=get_lamch_pointer,
_ptr_int, # CMACH func_type_ref=lamch_function_type,
unique_func_name_lit=unique_func_name,
) )
return functype(lapack_ptr) res = fn(CMACH)
return res
return lamch
@classmethod @classmethod
def numba_xgecon(cls, dtype): def numba_xgecon(cls, dtype) -> CPUDispatcher:
""" """
Estimates the condition number of a matrix A, using the LU factorization computed by numba_getrf. Estimates the condition number of a matrix A, using the LU factorization computed by numba_getrf.
Called by scipy.linalg.solve when assume_a == "gen" Called by scipy.linalg.solve when assume_a == "gen"
""" """
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "gecon") kind = get_blas_kind(dtype)
functype = ctypes.CFUNCTYPE( float_pointer = _get_nb_float_from_dtype(kind)
None, unique_func_name = f"scipy.lapack.{kind}gecon"
_ptr_int, # NORM
_ptr_int, # N @numba_basic.numba_njit
def get_gecon_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "gecon")
return ptr
gecon_function_type = types.FunctionType(
types.void(
nb_i32p, # NORM
nb_i32p, # N
float_pointer, # A float_pointer, # A
_ptr_int, # LDA nb_i32p, # LDA
float_pointer, # ANORM float_pointer, # ANORM
float_pointer, # RCOND float_pointer, # RCOND
float_pointer, # WORK float_pointer, # WORK
_ptr_int, # IWORK nb_i32p, # IWORK
_ptr_int, # INFO nb_i32p, # INFO
)
) )
return functype(lapack_ptr)
@numba_basic.numba_njit
def gecon(NORM, N, A, LDA, ANORM, RCOND, WORK, IWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_gecon_pointer,
func_type_ref=gecon_function_type,
unique_func_name_lit=unique_func_name,
)
fn(NORM, N, A, LDA, ANORM, RCOND, WORK, IWORK, INFO)
return gecon
@classmethod @classmethod
def numba_xgetrf(cls, dtype): def numba_xgetrf(cls, dtype) -> CPUDispatcher:
""" """
Compute partial pivoting LU factorization of a general M-by-N matrix A using row interchanges. Compute partial pivoting LU factorization of a general M-by-N matrix A using row interchanges.
Called by scipy.linalg.lu_factor Called by scipy.linalg.lu_factor
""" """
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "getrf") kind = get_blas_kind(dtype)
functype = ctypes.CFUNCTYPE( float_pointer = _get_nb_float_from_dtype(kind)
None, unique_func_name = f"scipy.lapack.{kind}getrf"
_ptr_int, # M
_ptr_int, # N @numba_basic.numba_njit
def get_getrf_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "getrf")
return ptr
getrf_function_type = types.FunctionType(
types.void(
nb_i32p, # M
nb_i32p, # N
float_pointer, # A float_pointer, # A
_ptr_int, # LDA nb_i32p, # LDA
_ptr_int, # IPIV nb_i32p, # IPIV
_ptr_int, # INFO nb_i32p, # INFO
)
) )
return functype(lapack_ptr)
@numba_basic.numba_njit
def getrf(M, N, A, LDA, IPIV, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_getrf_pointer,
func_type_ref=getrf_function_type,
unique_func_name_lit=unique_func_name,
)
fn(M, N, A, LDA, IPIV, INFO)
return getrf
@classmethod @classmethod
def numba_xgetrs(cls, dtype): def numba_xgetrs(cls, dtype) -> CPUDispatcher:
""" """
Solve a system of linear equations A @ X = B or A.T @ X = B with a general N-by-N matrix A using the LU Solve a system of linear equations A @ X = B or A.T @ X = B with a general N-by-N matrix A using the LU
factorization computed by GETRF. factorization computed by GETRF.
Called by scipy.linalg.lu_solve Called by scipy.linalg.lu_solve
""" """
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "getrs") kind = get_blas_kind(dtype)
functype = ctypes.CFUNCTYPE( float_pointer = _get_nb_float_from_dtype(kind)
None, unique_func_name = f"scipy.lapack.{kind}getrs"
_ptr_int, # TRANS
_ptr_int, # N @numba_basic.numba_njit
_ptr_int, # NRHS def get_getrs_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "getrs")
return ptr
getrs_function_type = types.FunctionType(
types.void(
nb_i32p, # TRANS
nb_i32p, # N
nb_i32p, # NRHS
float_pointer, # A float_pointer, # A
_ptr_int, # LDA nb_i32p, # LDA
_ptr_int, # IPIV nb_i32p, # IPIV
float_pointer, # B float_pointer, # B
_ptr_int, # LDB nb_i32p, # LDB
_ptr_int, # INFO nb_i32p, # INFO
)
)
@numba_basic.numba_njit
def getrs(TRANS, N, NRHS, A, LDA, IPIV, B, LDB, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_getrs_pointer,
func_type_ref=getrs_function_type,
unique_func_name_lit=unique_func_name,
) )
return functype(lapack_ptr) fn(TRANS, N, NRHS, A, LDA, IPIV, B, LDB, INFO)
return getrs
@classmethod @classmethod
def numba_xsysv(cls, dtype): def numba_xsysv(cls, dtype) -> CPUDispatcher:
""" """
Solve a system of linear equations A @ X = B with a symmetric matrix A using the diagonal pivoting method, Solve a system of linear equations A @ X = B with a symmetric matrix A using the diagonal pivoting method,
factorizing A into LDL^T or UDU^T form, depending on the value of UPLO factorizing A into LDL^T or UDU^T form, depending on the value of UPLO
Called by scipy.linalg.solve when assume_a == "sym" Called by scipy.linalg.solve when assume_a == "sym"
""" """
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "sysv") kind = get_blas_kind(dtype)
functype = ctypes.CFUNCTYPE( float_pointer = _get_nb_float_from_dtype(kind)
None, unique_func_name = f"scipy.lapack.{kind}sysv"
_ptr_int, # UPLO
_ptr_int, # N @numba_basic.numba_njit
_ptr_int, # NRHS def get_sysv_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "sysv")
return ptr
sysv_function_type = types.FunctionType(
types.void(
nb_i32p, # UPLO
nb_i32p, # N
nb_i32p, # NRHS
float_pointer, # A float_pointer, # A
_ptr_int, # LDA nb_i32p, # LDA
_ptr_int, # IPIV nb_i32p, # IPIV
float_pointer, # B float_pointer, # B
_ptr_int, # LDB nb_i32p, # LDB
float_pointer, # WORK float_pointer, # WORK
_ptr_int, # LWORK nb_i32p, # LWORK
_ptr_int, # INFO nb_i32p, # INFO
)
) )
return functype(lapack_ptr)
@numba_basic.numba_njit
def sysv(UPLO, N, NRHS, A, LDA, IPIV, B, LDB, WORK, LWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_sysv_pointer,
func_type_ref=sysv_function_type,
unique_func_name_lit=unique_func_name,
)
fn(UPLO, N, NRHS, A, LDA, IPIV, B, LDB, WORK, LWORK, INFO)
return sysv
@classmethod @classmethod
def numba_xsycon(cls, dtype): def numba_xsycon(cls, dtype) -> CPUDispatcher:
""" """
Estimate the reciprocal of the condition number of a symmetric matrix A using the UDU or LDL factorization Estimate the reciprocal of the condition number of a symmetric matrix A using the UDU or LDL factorization
computed by xSYTRF. computed by xSYTRF.
""" """
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "sycon") kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}sycon"
@numba_basic.numba_njit
def get_sycon_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "sycon")
return ptr
functype = ctypes.CFUNCTYPE( sycon_function_type = types.FunctionType(
None, types.void(
_ptr_int, # UPLO nb_i32p, # UPLO
_ptr_int, # N nb_i32p, # N
float_pointer, # A float_pointer, # A
_ptr_int, # LDA nb_i32p, # LDA
_ptr_int, # IPIV nb_i32p, # IPIV
float_pointer, # ANORM float_pointer, # ANORM
float_pointer, # RCOND float_pointer, # RCOND
float_pointer, # WORK float_pointer, # WORK
_ptr_int, # IWORK nb_i32p, # IWORK
_ptr_int, # INFO nb_i32p, # INFO
)
) )
return functype(lapack_ptr)
@numba_basic.numba_njit
def sycon(UPLO, N, A, LDA, IPIV, ANORM, RCOND, WORK, IWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_sycon_pointer,
func_type_ref=sycon_function_type,
unique_func_name_lit=unique_func_name,
)
fn(UPLO, N, A, LDA, IPIV, ANORM, RCOND, WORK, IWORK, INFO)
return sycon
@classmethod @classmethod
def numba_xpocon(cls, dtype): def numba_xpocon(cls, dtype) -> CPUDispatcher:
""" """
Estimates the reciprocal of the condition number of a positive definite matrix A using the Cholesky factorization Estimates the reciprocal of the condition number of a positive definite matrix A using the Cholesky factorization
computed by potrf. computed by potrf.
Called by scipy.linalg.solve when assume_a == "pos" Called by scipy.linalg.solve when assume_a == "pos"
""" """
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "pocon") kind = get_blas_kind(dtype)
functype = ctypes.CFUNCTYPE( float_pointer = _get_nb_float_from_dtype(kind)
None, unique_func_name = f"scipy.lapack.{kind}pocon"
_ptr_int, # UPLO
_ptr_int, # N @numba_basic.numba_njit
def get_pocon_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "pocon")
return ptr
pocon_function_type = types.FunctionType(
types.void(
nb_i32p, # UPLO
nb_i32p, # N
float_pointer, # A float_pointer, # A
_ptr_int, # LDA nb_i32p, # LDA
float_pointer, # ANORM float_pointer, # ANORM
float_pointer, # RCOND float_pointer, # RCOND
float_pointer, # WORK float_pointer, # WORK
_ptr_int, # IWORK nb_i32p, # IWORK
_ptr_int, # INFO nb_i32p, # INFO
)
) )
return functype(lapack_ptr)
@numba_basic.numba_njit
def pocon(UPLO, N, A, LDA, ANORM, RCOND, WORK, IWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_pocon_pointer,
func_type_ref=pocon_function_type,
unique_func_name_lit=unique_func_name,
)
fn(UPLO, N, A, LDA, ANORM, RCOND, WORK, IWORK, INFO)
return pocon
@classmethod @classmethod
def numba_xposv(cls, dtype): def numba_xposv(cls, dtype) -> CPUDispatcher:
""" """
Solve a system of linear equations A @ X = B with a symmetric positive definite matrix A using the Cholesky Solve a system of linear equations A @ X = B with a symmetric positive definite matrix A using the Cholesky
factorization computed by potrf. factorization computed by potrf.
""" """
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}posv"
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "posv") @numba_basic.numba_njit
functype = ctypes.CFUNCTYPE( def get_posv_pointer():
None, with numba.objmode(ptr=types.intp):
_ptr_int, # UPLO ptr = get_lapack_ptr(dtype, "posv")
_ptr_int, # N return ptr
_ptr_int, # NRHS
posv_function_type = types.FunctionType(
types.void(
nb_i32p, # UPLO
nb_i32p, # N
nb_i32p, # NRHS
float_pointer, # A float_pointer, # A
_ptr_int, # LDA nb_i32p, # LDA
float_pointer, # B float_pointer, # B
_ptr_int, # LDB nb_i32p, # LDB
_ptr_int, # INFO nb_i32p, # INFO
)
) )
return functype(lapack_ptr)
@numba_basic.numba_njit
def posv(UPLO, N, NRHS, A, LDA, B, LDB, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_posv_pointer,
func_type_ref=posv_function_type,
unique_func_name_lit=unique_func_name,
)
fn(UPLO, N, NRHS, A, LDA, B, LDB, INFO)
return posv
@classmethod @classmethod
def numba_xgttrf(cls, dtype): def numba_xgttrf(cls, dtype) -> CPUDispatcher:
""" """
Compute the LU factorization of a tridiagonal matrix A using row interchanges. Compute the LU factorization of a tridiagonal matrix A using row interchanges.
Called by scipy.linalg.lu_factor Called by scipy.linalg.solve when assume_a == "tri"
""" """
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "gttrf") kind = get_blas_kind(dtype)
functype = ctypes.CFUNCTYPE( float_pointer = _get_nb_float_from_dtype(kind)
None, unique_func_name = f"scipy.lapack.{kind}gttrf"
_ptr_int, # N
@numba_basic.numba_njit
def get_gttrf_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "gttrf")
return ptr
gttrf_function_type = types.FunctionType(
types.void(
nb_i32p, # N
float_pointer, # DL float_pointer, # DL
float_pointer, # D float_pointer, # D
float_pointer, # DU float_pointer, # DU
float_pointer, # DU2 float_pointer, # DU2
_ptr_int, # IPIV nb_i32p, # IPIV
_ptr_int, # INFO nb_i32p, # INFO
)
) )
return functype(lapack_ptr)
@numba_basic.numba_njit
def gttrf(N, DL, D, DU, DU2, IPIV, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_gttrf_pointer,
func_type_ref=gttrf_function_type,
unique_func_name_lit=unique_func_name,
)
fn(N, DL, D, DU, DU2, IPIV, INFO)
return gttrf
@classmethod @classmethod
def numba_xgttrs(cls, dtype): def numba_xgttrs(cls, dtype) -> CPUDispatcher:
""" """
Solve a system of linear equations A @ X = B with a tridiagonal matrix A using the LU factorization computed by numba_gttrf. Solve a system of linear equations A @ X = B with a tridiagonal matrix A using the LU factorization computed by numba_gttrf.
Called by scipy.linalg.lu_solve Called by scipy.linalg.solve, when assume_a == "tri"
""" """
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "gttrs") kind = get_blas_kind(dtype)
functype = ctypes.CFUNCTYPE( float_pointer = _get_nb_float_from_dtype(kind)
None, unique_func_name = f"scipy.lapack.{kind}gttrs"
_ptr_int, # TRANS
_ptr_int, # N @numba_basic.numba_njit
_ptr_int, # NRHS def get_gttrs_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "gttrs")
return ptr
gttrs_function_type = types.FunctionType(
types.void(
nb_i32p, # TRANS
nb_i32p, # N
nb_i32p, # NRHS
float_pointer, # DL float_pointer, # DL
float_pointer, # D float_pointer, # D
float_pointer, # DU float_pointer, # DU
float_pointer, # DU2 float_pointer, # DU2
_ptr_int, # IPIV nb_i32p, # IPIV
float_pointer, # B float_pointer, # B
_ptr_int, # LDB nb_i32p, # LDB
_ptr_int, # INFO nb_i32p, # INFO
)
)
@numba_basic.numba_njit
def gttrs(TRANS, N, NRHS, DL, D, DU, DU2, IPIV, B, LDB, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_gttrs_pointer,
func_type_ref=gttrs_function_type,
unique_func_name_lit=unique_func_name,
) )
return functype(lapack_ptr) fn(TRANS, N, NRHS, DL, D, DU, DU2, IPIV, B, LDB, INFO)
return gttrs
@classmethod @classmethod
def numba_xgtcon(cls, dtype): def numba_xgtcon(cls, dtype) -> CPUDispatcher:
""" """
Estimate the reciprocal of the condition number of a tridiagonal matrix A using the LU factorization computed by numba_gttrf. Estimate the reciprocal of the condition number of a tridiagonal matrix A using the LU factorization computed by numba_gttrf.
""" """
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "gtcon") kind = get_blas_kind(dtype)
functype = ctypes.CFUNCTYPE( float_pointer = _get_nb_float_from_dtype(kind)
None, unique_func_name = f"scipy.lapack.{kind}gtcon"
_ptr_int, # NORM
_ptr_int, # N @numba_basic.numba_njit
def get_gtcon_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "gtcon")
return ptr
gtcon_function_type = types.FunctionType(
types.void(
nb_i32p, # NORM
nb_i32p, # N
float_pointer, # DL float_pointer, # DL
float_pointer, # D float_pointer, # D
float_pointer, # DU float_pointer, # DU
float_pointer, # DU2 float_pointer, # DU2
_ptr_int, # IPIV nb_i32p, # IPIV
float_pointer, # ANORM float_pointer, # ANORM
float_pointer, # RCOND float_pointer, # RCOND
float_pointer, # WORK float_pointer, # WORK
_ptr_int, # IWORK nb_i32p, # IWORK
_ptr_int, # INFO nb_i32p, # INFO
) )
return functype(lapack_ptr) )
@numba_basic.numba_njit
def gtcon(NORM, N, DL, D, DU, DU2, IPIV, ANORM, RCOND, WORK, IWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_gtcon_pointer,
func_type_ref=gtcon_function_type,
unique_func_name_lit=unique_func_name,
)
fn(NORM, N, DL, D, DU, DU2, IPIV, ANORM, RCOND, WORK, IWORK, INFO)
return gtcon
@classmethod @classmethod
def numba_xgeqrf(cls, dtype): def numba_xgeqrf(cls, dtype) -> CPUDispatcher:
""" """
Compute the QR factorization of a general M-by-N matrix A. Compute the QR factorization of a general M-by-N matrix A.
Used in QR decomposition (no pivoting). Used in QR decomposition (no pivoting).
""" """
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "geqrf") kind = get_blas_kind(dtype)
functype = ctypes.CFUNCTYPE( float_pointer = _get_nb_float_from_dtype(kind)
None, unique_func_name = f"scipy.lapack.{kind}geqrf"
_ptr_int, # M
_ptr_int, # N @numba_basic.numba_njit
def get_geqrf_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "geqrf")
return ptr
geqrf_function_type = types.FunctionType(
types.void(
nb_i32p, # M
nb_i32p, # N
float_pointer, # A float_pointer, # A
_ptr_int, # LDA nb_i32p, # LDA
float_pointer, # TAU float_pointer, # TAU
float_pointer, # WORK float_pointer, # WORK
_ptr_int, # LWORK nb_i32p, # LWORK
_ptr_int, # INFO nb_i32p, # INFO
)
)
@numba_basic.numba_njit
def geqrf(M, N, A, LDA, TAU, WORK, LWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_geqrf_pointer,
func_type_ref=geqrf_function_type,
unique_func_name_lit=unique_func_name,
) )
return functype(lapack_ptr) fn(M, N, A, LDA, TAU, WORK, LWORK, INFO)
return geqrf
@classmethod @classmethod
def numba_xgeqp3(cls, dtype): def numba_xgeqp3(cls, dtype) -> CPUDispatcher:
""" """
Compute the QR factorization with column pivoting of a general M-by-N matrix A. Compute the QR factorization with column pivoting of a general M-by-N matrix A.
Used in QR decomposition with pivoting. Used in QR decomposition with pivoting.
""" """
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "geqp3") kind = get_blas_kind(dtype)
ctype_args = ( float_pointer = _get_nb_float_from_dtype(kind)
_ptr_int, # M unique_func_name = f"scipy.lapack.{kind}geqp3"
_ptr_int, # N
@numba_basic.numba_njit
def get_geqp3_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "geqp3")
return ptr
if isinstance(dtype, Complex):
real_pointer = nb_f64p if dtype is nb_c128 else nb_f32p
geqp3_function_type = types.FunctionType(
types.void(
nb_i32p, # M
nb_i32p, # N
float_pointer, # A float_pointer, # A
_ptr_int, # LDA nb_i32p, # LDA
_ptr_int, # JPVT nb_i32p, # JPVT
float_pointer, # TAU float_pointer, # TAU
float_pointer, # WORK float_pointer, # WORK
_ptr_int, # LWORK nb_i32p, # LWORK
real_pointer, # RWORK
nb_i32p, # INFO
)
) )
if isinstance(dtype, Complex): @numba_basic.numba_njit
ctype_args = ( def geqp3(M, N, A, LDA, JPVT, TAU, WORK, LWORK, RWORK, INFO):
*ctype_args, fn = _call_cached_ptr(
float_pointer, # RWORK) get_ptr_func=get_geqp3_pointer,
func_type_ref=geqp3_function_type,
unique_func_name_lit=unique_func_name,
) )
fn(M, N, A, LDA, JPVT, TAU, WORK, LWORK, RWORK, INFO)
functype = ctypes.CFUNCTYPE( else:
None, geqp3_function_type = types.FunctionType(
*ctype_args, types.void(
_ptr_int, # INFO nb_i32p, # M
nb_i32p, # N
float_pointer, # A
nb_i32p, # LDA
nb_i32p, # JPVT
float_pointer, # TAU
float_pointer, # WORK
nb_i32p, # LWORK
nb_i32p, # INFO
)
) )
return functype(lapack_ptr) @numba_basic.numba_njit
def geqp3(M, N, A, LDA, JPVT, TAU, WORK, LWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_geqp3_pointer,
func_type_ref=geqp3_function_type,
unique_func_name_lit=unique_func_name,
)
fn(M, N, A, LDA, JPVT, TAU, WORK, LWORK, INFO)
return geqp3
@classmethod @classmethod
def numba_xorgqr(cls, dtype): def numba_xorgqr(cls, dtype) -> CPUDispatcher:
""" """
Generate the orthogonal matrix Q from a QR factorization (real types). Generate the orthogonal matrix Q from a QR factorization (real types).
Used in QR decomposition to form Q. Used in QR decomposition to form Q.
""" """
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "orgqr") kind = get_blas_kind(dtype)
functype = ctypes.CFUNCTYPE( float_pointer = _get_nb_float_from_dtype(kind)
None, unique_func_name = f"scipy.lapack.{kind}orgqr"
_ptr_int, # M
_ptr_int, # N @numba_basic.numba_njit
_ptr_int, # K def get_orgqr_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "orgqr")
return ptr
orgqr_function_type = types.FunctionType(
types.void(
nb_i32p, # M
nb_i32p, # N
nb_i32p, # K
float_pointer, # A float_pointer, # A
_ptr_int, # LDA nb_i32p, # LDA
float_pointer, # TAU float_pointer, # TAU
float_pointer, # WORK float_pointer, # WORK
_ptr_int, # LWORK nb_i32p, # LWORK
_ptr_int, # INFO nb_i32p, # INFO
)
)
@numba_basic.numba_njit
def orgqr(M, N, K, A, LDA, TAU, WORK, LWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_orgqr_pointer,
func_type_ref=orgqr_function_type,
unique_func_name_lit=unique_func_name,
) )
return functype(lapack_ptr) fn(M, N, K, A, LDA, TAU, WORK, LWORK, INFO)
return orgqr
@classmethod @classmethod
def numba_xungqr(cls, dtype): def numba_xungqr(cls, dtype) -> CPUDispatcher:
""" """
Generate the unitary matrix Q from a QR factorization (complex types). Generate the unitary matrix Q from a QR factorization (complex types).
Used in QR decomposition to form Q for complex types. Used in QR decomposition to form Q for complex types.
""" """
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "ungqr") kind = get_blas_kind(dtype)
functype = ctypes.CFUNCTYPE( float_pointer = _get_nb_float_from_dtype(kind)
None, unique_func_name = f"scipy.lapack.{kind}ungqr"
_ptr_int, # M
_ptr_int, # N @numba_basic.numba_njit
_ptr_int, # K def get_ungqr_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "ungqr")
return ptr
ungqr_function_type = types.FunctionType(
types.void(
nb_i32p, # M
nb_i32p, # N
nb_i32p, # K
float_pointer, # A float_pointer, # A
_ptr_int, # LDA nb_i32p, # LDA
float_pointer, # TAU float_pointer, # TAU
float_pointer, # WORK float_pointer, # WORK
_ptr_int, # LWORK nb_i32p, # LWORK
_ptr_int, # INFO nb_i32p, # INFO
) )
return functype(lapack_ptr) )
@numba_basic.numba_njit
def ungqr(M, N, K, A, LDA, TAU, WORK, LWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_ungqr_pointer,
func_type_ref=ungqr_function_type,
unique_func_name_lit=unique_func_name,
)
fn(M, N, K, A, LDA, TAU, WORK, LWORK, INFO)
return ungqr
...@@ -31,7 +31,6 @@ def _xgeqrf(A: np.ndarray, overwrite_a: bool, lwork: int): ...@@ -31,7 +31,6 @@ def _xgeqrf(A: np.ndarray, overwrite_a: bool, lwork: int):
def xgeqrf_impl(A, overwrite_a, lwork): def xgeqrf_impl(A, overwrite_a, lwork):
ensure_lapack() ensure_lapack()
dtype = A.dtype dtype = A.dtype
w_type = _get_underlying_float(dtype)
geqrf = _LAPACK().numba_xgeqrf(dtype) geqrf = _LAPACK().numba_xgeqrf(dtype)
def impl(A, overwrite_a, lwork): def impl(A, overwrite_a, lwork):
...@@ -57,10 +56,10 @@ def xgeqrf_impl(A, overwrite_a, lwork): ...@@ -57,10 +56,10 @@ def xgeqrf_impl(A, overwrite_a, lwork):
geqrf( geqrf(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
A_copy.T.view(w_type).T.ctypes, A_copy.ctypes,
LDA, LDA,
TAU.view(w_type).ctypes, TAU.ctypes,
WORK.view(w_type).ctypes, WORK.ctypes,
LWORK, LWORK,
INFO, INFO,
) )
...@@ -82,7 +81,6 @@ def _xgeqp3(A: np.ndarray, overwrite_a: bool, lwork: int): ...@@ -82,7 +81,6 @@ def _xgeqp3(A: np.ndarray, overwrite_a: bool, lwork: int):
def xgeqp3_impl(A, overwrite_a, lwork): def xgeqp3_impl(A, overwrite_a, lwork):
ensure_lapack() ensure_lapack()
dtype = A.dtype dtype = A.dtype
w_type = _get_underlying_float(dtype)
geqp3 = _LAPACK().numba_xgeqp3(dtype) geqp3 = _LAPACK().numba_xgeqp3(dtype)
def impl(A, overwrite_a, lwork): def impl(A, overwrite_a, lwork):
...@@ -109,11 +107,11 @@ def xgeqp3_impl(A, overwrite_a, lwork): ...@@ -109,11 +107,11 @@ def xgeqp3_impl(A, overwrite_a, lwork):
geqp3( geqp3(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
A_copy.T.view(w_type).T.ctypes, A_copy.ctypes,
LDA, LDA,
JPVT.ctypes, JPVT.ctypes,
TAU.view(w_type).ctypes, TAU.ctypes,
WORK.view(w_type).ctypes, WORK.ctypes,
LWORK, LWORK,
INFO, INFO,
) )
...@@ -135,7 +133,6 @@ def _xorgqr(A: np.ndarray, tau: np.ndarray, overwrite_a: bool, lwork: int): ...@@ -135,7 +133,6 @@ def _xorgqr(A: np.ndarray, tau: np.ndarray, overwrite_a: bool, lwork: int):
def xorgqr_impl(A, tau, overwrite_a, lwork): def xorgqr_impl(A, tau, overwrite_a, lwork):
ensure_lapack() ensure_lapack()
dtype = A.dtype dtype = A.dtype
w_type = _get_underlying_float(dtype)
orgqr = _LAPACK().numba_xorgqr(dtype) orgqr = _LAPACK().numba_xorgqr(dtype)
def impl(A, tau, overwrite_a, lwork): def impl(A, tau, overwrite_a, lwork):
...@@ -162,10 +159,10 @@ def xorgqr_impl(A, tau, overwrite_a, lwork): ...@@ -162,10 +159,10 @@ def xorgqr_impl(A, tau, overwrite_a, lwork):
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
val_to_int_ptr(K), val_to_int_ptr(K),
A_copy.T.view(w_type).T.ctypes, A_copy.ctypes,
LDA, LDA,
tau.view(w_type).ctypes, tau.ctypes,
WORK.view(w_type).ctypes, WORK.ctypes,
LWORK, LWORK,
INFO, INFO,
) )
...@@ -188,7 +185,6 @@ def xungqr_impl(A, tau, overwrite_a, lwork): ...@@ -188,7 +185,6 @@ def xungqr_impl(A, tau, overwrite_a, lwork):
ensure_lapack() ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=(Float, Complex), func_name="qr") _check_linalg_matrix(A, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = A.dtype dtype = A.dtype
w_type = _get_underlying_float(dtype)
ungqr = _LAPACK().numba_xungqr(dtype) ungqr = _LAPACK().numba_xungqr(dtype)
def impl(A, tau, overwrite_a, lwork): def impl(A, tau, overwrite_a, lwork):
...@@ -214,10 +210,10 @@ def xungqr_impl(A, tau, overwrite_a, lwork): ...@@ -214,10 +210,10 @@ def xungqr_impl(A, tau, overwrite_a, lwork):
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
val_to_int_ptr(K), val_to_int_ptr(K),
A_copy.T.view(w_type).T.ctypes, A_copy.ctypes,
LDA, LDA,
tau.view(w_type).ctypes, tau.ctypes,
WORK.view(w_type).ctypes, WORK.ctypes,
LWORK, LWORK,
INFO, INFO,
) )
...@@ -426,11 +422,11 @@ def qr_full_pivot_impl( ...@@ -426,11 +422,11 @@ def qr_full_pivot_impl(
geqp3( geqp3(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes, x_copy.ctypes,
LDA, LDA,
JPVT.ctypes, JPVT.ctypes,
TAU.view(w_type).ctypes, TAU.ctypes,
WORK.view(w_type).ctypes, WORK.ctypes,
val_to_int_ptr(-1), # LWORK val_to_int_ptr(-1), # LWORK
RWORK.ctypes, RWORK.ctypes,
val_to_int_ptr(1), # INFO val_to_int_ptr(1), # INFO
...@@ -439,11 +435,11 @@ def qr_full_pivot_impl( ...@@ -439,11 +435,11 @@ def qr_full_pivot_impl(
geqp3( geqp3(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes, x_copy.ctypes,
LDA, LDA,
JPVT.ctypes, JPVT.ctypes,
TAU.view(w_type).ctypes, TAU.ctypes,
WORK.view(w_type).ctypes, WORK.ctypes,
val_to_int_ptr(-1), val_to_int_ptr(-1),
val_to_int_ptr(1), val_to_int_ptr(1),
) )
...@@ -458,11 +454,11 @@ def qr_full_pivot_impl( ...@@ -458,11 +454,11 @@ def qr_full_pivot_impl(
geqp3( geqp3(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes, x_copy.ctypes,
LDA, LDA,
JPVT.ctypes, JPVT.ctypes,
TAU.view(w_type).ctypes, TAU.ctypes,
WORK.view(w_type).ctypes, WORK.ctypes,
val_to_int_ptr(lwork_val), val_to_int_ptr(lwork_val),
RWORK.ctypes, RWORK.ctypes,
INFO, INFO,
...@@ -471,11 +467,11 @@ def qr_full_pivot_impl( ...@@ -471,11 +467,11 @@ def qr_full_pivot_impl(
geqp3( geqp3(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes, x_copy.ctypes,
LDA, LDA,
JPVT.ctypes, JPVT.ctypes,
TAU.view(w_type).ctypes, TAU.ctypes,
WORK.view(w_type).ctypes, WORK.ctypes,
val_to_int_ptr(lwork_val), val_to_int_ptr(lwork_val),
INFO, INFO,
) )
...@@ -501,10 +497,10 @@ def qr_full_pivot_impl( ...@@ -501,10 +497,10 @@ def qr_full_pivot_impl(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(Q_in.shape[1]), val_to_int_ptr(Q_in.shape[1]),
val_to_int_ptr(K), val_to_int_ptr(K),
Q_in.T.view(w_type).T.ctypes, Q_in.ctypes,
val_to_int_ptr(M), val_to_int_ptr(M),
TAU.view(w_type).ctypes, TAU.ctypes,
WORKQ.view(w_type).ctypes, WORKQ.ctypes,
val_to_int_ptr(-1), val_to_int_ptr(-1),
val_to_int_ptr(1), val_to_int_ptr(1),
) )
...@@ -519,10 +515,10 @@ def qr_full_pivot_impl( ...@@ -519,10 +515,10 @@ def qr_full_pivot_impl(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(Q_in.shape[1]), val_to_int_ptr(Q_in.shape[1]),
val_to_int_ptr(K), val_to_int_ptr(K),
Q_in.T.view(w_type).T.ctypes, Q_in.ctypes,
val_to_int_ptr(M), val_to_int_ptr(M),
TAU.view(w_type).ctypes, TAU.ctypes,
WORKQ.view(w_type).ctypes, WORKQ.ctypes,
val_to_int_ptr(lwork_q), val_to_int_ptr(lwork_q),
INFOQ, INFOQ,
) )
...@@ -538,7 +534,6 @@ def qr_full_no_pivot_impl( ...@@ -538,7 +534,6 @@ def qr_full_no_pivot_impl(
ensure_lapack() ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr") _check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype dtype = x.dtype
w_type = _get_underlying_float(dtype)
geqrf = _LAPACK().numba_xgeqrf(dtype) geqrf = _LAPACK().numba_xgeqrf(dtype)
orgqr = ( orgqr = (
_LAPACK().numba_xorgqr(dtype) _LAPACK().numba_xorgqr(dtype)
...@@ -574,10 +569,10 @@ def qr_full_no_pivot_impl( ...@@ -574,10 +569,10 @@ def qr_full_no_pivot_impl(
geqrf( geqrf(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes, x_copy.ctypes,
LDA, LDA,
TAU.view(w_type).ctypes, TAU.ctypes,
WORK.view(w_type).ctypes, WORK.ctypes,
val_to_int_ptr(-1), val_to_int_ptr(-1),
val_to_int_ptr(1), val_to_int_ptr(1),
) )
...@@ -591,10 +586,10 @@ def qr_full_no_pivot_impl( ...@@ -591,10 +586,10 @@ def qr_full_no_pivot_impl(
geqrf( geqrf(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes, x_copy.ctypes,
LDA, LDA,
TAU.view(w_type).ctypes, TAU.ctypes,
WORK.view(w_type).ctypes, WORK.ctypes,
val_to_int_ptr(lwork_val), val_to_int_ptr(lwork_val),
INFO, INFO,
) )
...@@ -619,10 +614,10 @@ def qr_full_no_pivot_impl( ...@@ -619,10 +614,10 @@ def qr_full_no_pivot_impl(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(Q_in.shape[1]), val_to_int_ptr(Q_in.shape[1]),
val_to_int_ptr(K), val_to_int_ptr(K),
Q_in.T.view(w_type).T.ctypes, Q_in.ctypes,
val_to_int_ptr(M), val_to_int_ptr(M),
TAU.view(w_type).ctypes, TAU.ctypes,
WORKQ.view(w_type).ctypes, WORKQ.ctypes,
val_to_int_ptr(-1), val_to_int_ptr(-1),
val_to_int_ptr(1), val_to_int_ptr(1),
) )
...@@ -637,10 +632,10 @@ def qr_full_no_pivot_impl( ...@@ -637,10 +632,10 @@ def qr_full_no_pivot_impl(
val_to_int_ptr(M), # M val_to_int_ptr(M), # M
val_to_int_ptr(Q_in.shape[1]), # N val_to_int_ptr(Q_in.shape[1]), # N
val_to_int_ptr(K), # K val_to_int_ptr(K), # K
Q_in.T.view(w_type).T.ctypes, # A Q_in.ctypes, # A
val_to_int_ptr(M), # LDA val_to_int_ptr(M), # LDA
TAU.view(w_type).ctypes, # TAU TAU.ctypes, # TAU
WORKQ.view(w_type).ctypes, # WORK WORKQ.ctypes, # WORK
val_to_int_ptr(lwork_q), # LWORK val_to_int_ptr(lwork_q), # LWORK
INFOQ, # INFO INFOQ, # INFO
) )
...@@ -656,7 +651,6 @@ def qr_r_pivot_impl( ...@@ -656,7 +651,6 @@ def qr_r_pivot_impl(
ensure_lapack() ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr") _check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype dtype = x.dtype
w_type = _get_underlying_float(dtype)
geqp3 = _LAPACK().numba_xgeqp3(dtype) geqp3 = _LAPACK().numba_xgeqp3(dtype)
def impl( def impl(
...@@ -687,11 +681,11 @@ def qr_r_pivot_impl( ...@@ -687,11 +681,11 @@ def qr_r_pivot_impl(
geqp3( geqp3(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes, x_copy.ctypes,
LDA, LDA,
JPVT.ctypes, JPVT.ctypes,
TAU.view(w_type).ctypes, TAU.ctypes,
WORK.view(w_type).ctypes, WORK.ctypes,
val_to_int_ptr(-1), val_to_int_ptr(-1),
val_to_int_ptr(1), val_to_int_ptr(1),
) )
...@@ -705,11 +699,11 @@ def qr_r_pivot_impl( ...@@ -705,11 +699,11 @@ def qr_r_pivot_impl(
geqp3( geqp3(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes, x_copy.ctypes,
LDA, LDA,
JPVT.ctypes, JPVT.ctypes,
TAU.view(w_type).ctypes, TAU.ctypes,
WORK.view(w_type).ctypes, WORK.ctypes,
val_to_int_ptr(lwork_val), val_to_int_ptr(lwork_val),
INFO, INFO,
) )
...@@ -732,7 +726,6 @@ def qr_r_no_pivot_impl( ...@@ -732,7 +726,6 @@ def qr_r_no_pivot_impl(
ensure_lapack() ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr") _check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype dtype = x.dtype
w_type = _get_underlying_float(dtype)
geqrf = _LAPACK().numba_xgeqrf(dtype) geqrf = _LAPACK().numba_xgeqrf(dtype)
def impl( def impl(
...@@ -762,10 +755,10 @@ def qr_r_no_pivot_impl( ...@@ -762,10 +755,10 @@ def qr_r_no_pivot_impl(
geqrf( geqrf(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes, x_copy.ctypes,
LDA, LDA,
TAU.view(w_type).ctypes, TAU.ctypes,
WORK.view(w_type).ctypes, WORK.ctypes,
val_to_int_ptr(-1), val_to_int_ptr(-1),
val_to_int_ptr(1), val_to_int_ptr(1),
) )
...@@ -779,10 +772,10 @@ def qr_r_no_pivot_impl( ...@@ -779,10 +772,10 @@ def qr_r_no_pivot_impl(
geqrf( geqrf(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes, x_copy.ctypes,
LDA, LDA,
TAU.view(w_type).ctypes, TAU.ctypes,
WORK.view(w_type).ctypes, WORK.ctypes,
val_to_int_ptr(lwork_val), val_to_int_ptr(lwork_val),
INFO, INFO,
) )
...@@ -805,7 +798,6 @@ def qr_raw_no_pivot_impl( ...@@ -805,7 +798,6 @@ def qr_raw_no_pivot_impl(
ensure_lapack() ensure_lapack()
_check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr") _check_linalg_matrix(x, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = x.dtype dtype = x.dtype
w_type = _get_underlying_float(dtype)
geqrf = _LAPACK().numba_xgeqrf(dtype) geqrf = _LAPACK().numba_xgeqrf(dtype)
def impl( def impl(
...@@ -835,10 +827,10 @@ def qr_raw_no_pivot_impl( ...@@ -835,10 +827,10 @@ def qr_raw_no_pivot_impl(
geqrf( geqrf(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes, x_copy.ctypes,
LDA, LDA,
TAU.view(w_type).ctypes, TAU.ctypes,
WORK.view(w_type).ctypes, WORK.ctypes,
val_to_int_ptr(-1), val_to_int_ptr(-1),
val_to_int_ptr(1), val_to_int_ptr(1),
) )
...@@ -852,10 +844,10 @@ def qr_raw_no_pivot_impl( ...@@ -852,10 +844,10 @@ def qr_raw_no_pivot_impl(
geqrf( geqrf(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes, x_copy.ctypes,
LDA, LDA,
TAU.view(w_type).ctypes, TAU.ctypes,
WORK.view(w_type).ctypes, WORK.ctypes,
val_to_int_ptr(lwork_val), val_to_int_ptr(lwork_val),
INFO, INFO,
) )
...@@ -914,11 +906,11 @@ def qr_raw_pivot_impl( ...@@ -914,11 +906,11 @@ def qr_raw_pivot_impl(
geqp3( geqp3(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes, x_copy.ctypes,
LDA, LDA,
JPVT.ctypes, JPVT.ctypes,
TAU.view(w_type).ctypes, TAU.ctypes,
WORK.view(w_type).ctypes, WORK.ctypes,
val_to_int_ptr(-1), # LWORK val_to_int_ptr(-1), # LWORK
RWORK.ctypes, RWORK.ctypes,
val_to_int_ptr(1), # INFO val_to_int_ptr(1), # INFO
...@@ -927,11 +919,11 @@ def qr_raw_pivot_impl( ...@@ -927,11 +919,11 @@ def qr_raw_pivot_impl(
geqp3( geqp3(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes, x_copy.ctypes,
LDA, LDA,
JPVT.ctypes, JPVT.ctypes,
TAU.view(w_type).ctypes, TAU.ctypes,
WORK.view(w_type).ctypes, WORK.ctypes,
val_to_int_ptr(-1), val_to_int_ptr(-1),
val_to_int_ptr(1), val_to_int_ptr(1),
) )
...@@ -946,11 +938,11 @@ def qr_raw_pivot_impl( ...@@ -946,11 +938,11 @@ def qr_raw_pivot_impl(
geqp3( geqp3(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes, x_copy.ctypes,
LDA, LDA,
JPVT.ctypes, JPVT.ctypes,
TAU.view(w_type).ctypes, TAU.ctypes,
WORK.view(w_type).ctypes, WORK.ctypes,
val_to_int_ptr(lwork_val), val_to_int_ptr(lwork_val),
RWORK.ctypes, RWORK.ctypes,
INFO, INFO,
...@@ -959,11 +951,11 @@ def qr_raw_pivot_impl( ...@@ -959,11 +951,11 @@ def qr_raw_pivot_impl(
geqp3( geqp3(
val_to_int_ptr(M), val_to_int_ptr(M),
val_to_int_ptr(N), val_to_int_ptr(N),
x_copy.T.view(w_type).T.ctypes, x_copy.ctypes,
LDA, LDA,
JPVT.ctypes, JPVT.ctypes,
TAU.view(w_type).ctypes, TAU.ctypes,
WORK.view(w_type).ctypes, WORK.ctypes,
val_to_int_ptr(lwork_val), val_to_int_ptr(lwork_val),
INFO, INFO,
) )
......
...@@ -6,7 +6,6 @@ from pytensor import config ...@@ -6,7 +6,6 @@ from pytensor import config
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import ( from pytensor.link.numba.dispatch.basic import (
generate_fallback_impl, generate_fallback_impl,
numba_funcify,
register_funcify_default_op_cache_key, register_funcify_default_op_cache_key,
) )
from pytensor.link.numba.dispatch.linalg.decomposition.cholesky import _cholesky from pytensor.link.numba.dispatch.linalg.decomposition.cholesky import _cholesky
...@@ -44,7 +43,7 @@ from pytensor.tensor.slinalg import ( ...@@ -44,7 +43,7 @@ from pytensor.tensor.slinalg import (
) )
@numba_funcify.register(Cholesky) @register_funcify_default_op_cache_key(Cholesky)
def numba_funcify_Cholesky(op, node, **kwargs): def numba_funcify_Cholesky(op, node, **kwargs):
""" """
Overload scipy.linalg.cholesky with a numba function. Overload scipy.linalg.cholesky with a numba function.
...@@ -95,7 +94,8 @@ def numba_funcify_Cholesky(op, node, **kwargs): ...@@ -95,7 +94,8 @@ def numba_funcify_Cholesky(op, node, **kwargs):
return res return res
return cholesky cache_key = 1
return cholesky, cache_key
@register_funcify_default_op_cache_key(PivotToPermutations) @register_funcify_default_op_cache_key(PivotToPermutations)
...@@ -115,7 +115,7 @@ def pivot_to_permutation(op, node, **kwargs): ...@@ -115,7 +115,7 @@ def pivot_to_permutation(op, node, **kwargs):
return numba_pivot_to_permutation, cache_key return numba_pivot_to_permutation, cache_key
@numba_funcify.register(LU) @register_funcify_default_op_cache_key(LU)
def numba_funcify_LU(op, node, **kwargs): def numba_funcify_LU(op, node, **kwargs):
inp_dtype = node.inputs[0].type.numpy_dtype inp_dtype = node.inputs[0].type.numpy_dtype
if inp_dtype.kind == "c": if inp_dtype.kind == "c":
...@@ -179,10 +179,11 @@ def numba_funcify_LU(op, node, **kwargs): ...@@ -179,10 +179,11 @@ def numba_funcify_LU(op, node, **kwargs):
return res return res
return lu cache_key = 1
return lu, cache_key
@numba_funcify.register(LUFactor) @register_funcify_default_op_cache_key(LUFactor)
def numba_funcify_LUFactor(op, node, **kwargs): def numba_funcify_LUFactor(op, node, **kwargs):
inp_dtype = node.inputs[0].type.numpy_dtype inp_dtype = node.inputs[0].type.numpy_dtype
if inp_dtype.kind == "c": if inp_dtype.kind == "c":
...@@ -215,7 +216,8 @@ def numba_funcify_LUFactor(op, node, **kwargs): ...@@ -215,7 +216,8 @@ def numba_funcify_LUFactor(op, node, **kwargs):
return LU, piv return LU, piv
return lu_factor cache_key = 1
return lu_factor, cache_key
@register_funcify_default_op_cache_key(BlockDiagonal) @register_funcify_default_op_cache_key(BlockDiagonal)
...@@ -240,7 +242,7 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs): ...@@ -240,7 +242,7 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs):
return block_diag return block_diag
@numba_funcify.register(Solve) @register_funcify_default_op_cache_key(Solve)
def numba_funcify_Solve(op, node, **kwargs): def numba_funcify_Solve(op, node, **kwargs):
A_dtype, b_dtype = (i.type.numpy_dtype for i in node.inputs) A_dtype, b_dtype = (i.type.numpy_dtype for i in node.inputs)
out_dtype = node.outputs[0].type.numpy_dtype out_dtype = node.outputs[0].type.numpy_dtype
...@@ -305,10 +307,11 @@ def numba_funcify_Solve(op, node, **kwargs): ...@@ -305,10 +307,11 @@ def numba_funcify_Solve(op, node, **kwargs):
res = solve_fn(a, b, lower, overwrite_a, overwrite_b, check_finite, transposed) res = solve_fn(a, b, lower, overwrite_a, overwrite_b, check_finite, transposed)
return res return res
return solve cache_key = 1
return solve, cache_key
@numba_funcify.register(SolveTriangular) @register_funcify_default_op_cache_key(SolveTriangular)
def numba_funcify_SolveTriangular(op, node, **kwargs): def numba_funcify_SolveTriangular(op, node, **kwargs):
lower = op.lower lower = op.lower
unit_diagonal = op.unit_diagonal unit_diagonal = op.unit_diagonal
...@@ -358,10 +361,11 @@ def numba_funcify_SolveTriangular(op, node, **kwargs): ...@@ -358,10 +361,11 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
return res return res
return solve_triangular cache_key = 1
return solve_triangular, cache_key
@numba_funcify.register(CholeskySolve) @register_funcify_default_op_cache_key(CholeskySolve)
def numba_funcify_CholeskySolve(op, node, **kwargs): def numba_funcify_CholeskySolve(op, node, **kwargs):
lower = op.lower lower = op.lower
overwrite_b = op.overwrite_b overwrite_b = op.overwrite_b
...@@ -407,10 +411,11 @@ def numba_funcify_CholeskySolve(op, node, **kwargs): ...@@ -407,10 +411,11 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
check_finite=check_finite, check_finite=check_finite,
) )
return cho_solve cache_key = 1
return cho_solve, cache_key
@numba_funcify.register(QR) @register_funcify_default_op_cache_key(QR)
def numba_funcify_QR(op, node, **kwargs): def numba_funcify_QR(op, node, **kwargs):
mode = op.mode mode = op.mode
check_finite = op.check_finite check_finite = op.check_finite
...@@ -500,4 +505,5 @@ def numba_funcify_QR(op, node, **kwargs): ...@@ -500,4 +505,5 @@ def numba_funcify_QR(op, node, **kwargs):
f"QR mode={mode}, pivoting={pivoting} not supported in numba mode." f"QR mode={mode}, pivoting={pivoting} not supported in numba mode."
) )
return qr cache_key = 1
return qr, cache_key
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论