Unverified 提交 5137ed3e authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Allow caching of cython functions in numba backend (#1807)

* Add machinery to enable numba caching of function pointers * Numba cache for Cholesky * Numba cache for SolveTriangular * Numba cache for CholeskySolve * Numba cache for solve helpers * Numba cache for GECON * Numba cache for lu_factor * Numba cache for Solve when assume_a="gen" * Numba cache for Solve when assume_a="sym" * Numba cache for Solve when assume_a="pos" * Numba cache for Solve when assume_a='tri' * Numba cache for QR * Clean up obsolete code * Feedback * More feedback * Rename `cache_key_lit` -> `unique_func_name_lit`
上级 0b439c0f
......@@ -5,6 +5,9 @@ from tempfile import NamedTemporaryFile
from typing import Any
from weakref import WeakKeyDictionary
import numba
from llvmlite import ir
from numba.core import cgutils
from numba.core.caching import CacheImpl, _CacheLocator
from pytensor.configdefaults import config
......@@ -127,3 +130,49 @@ def compile_numba_function_src(
CACHED_SRC_FUNCTIONS[res] = cache_key
return res # type: ignore
@numba.extending.intrinsic(prefer_literal=True)
def _call_cached_ptr(typingctx, get_ptr_func, func_type_ref, unique_func_name_lit):
"""
Enable caching of function pointers returned by `get_ptr_func`.
When one of our Numba-dispatched functions depends on a pointer to a compiled function function (e.g. when we call
cython_lapack routines), numba will refuse to cache the function, because the pointer may change between runs.
This intrinsic allows us to cache the pointer ourselves, by storing it in a global variable keyed by a literal
`unique_func_name_lit`. The first time the intrinsic is called, it will call `get_ptr_func` to get the pointer, store it
in the global variable, and return it. Subsequent calls will load the pointer from the global variable.
"""
func_type = func_type_ref.instance_type
cache_key = unique_func_name_lit.literal_value
def codegen(context, builder, signature, args):
ptr_ty = ir.PointerType(ir.IntType(8))
null = ptr_ty(None)
align = 64
mod = builder.module
var = cgutils.add_global_variable(mod, ptr_ty, f"_ptr_cache_{cache_key}")
var.align = align
var.linkage = "private"
var.initializer = null
var_val = builder.load_atomic(var, "acquire", align)
result_ptr = cgutils.alloca_once_value(builder, var_val)
with builder.if_then(builder.icmp_signed("==", var_val, null), likely=False):
sig = typingctx.resolve_function_type(get_ptr_func, [], {})
f = context.get_function(get_ptr_func, sig)
new_ptr = f(builder, [])
new_ptr = builder.inttoptr(new_ptr, ptr_ty)
builder.store_atomic(new_ptr, var, "release", align)
builder.store(new_ptr, result_ptr)
sfunc = cgutils.create_struct_proxy(func_type)(context, builder)
sfunc.c_addr = builder.load(result_ptr)
return sfunc._getvalue()
sig = func_type(get_ptr_func, func_type_ref, unique_func_name_lit)
return sig, codegen
......@@ -6,7 +6,6 @@ from pytensor import config
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
generate_fallback_impl,
numba_funcify,
register_funcify_default_op_cache_key,
)
from pytensor.link.numba.dispatch.linalg.decomposition.cholesky import _cholesky
......@@ -44,7 +43,7 @@ from pytensor.tensor.slinalg import (
)
@numba_funcify.register(Cholesky)
@register_funcify_default_op_cache_key(Cholesky)
def numba_funcify_Cholesky(op, node, **kwargs):
"""
Overload scipy.linalg.cholesky with a numba function.
......@@ -95,7 +94,8 @@ def numba_funcify_Cholesky(op, node, **kwargs):
return res
return cholesky
cache_key = 1
return cholesky, cache_key
@register_funcify_default_op_cache_key(PivotToPermutations)
......@@ -115,7 +115,7 @@ def pivot_to_permutation(op, node, **kwargs):
return numba_pivot_to_permutation, cache_key
@numba_funcify.register(LU)
@register_funcify_default_op_cache_key(LU)
def numba_funcify_LU(op, node, **kwargs):
inp_dtype = node.inputs[0].type.numpy_dtype
if inp_dtype.kind == "c":
......@@ -179,10 +179,11 @@ def numba_funcify_LU(op, node, **kwargs):
return res
return lu
cache_key = 1
return lu, cache_key
@numba_funcify.register(LUFactor)
@register_funcify_default_op_cache_key(LUFactor)
def numba_funcify_LUFactor(op, node, **kwargs):
inp_dtype = node.inputs[0].type.numpy_dtype
if inp_dtype.kind == "c":
......@@ -215,7 +216,8 @@ def numba_funcify_LUFactor(op, node, **kwargs):
return LU, piv
return lu_factor
cache_key = 1
return lu_factor, cache_key
@register_funcify_default_op_cache_key(BlockDiagonal)
......@@ -240,7 +242,7 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs):
return block_diag
@numba_funcify.register(Solve)
@register_funcify_default_op_cache_key(Solve)
def numba_funcify_Solve(op, node, **kwargs):
A_dtype, b_dtype = (i.type.numpy_dtype for i in node.inputs)
out_dtype = node.outputs[0].type.numpy_dtype
......@@ -305,10 +307,11 @@ def numba_funcify_Solve(op, node, **kwargs):
res = solve_fn(a, b, lower, overwrite_a, overwrite_b, check_finite, transposed)
return res
return solve
cache_key = 1
return solve, cache_key
@numba_funcify.register(SolveTriangular)
@register_funcify_default_op_cache_key(SolveTriangular)
def numba_funcify_SolveTriangular(op, node, **kwargs):
lower = op.lower
unit_diagonal = op.unit_diagonal
......@@ -358,10 +361,11 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
return res
return solve_triangular
cache_key = 1
return solve_triangular, cache_key
@numba_funcify.register(CholeskySolve)
@register_funcify_default_op_cache_key(CholeskySolve)
def numba_funcify_CholeskySolve(op, node, **kwargs):
lower = op.lower
overwrite_b = op.overwrite_b
......@@ -407,10 +411,11 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
check_finite=check_finite,
)
return cho_solve
cache_key = 1
return cho_solve, cache_key
@numba_funcify.register(QR)
@register_funcify_default_op_cache_key(QR)
def numba_funcify_QR(op, node, **kwargs):
mode = op.mode
check_finite = op.check_finite
......@@ -500,4 +505,5 @@ def numba_funcify_QR(op, node, **kwargs):
f"QR mode={mode}, pivoting={pivoting} not supported in numba mode."
)
return qr
cache_key = 1
return qr, cache_key
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论