提交 74ab0383 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Manual control of numba caching

上级 5fbf81df
......@@ -228,7 +228,7 @@ Here's an example for :class:`DimShuffle`:
# E No match.
# ...(on this line)...
# E shuffle_shape = res.shape[: len(shuffle)]
@numba_basic.numba_njit(inline="always")
@numba_basic.numba_njit
def dimshuffle(x):
return dimshuffle_inner(np.asarray(x), shuffle)
......
import logging
import os
import shutil
import sys
from pathlib import Path
......@@ -74,7 +75,10 @@ def main():
'You can also call "pytensor-cache purge" to '
"remove everything from that directory."
)
_logger.debug(f"Remaining elements ({len(items)}): {', '.join(items)}")
_logger.debug(f"Remaining elements ({len(items)}): {items}")
numba_cache_dir: Path = config.base_compiledir / "numba"
shutil.rmtree(numba_cache_dir, ignore_errors=True)
elif sys.argv[1] == "list":
pytensor.compile.compiledir.print_compiledir_content()
elif sys.argv[1] == "cleanup":
......@@ -86,6 +90,8 @@ def main():
print("Lock successfully removed!")
elif sys.argv[1] == "purge":
pytensor.compile.compiledir.compiledir_purge()
numba_cache_dir: Path = config.base_compiledir / "numba"
shutil.rmtree(numba_cache_dir, ignore_errors=True)
elif sys.argv[1] == "basecompiledir":
# Simply print the base_compiledir
print(pytensor.config.base_compiledir)
......
from collections.abc import Callable
from hashlib import sha256
from pathlib import Path
from pickle import dump
from tempfile import NamedTemporaryFile
from typing import Any
from weakref import WeakKeyDictionary
from numba.core.caching import CacheImpl, _CacheLocator
from pytensor.configdefaults import config
NUMBA_CACHE_PATH = config.base_compiledir / "numba"
NUMBA_CACHE_PATH.mkdir(exist_ok=True)
CACHED_SRC_FUNCTIONS: WeakKeyDictionary[Callable, str] = WeakKeyDictionary()
class NumbaPyTensorCacheLocator(_CacheLocator):
"""Locator for Numba functions defined from PyTensor-generated source code.
It uses an internally-defined hash to disambiguate functions.
Functions returned by the PyTensor dispatchers are cached in the CACHED_SRC_FUNCTIONS
weakref dictionary when `compile_numba_function_src` is called with a `cache_key`.
When numba later attempts to find a cache for such a function, this locator gets triggered
and directs numba to the PyTensor Numba cache directory, using the provided hash as disambiguator.
It is not necessary that the python functions be cached by the dispatchers.
As long as the key is the same, numba will be directed to the same cache entry, even if the function is fresh.
Conversely, if the function changed but the key is the same, numba will still use the old cache.
"""
def __init__(self, py_func, py_file, hash):
self._py_func = py_func
self._py_file = py_file
self._hash = hash
def ensure_cache_path(self):
"""We ensured this when the module was loaded.
It's too slow to run every time a cache is needed.
"""
pass
def get_cache_path(self):
"""Return the directory the function is cached in."""
return NUMBA_CACHE_PATH
def get_source_stamp(self):
"""Get a timestamp representing the source code's freshness.
Can return any picklable Python object.
This can be used to invalidate all caches from previous PyTensor releases.
"""
return 0
def get_disambiguator(self):
"""Get a string disambiguator for this locator's function.
It should allow disambiguating different but similarly-named functions.
"""
return self._hash
@classmethod
def from_function(cls, py_func, py_file):
"""Create a locator instance for functions stored in CACHED_SRC_FUNCTIONS."""
if config.numba__cache and py_func in CACHED_SRC_FUNCTIONS:
return cls(py_func, Path(py_file).parent, CACHED_SRC_FUNCTIONS[py_func])
# Register our locator at the front of Numba's locator list
CacheImpl._locator_classes.insert(0, NumbaPyTensorCacheLocator)
def hash_from_pickle_dump(obj: Any) -> str:
"""Create a sha256 hash from the pickle dump of an object."""
# Stream pickle directly into the hasher to avoid a large temporary bytes object
hasher = sha256()
class HashFile:
def write(self, b):
hasher.update(b)
dump(obj, HashFile())
return hasher.hexdigest()
def compile_numba_function_src(
src: str,
function_name: str,
global_env: dict[Any, Any] | None = None,
local_env: dict[Any, Any] | None = None,
write_to_disk: bool = False,
cache_key: str | None = None,
) -> Callable:
"""Compile (and optionally cache) a function from source code for use with Numba.
This function compiles the provided source code string into a Python function
with the specified name. If `store_to_disk` is True, the source code is written
to a temporary file before compilation. The compiled function is then executed
in the provided global and local environments.
If a `cache_key` is provided the function is registered in a `CACHED_SRC_FUNCTIONS`
weak reference dictionary, to be used by the `NumbaPyTensorCacheLocator` for caching.
"""
if write_to_disk:
with NamedTemporaryFile(delete=False) as f:
filename = f.name
f.write(src.encode())
else:
filename = "<string>"
if global_env is None:
global_env = {}
if local_env is None:
local_env = {}
mod_code = compile(src, filename, mode="exec")
exec(mod_code, global_env, local_env)
res = local_env[function_name]
res.__source__ = src
if cache_key is not None:
CACHED_SRC_FUNCTIONS[res] = cache_key
return res # type: ignore
......@@ -30,21 +30,7 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs):
accept_inplace=True,
)
NUMBA.optimizer(fgraph)
fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs))
if len(op.fgraph.outputs) == 1:
@numba_basic.numba_njit
def opfromgraph(*inputs):
return fgraph_fn(*inputs)[0]
else:
@numba_basic.numba_njit
def opfromgraph(*inputs):
return fgraph_fn(*inputs)
return opfromgraph
return numba_funcify(op.fgraph, squeeze_output=True, **kwargs)
@numba_funcify.register(TypeCastingOp)
......
......@@ -220,14 +220,9 @@ def numba_funcify_Clip(op, **kwargs):
@numba_funcify.register(Composite)
def numba_funcify_Composite(op, node, **kwargs):
signature = create_numba_signature(op.fgraph, force_scalar=True)
_ = kwargs.pop("storage_map", None)
composite_fn = numba_basic.numba_njit(signature)(
numba_funcify(op.fgraph, squeeze_output=True, **kwargs)
)
return composite_fn
return numba_funcify(op.fgraph, squeeze_output=True, **kwargs)
@numba_funcify.register(Second)
......
......@@ -97,7 +97,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
)
rewriter(fgraph)
scan_inner_func = numba_basic.numba_njit(numba_funcify(op.fgraph))
scan_inner_func = numba_funcify(op.fgraph)
outer_in_names_to_vars = {
(f"outer_in_{i}" if i > 0 else "n_steps"): v for i, v in enumerate(node.inputs)
......
......@@ -5,15 +5,17 @@ class NumbaLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using Numba."""
def fgraph_convert(self, fgraph, **kwargs):
from pytensor.link.numba.dispatch import numba_funcify
# Import numba_njit_and_cache lazily (as numba is an optional dependency)
# This is what triggers the registering of the dispatches as well
from pytensor.link.numba.dispatch.basic import numba_funcify_ensure_cache
return numba_funcify(fgraph, **kwargs)
return numba_funcify_ensure_cache(fgraph, **kwargs)
def jit_compile(self, fn):
def jit_compile(self, fn_and_cache):
from pytensor.link.numba.dispatch.basic import numba_njit
jitted_fn = numba_njit(fn, no_cpython_wrapper=False, no_cfunc_wrapper=False)
return jitted_fn
fn, cache_key = fn_and_cache
return numba_njit(fn.py_func, final_function=True, cache=cache_key is not None)
def create_thunk_inputs(self, storage_map):
return [storage_map[n] for n in self.fgraph.inputs]
......@@ -8,6 +8,7 @@ import pytest
import scipy
from pytensor.compile import SymbolicInput
from pytensor.tensor.utils import hash_from_ndarray
numba = pytest.importorskip("numba")
......@@ -22,6 +23,7 @@ from pytensor.graph.op import Op
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.type import Type
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import cache_key_for_constant
from pytensor.link.numba.linker import NumbaLinker
from pytensor.scalar.basic import ScalarOp, as_scalar
from pytensor.tensor.elemwise import Elemwise
......@@ -131,10 +133,14 @@ def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode):
return tuple(ll)
def njit_noop(*args, **kwargs):
def add_py_func_attr(x):
x.py_func = x
return x
if len(args) == 1 and callable(args[0]):
return args[0]
return add_py_func_attr(args[0])
else:
return lambda x: x
return lambda x: add_py_func_attr(x)
mocks = [
mock.patch("numba.njit", njit_noop),
......@@ -396,8 +402,8 @@ def test_config_options_fastmath():
with config.change_flags(numba__fastmath=True):
pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode)
numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
assert numba_mul_fn.targetoptions["fastmath"] == {
numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
assert numba_sum_fn.targetoptions["fastmath"] == {
"afn",
"arcp",
"contract",
......@@ -405,19 +411,26 @@ def test_config_options_fastmath():
"reassoc",
}
with config.change_flags(numba__fastmath=False):
pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode)
numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
assert numba_sum_fn.targetoptions["fastmath"] is False
def test_config_options_cached():
x = pt.dvector()
with config.change_flags(numba__cache=True):
pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode)
numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
assert not isinstance(numba_mul_fn._cache, numba.core.caching.NullCache)
numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
# Caching is disabled unless the dispatched function returns an explicit cache key
assert isinstance(numba_sum_fn._cache, numba.core.caching.NullCache)
with config.change_flags(numba__cache=False):
pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode)
numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
assert isinstance(numba_mul_fn._cache, numba.core.caching.NullCache)
# Without caching we don't wrap the function in jitable_func
numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
assert isinstance(numba_sum_fn._cache, numba.core.caching.NullCache)
def test_scalar_return_value_conversion():
......@@ -456,3 +469,180 @@ def test_function_overhead(mode, benchmark):
assert np.sum(fn(test_x)) == 1000
benchmark(fn, test_x)
class ComplexType:
def __init__(self, a, b):
self.a = a
self.b = b
class TestKeyForConstant:
def test_numpy_scalars(self):
key_float64_0 = cache_key_for_constant(np.float64(0))
key_float64_0_again = cache_key_for_constant(np.float64(0))
key_int64_0 = cache_key_for_constant(np.float32(0))
assert key_float64_0 == key_float64_0_again
assert key_float64_0 != key_int64_0
def test_None(self):
key_none_1 = cache_key_for_constant(None)
key_none_2 = cache_key_for_constant(None)
assert key_none_1 == key_none_2
def test_python_scalars(self):
key_int_0 = cache_key_for_constant(0)
key_int_0_again = cache_key_for_constant(0)
key_float_0 = cache_key_for_constant(0.0)
assert key_int_0 == key_int_0_again
assert key_int_0 != key_float_0
def test_numpy_arrays(self):
# Jest check we are using hash_from_ndarary and trust that is working
# If we change our implementation we may need more exhaustive tests here
arr1 = np.array([1, 2, 3], dtype=np.float32)
arr2 = np.array([1, 3, 2], dtype=np.float32)
key_arr1 = cache_key_for_constant(arr1)
expected_key_arr1 = hash_from_ndarray(arr1)
key_arr2 = cache_key_for_constant(arr2)
expected_key_arr2 = hash_from_ndarray(arr2)
assert key_arr1 == expected_key_arr1
assert key_arr2 == expected_key_arr2
assert key_arr1 != key_arr2
def test_complex_types(self):
obj1 = ComplexType(1, 2)
ob1_again = ComplexType(1, 2)
obj2 = ComplexType(3, 4)
key_obj1 = cache_key_for_constant(obj1)
key_obj1_again = cache_key_for_constant(ob1_again)
key_obj2 = cache_key_for_constant(obj2)
assert key_obj1 == key_obj1_again
assert key_obj1 != key_obj2
def test_funcify_dispatch_interop():
"""Test that the different funcify registration decorators work together as expected."""
class BaseOp(Op):
itypes = [pt.dscalar]
otypes = [pt.dscalar]
class FuncifiedOp(BaseOp):
def perform(self, node, inputs, outputs):
outputs[0][0] = inputs[0] + 1
class FuncifiedAndCachedOp(BaseOp):
def perform(self, node, inputs, outputs):
outputs[0][0] = inputs[0] * 2
class FuncifiedAndDefaultCachedOp(BaseOp):
__props__ = ()
def perform(self, node, inputs, outputs):
outputs[0][0] = inputs[0] - 3
@numba_basic.numba_funcify.register(FuncifiedOp)
def _(op, node, **kwargs):
@numba_basic.numba_njit
def impl(x):
return x + 1
return impl
@numba_basic.register_funcify_and_cache_key(FuncifiedAndCachedOp)
def _(op, node, **kwargs):
@numba_basic.numba_njit
def impl(x):
return x * 2
return impl, "sushi-hash"
@numba_basic.register_funcify_default_op_cache_key(FuncifiedAndDefaultCachedOp)
def _(op, node, **kwargs):
@numba_basic.numba_njit
def impl(x):
return x - 3
return impl
x = pt.scalar("x", dtype="float64")
outs = [
FuncifiedOp()(x),
FuncifiedAndCachedOp()(x),
FuncifiedAndDefaultCachedOp()(x),
]
test_x = np.array(5.0)
compare_numba_and_py(
[x],
outs,
[test_x],
)
# Test we can use numba_funcify_ensure_cache
fn0, cache0 = numba_basic.numba_funcify_ensure_cache(
outs[0].owner.op, outs[0].owner
)
assert cache0 is None
assert numba.njit(lambda x: fn0(x))(test_x) == 6
fn1, cache1 = numba_basic.numba_funcify_ensure_cache(
outs[1].owner.op, outs[1].owner
)
assert cache1 == "sushi-hash"
assert numba.njit(lambda x: fn1(x))(test_x) == 10
fn2, cache2 = numba_basic.numba_funcify_ensure_cache(
outs[2].owner.op, outs[2].owner
)
assert cache2 is not None
assert numba.njit(lambda x: fn2(x))(test_x) == 2
fn2_again, cache2_again = numba_basic.numba_funcify_ensure_cache(
outs[2].owner.op, outs[2].owner
)
assert cache2 == cache2_again
assert numba.njit(lambda x: fn2_again(x))(test_x) == 2
# Test we can use numba_funcify directly
fn0 = numba_basic.numba_funcify(outs[0].owner.op, outs[0].owner)
assert numba.njit(lambda x: fn0(x))(test_x) == 6
fn1 = numba_basic.numba_funcify(outs[1].owner.op, outs[1].owner)
assert numba.njit(lambda x: fn1(x))(test_x) == 10
fn2 = numba_basic.numba_funcify(outs[2].owner.op, outs[2].owner)
assert numba.njit(lambda x: fn2(x))(test_x) == 2
# Test we can use numba_funcify_and_cache_key directly
fn0, cache0 = numba_basic.numba_funcify_and_cache_key(
outs[0].owner.op, outs[0].owner
)
assert cache0 is None
assert numba.njit(lambda x: fn0(x))(test_x) == 6
fn1, cache1 = numba_basic.numba_funcify_and_cache_key(
outs[1].owner.op, outs[1].owner
)
assert cache1 == "sushi-hash"
assert numba.njit(lambda x: fn1(x))(test_x) == 10
fn2, cache2 = numba_basic.numba_funcify_and_cache_key(
outs[2].owner.op, outs[2].owner
)
assert cache2 is not None
assert numba.njit(lambda x: fn2(x))(test_x) == 2
fn2_again, cache2_again = numba_basic.numba_funcify_and_cache_key(
outs[2].owner.op, outs[2].owner
)
assert cache2 == cache2_again
assert numba.njit(lambda x: fn2_again(x))(test_x) == 2
# Test numba_funcify_default_op_cache_key works as expected
with pytest.raises(NotImplementedError):
numba_basic.numba_funcify_default_op_cache_key(outs[0].owner.op, outs[0].owner)
with pytest.raises(NotImplementedError):
numba_basic.numba_funcify_default_op_cache_key(outs[1].owner.op, outs[1].owner)
fn2_def_cached = numba_basic.numba_funcify_default_op_cache_key(
outs[2].owner.op, outs[2].owner
)
assert numba.njit(lambda x: fn2_def_cached(x))(test_x) == 2
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论