Unverified 提交 5137ed3e authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Allow caching of cython functions in numba backend (#1807)

* Add machinery to enable numba caching of function pointers * Numba cache for Cholesky * Numba cache for SolveTriangular * Numba cache for CholeskySolve * Numba cache for solve helpers * Numba cache for GECON * Numba cache for lu_factor * Numba cache for Solve when assume_a="gen" * Numba cache for Solve when assume_a="sym" * Numba cache for Solve when assume_a="pos" * Numba cache for Solve when assume_a='tri' * Numba cache for QR * Clean up obsolete code * Feedback * More feedback * Rename `cache_key_lit` -> `unique_func_name_lit`
上级 0b439c0f
...@@ -5,6 +5,9 @@ from tempfile import NamedTemporaryFile ...@@ -5,6 +5,9 @@ from tempfile import NamedTemporaryFile
from typing import Any from typing import Any
from weakref import WeakKeyDictionary from weakref import WeakKeyDictionary
import numba
from llvmlite import ir
from numba.core import cgutils
from numba.core.caching import CacheImpl, _CacheLocator from numba.core.caching import CacheImpl, _CacheLocator
from pytensor.configdefaults import config from pytensor.configdefaults import config
...@@ -127,3 +130,49 @@ def compile_numba_function_src( ...@@ -127,3 +130,49 @@ def compile_numba_function_src(
CACHED_SRC_FUNCTIONS[res] = cache_key CACHED_SRC_FUNCTIONS[res] = cache_key
return res # type: ignore return res # type: ignore
@numba.extending.intrinsic(prefer_literal=True)
def _call_cached_ptr(typingctx, get_ptr_func, func_type_ref, unique_func_name_lit):
"""
Enable caching of function pointers returned by `get_ptr_func`.
When one of our Numba-dispatched functions depends on a pointer to a compiled function function (e.g. when we call
cython_lapack routines), numba will refuse to cache the function, because the pointer may change between runs.
This intrinsic allows us to cache the pointer ourselves, by storing it in a global variable keyed by a literal
`unique_func_name_lit`. The first time the intrinsic is called, it will call `get_ptr_func` to get the pointer, store it
in the global variable, and return it. Subsequent calls will load the pointer from the global variable.
"""
func_type = func_type_ref.instance_type
cache_key = unique_func_name_lit.literal_value
def codegen(context, builder, signature, args):
ptr_ty = ir.PointerType(ir.IntType(8))
null = ptr_ty(None)
align = 64
mod = builder.module
var = cgutils.add_global_variable(mod, ptr_ty, f"_ptr_cache_{cache_key}")
var.align = align
var.linkage = "private"
var.initializer = null
var_val = builder.load_atomic(var, "acquire", align)
result_ptr = cgutils.alloca_once_value(builder, var_val)
with builder.if_then(builder.icmp_signed("==", var_val, null), likely=False):
sig = typingctx.resolve_function_type(get_ptr_func, [], {})
f = context.get_function(get_ptr_func, sig)
new_ptr = f(builder, [])
new_ptr = builder.inttoptr(new_ptr, ptr_ty)
builder.store_atomic(new_ptr, var, "release", align)
builder.store(new_ptr, result_ptr)
sfunc = cgutils.create_struct_proxy(func_type)(context, builder)
sfunc.c_addr = builder.load(result_ptr)
return sfunc._getvalue()
sig = func_type(get_ptr_func, func_type_ref, unique_func_name_lit)
return sig, codegen
...@@ -6,7 +6,6 @@ from pytensor import config ...@@ -6,7 +6,6 @@ from pytensor import config
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import ( from pytensor.link.numba.dispatch.basic import (
generate_fallback_impl, generate_fallback_impl,
numba_funcify,
register_funcify_default_op_cache_key, register_funcify_default_op_cache_key,
) )
from pytensor.link.numba.dispatch.linalg.decomposition.cholesky import _cholesky from pytensor.link.numba.dispatch.linalg.decomposition.cholesky import _cholesky
...@@ -44,7 +43,7 @@ from pytensor.tensor.slinalg import ( ...@@ -44,7 +43,7 @@ from pytensor.tensor.slinalg import (
) )
@numba_funcify.register(Cholesky) @register_funcify_default_op_cache_key(Cholesky)
def numba_funcify_Cholesky(op, node, **kwargs): def numba_funcify_Cholesky(op, node, **kwargs):
""" """
Overload scipy.linalg.cholesky with a numba function. Overload scipy.linalg.cholesky with a numba function.
...@@ -95,7 +94,8 @@ def numba_funcify_Cholesky(op, node, **kwargs): ...@@ -95,7 +94,8 @@ def numba_funcify_Cholesky(op, node, **kwargs):
return res return res
return cholesky cache_key = 1
return cholesky, cache_key
@register_funcify_default_op_cache_key(PivotToPermutations) @register_funcify_default_op_cache_key(PivotToPermutations)
...@@ -115,7 +115,7 @@ def pivot_to_permutation(op, node, **kwargs): ...@@ -115,7 +115,7 @@ def pivot_to_permutation(op, node, **kwargs):
return numba_pivot_to_permutation, cache_key return numba_pivot_to_permutation, cache_key
@numba_funcify.register(LU) @register_funcify_default_op_cache_key(LU)
def numba_funcify_LU(op, node, **kwargs): def numba_funcify_LU(op, node, **kwargs):
inp_dtype = node.inputs[0].type.numpy_dtype inp_dtype = node.inputs[0].type.numpy_dtype
if inp_dtype.kind == "c": if inp_dtype.kind == "c":
...@@ -179,10 +179,11 @@ def numba_funcify_LU(op, node, **kwargs): ...@@ -179,10 +179,11 @@ def numba_funcify_LU(op, node, **kwargs):
return res return res
return lu cache_key = 1
return lu, cache_key
@numba_funcify.register(LUFactor) @register_funcify_default_op_cache_key(LUFactor)
def numba_funcify_LUFactor(op, node, **kwargs): def numba_funcify_LUFactor(op, node, **kwargs):
inp_dtype = node.inputs[0].type.numpy_dtype inp_dtype = node.inputs[0].type.numpy_dtype
if inp_dtype.kind == "c": if inp_dtype.kind == "c":
...@@ -215,7 +216,8 @@ def numba_funcify_LUFactor(op, node, **kwargs): ...@@ -215,7 +216,8 @@ def numba_funcify_LUFactor(op, node, **kwargs):
return LU, piv return LU, piv
return lu_factor cache_key = 1
return lu_factor, cache_key
@register_funcify_default_op_cache_key(BlockDiagonal) @register_funcify_default_op_cache_key(BlockDiagonal)
...@@ -240,7 +242,7 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs): ...@@ -240,7 +242,7 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs):
return block_diag return block_diag
@numba_funcify.register(Solve) @register_funcify_default_op_cache_key(Solve)
def numba_funcify_Solve(op, node, **kwargs): def numba_funcify_Solve(op, node, **kwargs):
A_dtype, b_dtype = (i.type.numpy_dtype for i in node.inputs) A_dtype, b_dtype = (i.type.numpy_dtype for i in node.inputs)
out_dtype = node.outputs[0].type.numpy_dtype out_dtype = node.outputs[0].type.numpy_dtype
...@@ -305,10 +307,11 @@ def numba_funcify_Solve(op, node, **kwargs): ...@@ -305,10 +307,11 @@ def numba_funcify_Solve(op, node, **kwargs):
res = solve_fn(a, b, lower, overwrite_a, overwrite_b, check_finite, transposed) res = solve_fn(a, b, lower, overwrite_a, overwrite_b, check_finite, transposed)
return res return res
return solve cache_key = 1
return solve, cache_key
@numba_funcify.register(SolveTriangular) @register_funcify_default_op_cache_key(SolveTriangular)
def numba_funcify_SolveTriangular(op, node, **kwargs): def numba_funcify_SolveTriangular(op, node, **kwargs):
lower = op.lower lower = op.lower
unit_diagonal = op.unit_diagonal unit_diagonal = op.unit_diagonal
...@@ -358,10 +361,11 @@ def numba_funcify_SolveTriangular(op, node, **kwargs): ...@@ -358,10 +361,11 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
return res return res
return solve_triangular cache_key = 1
return solve_triangular, cache_key
@numba_funcify.register(CholeskySolve) @register_funcify_default_op_cache_key(CholeskySolve)
def numba_funcify_CholeskySolve(op, node, **kwargs): def numba_funcify_CholeskySolve(op, node, **kwargs):
lower = op.lower lower = op.lower
overwrite_b = op.overwrite_b overwrite_b = op.overwrite_b
...@@ -407,10 +411,11 @@ def numba_funcify_CholeskySolve(op, node, **kwargs): ...@@ -407,10 +411,11 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
check_finite=check_finite, check_finite=check_finite,
) )
return cho_solve cache_key = 1
return cho_solve, cache_key
@numba_funcify.register(QR) @register_funcify_default_op_cache_key(QR)
def numba_funcify_QR(op, node, **kwargs): def numba_funcify_QR(op, node, **kwargs):
mode = op.mode mode = op.mode
check_finite = op.check_finite check_finite = op.check_finite
...@@ -500,4 +505,5 @@ def numba_funcify_QR(op, node, **kwargs): ...@@ -500,4 +505,5 @@ def numba_funcify_QR(op, node, **kwargs):
f"QR mode={mode}, pivoting={pivoting} not supported in numba mode." f"QR mode={mode}, pivoting={pivoting} not supported in numba mode."
) )
return qr cache_key = 1
return qr, cache_key
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论