提交 74ab0383 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Manual control of numba caching

上级 5fbf81df
......@@ -228,7 +228,7 @@ Here's an example for :class:`DimShuffle`:
# E No match.
# ...(on this line)...
# E shuffle_shape = res.shape[: len(shuffle)]
@numba_basic.numba_njit(inline="always")
@numba_basic.numba_njit
def dimshuffle(x):
return dimshuffle_inner(np.asarray(x), shuffle)
......
import logging
import os
import shutil
import sys
from pathlib import Path
......@@ -74,7 +75,10 @@ def main():
'You can also call "pytensor-cache purge" to '
"remove everything from that directory."
)
_logger.debug(f"Remaining elements ({len(items)}): {', '.join(items)}")
_logger.debug(f"Remaining elements ({len(items)}): {items}")
numba_cache_dir: Path = config.base_compiledir / "numba"
shutil.rmtree(numba_cache_dir, ignore_errors=True)
elif sys.argv[1] == "list":
pytensor.compile.compiledir.print_compiledir_content()
elif sys.argv[1] == "cleanup":
......@@ -86,6 +90,8 @@ def main():
print("Lock successfully removed!")
elif sys.argv[1] == "purge":
pytensor.compile.compiledir.compiledir_purge()
numba_cache_dir: Path = config.base_compiledir / "numba"
shutil.rmtree(numba_cache_dir, ignore_errors=True)
elif sys.argv[1] == "basecompiledir":
# Simply print the base_compiledir
print(pytensor.config.base_compiledir)
......
from collections.abc import Callable
from hashlib import sha256
from pathlib import Path
from pickle import dump
from tempfile import NamedTemporaryFile
from typing import Any
from weakref import WeakKeyDictionary
from numba.core.caching import CacheImpl, _CacheLocator
from pytensor.configdefaults import config
NUMBA_CACHE_PATH = config.base_compiledir / "numba"
NUMBA_CACHE_PATH.mkdir(exist_ok=True)
CACHED_SRC_FUNCTIONS: WeakKeyDictionary[Callable, str] = WeakKeyDictionary()
class NumbaPyTensorCacheLocator(_CacheLocator):
"""Locator for Numba functions defined from PyTensor-generated source code.
It uses an internally-defined hash to disambiguate functions.
Functions returned by the PyTensor dispatchers are cached in the CACHED_SRC_FUNCTIONS
weakref dictionary when `compile_numba_function_src` is called with a `cache_key`.
When numba later attempts to find a cache for such a function, this locator gets triggered
and directs numba to the PyTensor Numba cache directory, using the provided hash as disambiguator.
It is not necessary that the python functions be cached by the dispatchers.
As long as the key is the same, numba will be directed to the same cache entry, even if the function is fresh.
Conversely, if the function changed but the key is the same, numba will still use the old cache.
"""
def __init__(self, py_func, py_file, hash):
self._py_func = py_func
self._py_file = py_file
self._hash = hash
def ensure_cache_path(self):
"""We ensured this when the module was loaded.
It's too slow to run every time a cache is needed.
"""
pass
def get_cache_path(self):
"""Return the directory the function is cached in."""
return NUMBA_CACHE_PATH
def get_source_stamp(self):
"""Get a timestamp representing the source code's freshness.
Can return any picklable Python object.
This can be used to invalidate all caches from previous PyTensor releases.
"""
return 0
def get_disambiguator(self):
"""Get a string disambiguator for this locator's function.
It should allow disambiguating different but similarly-named functions.
"""
return self._hash
@classmethod
def from_function(cls, py_func, py_file):
"""Create a locator instance for functions stored in CACHED_SRC_FUNCTIONS."""
if config.numba__cache and py_func in CACHED_SRC_FUNCTIONS:
return cls(py_func, Path(py_file).parent, CACHED_SRC_FUNCTIONS[py_func])
# Register our locator at the front of Numba's locator list
CacheImpl._locator_classes.insert(0, NumbaPyTensorCacheLocator)
def hash_from_pickle_dump(obj: Any) -> str:
"""Create a sha256 hash from the pickle dump of an object."""
# Stream pickle directly into the hasher to avoid a large temporary bytes object
hasher = sha256()
class HashFile:
def write(self, b):
hasher.update(b)
dump(obj, HashFile())
return hasher.hexdigest()
def compile_numba_function_src(
src: str,
function_name: str,
global_env: dict[Any, Any] | None = None,
local_env: dict[Any, Any] | None = None,
write_to_disk: bool = False,
cache_key: str | None = None,
) -> Callable:
"""Compile (and optionally cache) a function from source code for use with Numba.
This function compiles the provided source code string into a Python function
with the specified name. If `store_to_disk` is True, the source code is written
to a temporary file before compilation. The compiled function is then executed
in the provided global and local environments.
If a `cache_key` is provided the function is registered in a `CACHED_SRC_FUNCTIONS`
weak reference dictionary, to be used by the `NumbaPyTensorCacheLocator` for caching.
"""
if write_to_disk:
with NamedTemporaryFile(delete=False) as f:
filename = f.name
f.write(src.encode())
else:
filename = "<string>"
if global_env is None:
global_env = {}
if local_env is None:
local_env = {}
mod_code = compile(src, filename, mode="exec")
exec(mod_code, global_env, local_env)
res = local_env[function_name]
res.__source__ = src
if cache_key is not None:
CACHED_SRC_FUNCTIONS[res] = cache_key
return res # type: ignore
import warnings
from functools import singledispatch
from collections.abc import Callable
from functools import singledispatch, wraps
from hashlib import sha256
import numba
import numpy as np
from numba.core.errors import NumbaWarning
from numba import njit as _njit
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
from pytensor import config
from pytensor.graph.basic import Apply
from pytensor.graph.basic import Apply, Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.type import Type
from pytensor.link.numba.cache import compile_numba_function_src, hash_from_pickle_dump
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
from pytensor.link.utils import (
fgraph_to_python,
......@@ -17,12 +20,21 @@ from pytensor.link.utils import (
from pytensor.scalar.basic import ScalarType
from pytensor.sparse import SparseTensorType
from pytensor.tensor.type import TensorType
from pytensor.tensor.utils import hash_from_ndarray
def numba_njit(*args, fastmath=None, **kwargs):
kwargs.setdefault("cache", config.numba__cache)
kwargs.setdefault("no_cpython_wrapper", True)
kwargs.setdefault("no_cfunc_wrapper", True)
def numba_njit(
*args, fastmath=None, final_function: bool = False, **kwargs
) -> Callable:
"""A thin wrapper around `numba.njit`.
If `final_function` is `False` (default), the flags `no_cpython_wrapper` and `no_cfunc_wrapper` are set to `True`.
This speedups compilation for functions that need not be called directly from Python.
This function also sets opinionated defaults for the `fastmath` argument based on the
`pytensor.config.numba__fastmath` configuration variable.
"""
if fastmath is None:
if config.numba__fastmath:
# Opinionated default on fastmath flags
......@@ -37,23 +49,15 @@ def numba_njit(*args, fastmath=None, **kwargs):
else:
fastmath = False
# Suppress cache warning for internal functions
# We have to add an ansi escape code for optional bold text by numba
warnings.filterwarnings(
"ignore",
message=(
"(\x1b\\[1m)*" # ansi escape code for bold text
"Cannot cache compiled function "
'"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve|lu_factor)" '
"as it uses dynamic globals"
),
category=NumbaWarning,
)
if not final_function:
# These slow down compilation and are not necessary for functions not called directly from Python
kwargs.setdefault("no_cpython_wrapper", True)
kwargs.setdefault("no_cfunc_wrapper", True)
if len(args) > 0 and callable(args[0]):
return numba.njit(*args[1:], fastmath=fastmath, **kwargs)(args[0])
return numba.njit(*args, fastmath=fastmath, **kwargs)
return _njit(*args[1:], fastmath=fastmath, **kwargs)(args[0]) # type: ignore
else:
return _njit(*args, fastmath=fastmath, **kwargs) # type: ignore
def get_numba_type(
......@@ -261,17 +265,275 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs):
return generate_fallback_impl(op, node, storage_map, **kwargs)
@numba_funcify.register(FunctionGraph)
@singledispatch
def numba_funcify_default_op_cache_key(
op, node=None, **kwargs
) -> Callable | tuple[Callable, int]:
"""Funcify an Op and allow a default cache key to be generated for it.
Wrapped function can return an integer in addition to the generated numba function.
See docstrings of `register_funcify_default_op_cache_key` for details.
"""
raise NotImplementedError()
def register_funcify_default_op_cache_key(op_type):
"""Funcify an Op and allow a default cache key to be generated for it.
This function is a helper that dispatches to both `numba_funcify_default_op_cache_key`
and the legacy `numba_funcify`.
The cache key will ultimately be generated by the base case of `numba_funcify_and_cache_key`
when a more specialized dispatch for the Op is not registered. Functions wrapped by this decorator
can return an integer in addition to the numba function.
This will be added to the default cache key, and can be used to signal changes over versions.
The default cache key is based on the string representations of: `type(op)` and the
bytes of the props serialized by pickle.
It does not take into account the input types or any other graph context.
Note that numba will use the input array dtypes, rank and layout as part of its own cache key,
but not the static shape, broadcastable pattern or constant values.
If the funcify implementation exploits information that is not unique to either the Op class
or it's `_props` as described above, or the information numba uses, then this method should not be used.
Instead, use `register_funcify_and_cache_key` to implement a custom cache key generation.
"""
def decorator(dispatch_func):
numba_funcify_default_op_cache_key.register(op_type)(dispatch_func)
# Create a wrapper that can be dispatched to the legacy `numba_funcify`
@wraps(dispatch_func)
def dispatch_func_wrapper(*args, **kwargs):
# Discard the potential key salt for the non-cache version
func_and_int = dispatch_func(*args, **kwargs)
if isinstance(func_and_int, tuple):
func, _int = func_and_int
else:
func = func_and_int
return func
numba_funcify.register(op_type)(dispatch_func_wrapper)
# Return the original function
return dispatch_func
return decorator
@singledispatch
def numba_funcify_and_cache_key(op, node=None, **kwargs) -> tuple[Callable, str | None]:
"""Funcify an Op and return a unique cache key that can be used by numba caching.
A cache key of `None` can be returned to indicate that a function can't be cached.
See docstrings of `register_funcify_default_op_cache_key` for details.
"""
# The base case of this dispatch (if nothing specialized was registered), is to
# 1. Attempt to use `numba_funcify_default_op_cache_key`,
# which indicates a simple cache key based on the Op and its _props can be
# safely used to uniquely identify the returned numba function
# 2. If that fails, attempt to use the legacy `numba_funcify`.
# In this case a `None` is returned as the cache_key to indicate the function
# cannot be safely cached.
try:
func_and_int = numba_funcify_default_op_cache_key(op, node=node, **kwargs)
except NotImplementedError:
# Fallback
return numba_funcify(op, node=node, **kwargs), None
if isinstance(func_and_int, tuple):
func, integer = func_and_int
if isinstance(integer, int):
integer_str = str(integer)
else:
# Input validation
if integer is None: # type: ignore[unreachable]
raise TypeError(
"The function wrapped by `numba_funcify_default_op_cache_key` returned None as its second output, "
"but only integers are allowed.\nIf the function cannot be cached, the wrapper shouldn't be used. "
"You can use `numba_funcify_and_cache_key` to optionally return None",
)
else:
raise TypeError(
f"The function wrapped by numba_funcify_default_op_cache_key returned {integer} of type {type(integer)} "
"as its second output, but only integers are allowed."
)
else:
func, integer_str = func_and_int, "None"
try:
props_dict = op._props_dict()
except AttributeError:
raise ValueError(
"The function wrapped by `numba_funcify_default_op_cache_key` can only be used with Ops with `_props`, "
f"but {op} of type {type(op)} has no _props defined (not even empty)."
)
if not props_dict:
# Simple op, just use the type string as key
hash = sha256(f"({type(op)}, {integer_str})".encode()).hexdigest()
else:
# Simple props, can use string representation of props as key
simple_types = (str, bool, int, type(None), float)
container_types = (tuple, frozenset)
if all(
isinstance(v, simple_types)
or (
isinstance(v, container_types)
and all(isinstance(i, simple_types) for i in v)
)
for v in props_dict.values()
):
hash = sha256(
f"({type(op)}, {tuple(props_dict.items())}, {integer_str})".encode()
).hexdigest()
else:
# Complex props, use pickle to serialize them
hash = hash_from_pickle_dump(
(str(type(op)), tuple(props_dict.items()), integer_str),
)
return func, hash
def register_funcify_and_cache_key(op_type):
"""Funcify an Op and return a unique cache key that can be used by numba caching.
This function is a helper that dispatches to both `numba_funcify_and_cache_key`
and the legacy `numba_funcify`.
Note that numba will use the input array dtypes, rank and layout as part of its own cache key,
but not the static shape, broadcastable pattern or constant values.
The cache_key should be unique to identify the function that was generated by the dispatch
function among all possible PyTensor Ops and graphs, modulo the information numba already uses.
A cache key of `None` can be returned to indicate that a function can't be cached.
For simple cases, it may be possible to use the helper `register_funcify_default_op_cache_key`.
Be sure to read the limitations in the respective docstrings!
"""
def decorator(dispatch_func):
numba_funcify_and_cache_key.register(op_type)(dispatch_func)
# Create a wrapper for the legacy dispatcher
@wraps(dispatch_func)
def dispatch_func_wrapper(*args, **kwargs):
func, _key = dispatch_func(*args, **kwargs)
# Discard the key for the non-cache version
return func
numba_funcify.register(op_type)(dispatch_func_wrapper)
return dispatch_func
return decorator
def numba_funcify_ensure_cache(op, *args, **kwargs) -> tuple[Callable, str | None]:
"""Obtain a numba function for an Op and ensure it can be cached by numba.
If `config.numba__cache` is `True`, and `numba_funcify_and_cache_key` returns a non-None key,
the returned function will be wrapped in a python-compiled function that hoists any closures
to the global scope. This, together with the NumbaPyTensorCacheLocator ensures numba will use our cache.
Without this strategy, numba would often consider caches to be invalid. This was always the case for:
1. Ops using the custom vectorize intrinsic: Elemwise, Blockwise, RandomVariables
2. String generated functions: Alloc, Scan, OpFromGraph, and FunctionGraph itself
"""
if config.numba__cache:
jitable_func, cache_key = numba_funcify_and_cache_key(op, *args, **kwargs)
else:
jitable_func, cache_key = numba_funcify(op, *args, **kwargs), None
if cache_key is None:
if config.numba__cache and config.compiler_verbose:
print(f"{op} of type {type(op)} will not be cached by PyTensor.\n") # noqa: T201
return jitable_func, None
else:
op_name = jitable_func.__name__
cached_func = compile_numba_function_src(
src=f"def {op_name}(*args): return jitable_func(*args)",
function_name=op_name,
global_env=globals() | {"jitable_func": jitable_func},
cache_key=cache_key,
)
return numba_njit(cached_func, cache=True), cache_key
def cache_key_for_constant(data):
"""Create a cache key for a constant value."""
if isinstance(data, np.number):
return sha256(data.dtype.str.encode() + data.tobytes()).hexdigest()
elif isinstance(data, np.ndarray):
return hash_from_ndarray(data)
elif data is None:
return "None"
elif isinstance(data, int | float | bool):
# These should all really be np.number, but we keep this branch just in case
return str(data)
else:
# Fallback for arbitrary types
return hash_from_pickle_dump(data)
@register_funcify_and_cache_key(FunctionGraph)
def numba_funcify_FunctionGraph(
fgraph,
fgraph: FunctionGraph,
node=None,
fgraph_name="numba_funcified_fgraph",
**kwargs,
):
return fgraph_to_python(
# Collect cache keys of every Op/Constant in the FunctionGraph
# so we can create a global cache key for the whole FunctionGraph
cache_keys = []
toposort = fgraph.toposort()
clients = fgraph.clients
toposort_indices = {node: i for i, node in enumerate(toposort)}
# Add dummy output clients which are not included of the toposort
toposort_indices |= {
clients[out][0][0]: i
for i, out in enumerate(fgraph.outputs, start=len(toposort))
}
def op_conversion_and_key_collection(*args, **kwargs):
# Convert an Op to a funcified function and store the cache_key
# We also Cache each Op so Numba can do less work next time it sees it
func, key = numba_funcify_ensure_cache(*args, **kwargs)
cache_keys.append(key)
return func
def type_conversion_and_key_collection(value, variable, **kwargs):
# Convert a constant type to a numba compatible one and compute a cache key for it
# We need to know where in the graph the constants are used
# Otherwise we would hash stack(x, 5.0, 7.0), and stack(5.0, x, 7.0) the same
# FIXME: It doesn't make sense to call type_conversion on non-constants,
# but that's what fgraph_to_python currently does. We appease it, but don't consider for caching
if isinstance(variable, Constant):
client_indices = tuple(
(toposort_indices[node], inp_idx) for node, inp_idx in clients[variable]
)
cache_keys.append((client_indices, cache_key_for_constant(value)))
return numba_typify(value, variable=variable, **kwargs)
py_func = fgraph_to_python(
fgraph,
numba_funcify,
type_conversion_fn=numba_typify,
op_conversion_fn=op_conversion_and_key_collection,
type_conversion_fn=type_conversion_and_key_collection,
fgraph_name=fgraph_name,
**kwargs,
)
if any(key is None for key in cache_keys):
# If a single element couldn't be cached, we can't cache the whole FunctionGraph either
fgraph_key = None
else:
# Compose individual cache_keys into a global key for the FunctionGraph
fgraph_key = sha256(
f"({type(fgraph)}, {tuple(cache_keys)}, {len(fgraph.inputs)}, {len(fgraph.outputs)})".encode()
).hexdigest()
return numba_njit(py_func), fgraph_key
......@@ -30,21 +30,7 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs):
accept_inplace=True,
)
NUMBA.optimizer(fgraph)
fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs))
if len(op.fgraph.outputs) == 1:
@numba_basic.numba_njit
def opfromgraph(*inputs):
return fgraph_fn(*inputs)[0]
else:
@numba_basic.numba_njit
def opfromgraph(*inputs):
return fgraph_fn(*inputs)
return opfromgraph
return numba_funcify(op.fgraph, squeeze_output=True, **kwargs)
@numba_funcify.register(TypeCastingOp)
......
......@@ -220,14 +220,9 @@ def numba_funcify_Clip(op, **kwargs):
@numba_funcify.register(Composite)
def numba_funcify_Composite(op, node, **kwargs):
signature = create_numba_signature(op.fgraph, force_scalar=True)
_ = kwargs.pop("storage_map", None)
composite_fn = numba_basic.numba_njit(signature)(
numba_funcify(op.fgraph, squeeze_output=True, **kwargs)
)
return composite_fn
return numba_funcify(op.fgraph, squeeze_output=True, **kwargs)
@numba_funcify.register(Second)
......
......@@ -97,7 +97,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
)
rewriter(fgraph)
scan_inner_func = numba_basic.numba_njit(numba_funcify(op.fgraph))
scan_inner_func = numba_funcify(op.fgraph)
outer_in_names_to_vars = {
(f"outer_in_{i}" if i > 0 else "n_steps"): v for i, v in enumerate(node.inputs)
......
......@@ -5,15 +5,17 @@ class NumbaLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using Numba."""
def fgraph_convert(self, fgraph, **kwargs):
from pytensor.link.numba.dispatch import numba_funcify
# Import numba_njit_and_cache lazily (as numba is an optional dependency)
# This is what triggers the registering of the dispatches as well
from pytensor.link.numba.dispatch.basic import numba_funcify_ensure_cache
return numba_funcify(fgraph, **kwargs)
return numba_funcify_ensure_cache(fgraph, **kwargs)
def jit_compile(self, fn):
def jit_compile(self, fn_and_cache):
from pytensor.link.numba.dispatch.basic import numba_njit
jitted_fn = numba_njit(fn, no_cpython_wrapper=False, no_cfunc_wrapper=False)
return jitted_fn
fn, cache_key = fn_and_cache
return numba_njit(fn.py_func, final_function=True, cache=cache_key is not None)
def create_thunk_inputs(self, storage_map):
return [storage_map[n] for n in self.fgraph.inputs]
......@@ -8,6 +8,7 @@ import pytest
import scipy
from pytensor.compile import SymbolicInput
from pytensor.tensor.utils import hash_from_ndarray
numba = pytest.importorskip("numba")
......@@ -22,6 +23,7 @@ from pytensor.graph.op import Op
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.type import Type
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import cache_key_for_constant
from pytensor.link.numba.linker import NumbaLinker
from pytensor.scalar.basic import ScalarOp, as_scalar
from pytensor.tensor.elemwise import Elemwise
......@@ -131,10 +133,14 @@ def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode):
return tuple(ll)
def njit_noop(*args, **kwargs):
def add_py_func_attr(x):
x.py_func = x
return x
if len(args) == 1 and callable(args[0]):
return args[0]
return add_py_func_attr(args[0])
else:
return lambda x: x
return lambda x: add_py_func_attr(x)
mocks = [
mock.patch("numba.njit", njit_noop),
......@@ -396,8 +402,8 @@ def test_config_options_fastmath():
with config.change_flags(numba__fastmath=True):
pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode)
numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
assert numba_mul_fn.targetoptions["fastmath"] == {
numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
assert numba_sum_fn.targetoptions["fastmath"] == {
"afn",
"arcp",
"contract",
......@@ -405,19 +411,26 @@ def test_config_options_fastmath():
"reassoc",
}
with config.change_flags(numba__fastmath=False):
pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode)
numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
assert numba_sum_fn.targetoptions["fastmath"] is False
def test_config_options_cached():
x = pt.dvector()
with config.change_flags(numba__cache=True):
pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode)
numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
assert not isinstance(numba_mul_fn._cache, numba.core.caching.NullCache)
numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
# Caching is disabled unless the dispatched function returns an explicit cache key
assert isinstance(numba_sum_fn._cache, numba.core.caching.NullCache)
with config.change_flags(numba__cache=False):
pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode)
numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
assert isinstance(numba_mul_fn._cache, numba.core.caching.NullCache)
# Without caching we don't wrap the function in jitable_func
numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
assert isinstance(numba_sum_fn._cache, numba.core.caching.NullCache)
def test_scalar_return_value_conversion():
......@@ -456,3 +469,180 @@ def test_function_overhead(mode, benchmark):
assert np.sum(fn(test_x)) == 1000
benchmark(fn, test_x)
class ComplexType:
def __init__(self, a, b):
self.a = a
self.b = b
class TestKeyForConstant:
def test_numpy_scalars(self):
key_float64_0 = cache_key_for_constant(np.float64(0))
key_float64_0_again = cache_key_for_constant(np.float64(0))
key_int64_0 = cache_key_for_constant(np.float32(0))
assert key_float64_0 == key_float64_0_again
assert key_float64_0 != key_int64_0
def test_None(self):
key_none_1 = cache_key_for_constant(None)
key_none_2 = cache_key_for_constant(None)
assert key_none_1 == key_none_2
def test_python_scalars(self):
key_int_0 = cache_key_for_constant(0)
key_int_0_again = cache_key_for_constant(0)
key_float_0 = cache_key_for_constant(0.0)
assert key_int_0 == key_int_0_again
assert key_int_0 != key_float_0
def test_numpy_arrays(self):
# Jest check we are using hash_from_ndarary and trust that is working
# If we change our implementation we may need more exhaustive tests here
arr1 = np.array([1, 2, 3], dtype=np.float32)
arr2 = np.array([1, 3, 2], dtype=np.float32)
key_arr1 = cache_key_for_constant(arr1)
expected_key_arr1 = hash_from_ndarray(arr1)
key_arr2 = cache_key_for_constant(arr2)
expected_key_arr2 = hash_from_ndarray(arr2)
assert key_arr1 == expected_key_arr1
assert key_arr2 == expected_key_arr2
assert key_arr1 != key_arr2
def test_complex_types(self):
obj1 = ComplexType(1, 2)
ob1_again = ComplexType(1, 2)
obj2 = ComplexType(3, 4)
key_obj1 = cache_key_for_constant(obj1)
key_obj1_again = cache_key_for_constant(ob1_again)
key_obj2 = cache_key_for_constant(obj2)
assert key_obj1 == key_obj1_again
assert key_obj1 != key_obj2
def test_funcify_dispatch_interop():
"""Test that the different funcify registration decorators work together as expected."""
class BaseOp(Op):
itypes = [pt.dscalar]
otypes = [pt.dscalar]
class FuncifiedOp(BaseOp):
def perform(self, node, inputs, outputs):
outputs[0][0] = inputs[0] + 1
class FuncifiedAndCachedOp(BaseOp):
def perform(self, node, inputs, outputs):
outputs[0][0] = inputs[0] * 2
class FuncifiedAndDefaultCachedOp(BaseOp):
__props__ = ()
def perform(self, node, inputs, outputs):
outputs[0][0] = inputs[0] - 3
@numba_basic.numba_funcify.register(FuncifiedOp)
def _(op, node, **kwargs):
@numba_basic.numba_njit
def impl(x):
return x + 1
return impl
@numba_basic.register_funcify_and_cache_key(FuncifiedAndCachedOp)
def _(op, node, **kwargs):
@numba_basic.numba_njit
def impl(x):
return x * 2
return impl, "sushi-hash"
@numba_basic.register_funcify_default_op_cache_key(FuncifiedAndDefaultCachedOp)
def _(op, node, **kwargs):
@numba_basic.numba_njit
def impl(x):
return x - 3
return impl
x = pt.scalar("x", dtype="float64")
outs = [
FuncifiedOp()(x),
FuncifiedAndCachedOp()(x),
FuncifiedAndDefaultCachedOp()(x),
]
test_x = np.array(5.0)
compare_numba_and_py(
[x],
outs,
[test_x],
)
# Test we can use numba_funcify_ensure_cache
fn0, cache0 = numba_basic.numba_funcify_ensure_cache(
outs[0].owner.op, outs[0].owner
)
assert cache0 is None
assert numba.njit(lambda x: fn0(x))(test_x) == 6
fn1, cache1 = numba_basic.numba_funcify_ensure_cache(
outs[1].owner.op, outs[1].owner
)
assert cache1 == "sushi-hash"
assert numba.njit(lambda x: fn1(x))(test_x) == 10
fn2, cache2 = numba_basic.numba_funcify_ensure_cache(
outs[2].owner.op, outs[2].owner
)
assert cache2 is not None
assert numba.njit(lambda x: fn2(x))(test_x) == 2
fn2_again, cache2_again = numba_basic.numba_funcify_ensure_cache(
outs[2].owner.op, outs[2].owner
)
assert cache2 == cache2_again
assert numba.njit(lambda x: fn2_again(x))(test_x) == 2
# Test we can use numba_funcify directly
fn0 = numba_basic.numba_funcify(outs[0].owner.op, outs[0].owner)
assert numba.njit(lambda x: fn0(x))(test_x) == 6
fn1 = numba_basic.numba_funcify(outs[1].owner.op, outs[1].owner)
assert numba.njit(lambda x: fn1(x))(test_x) == 10
fn2 = numba_basic.numba_funcify(outs[2].owner.op, outs[2].owner)
assert numba.njit(lambda x: fn2(x))(test_x) == 2
# Test we can use numba_funcify_and_cache_key directly
fn0, cache0 = numba_basic.numba_funcify_and_cache_key(
outs[0].owner.op, outs[0].owner
)
assert cache0 is None
assert numba.njit(lambda x: fn0(x))(test_x) == 6
fn1, cache1 = numba_basic.numba_funcify_and_cache_key(
outs[1].owner.op, outs[1].owner
)
assert cache1 == "sushi-hash"
assert numba.njit(lambda x: fn1(x))(test_x) == 10
fn2, cache2 = numba_basic.numba_funcify_and_cache_key(
outs[2].owner.op, outs[2].owner
)
assert cache2 is not None
assert numba.njit(lambda x: fn2(x))(test_x) == 2
fn2_again, cache2_again = numba_basic.numba_funcify_and_cache_key(
outs[2].owner.op, outs[2].owner
)
assert cache2 == cache2_again
assert numba.njit(lambda x: fn2_again(x))(test_x) == 2
# Test numba_funcify_default_op_cache_key works as expected
with pytest.raises(NotImplementedError):
numba_basic.numba_funcify_default_op_cache_key(outs[0].owner.op, outs[0].owner)
with pytest.raises(NotImplementedError):
numba_basic.numba_funcify_default_op_cache_key(outs[1].owner.op, outs[1].owner)
fn2_def_cached = numba_basic.numba_funcify_default_op_cache_key(
outs[2].owner.op, outs[2].owner
)
assert numba.njit(lambda x: fn2_def_cached(x))(test_x) == 2
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论