提交 6ac5ab28 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Cache keys for numba Op dispatches

上级 74ab0383
import sys
from hashlib import sha256
from typing import cast
from numba.core.extending import overload
from numba.np.unsafe.ndarray import to_fixed_tuple
from pytensor.link.numba.cache import compile_numba_function_src
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import numba_funcify
from pytensor.link.numba.dispatch.basic import (
numba_funcify_and_cache_key,
register_funcify_and_cache_key,
)
from pytensor.link.numba.dispatch.vectorize_codegen import (
_jit_options,
_vectorized,
encode_literals,
store_core_outputs,
)
from pytensor.link.utils import compile_function_src
from pytensor.tensor import TensorVariable, get_vector_length
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
@numba_funcify.register(BlockwiseWithCoreShape)
@register_funcify_and_cache_key(BlockwiseWithCoreShape)
def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
[blockwise_node] = op.fgraph.apply_nodes
blockwise_op: Blockwise = blockwise_node.op
......@@ -30,7 +33,7 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
cast(tuple[TensorVariable], node.inputs[:nin]),
propagate_unbatched_core_inputs=True,
)
core_op_fn = numba_funcify(
core_op_fn, core_op_key = numba_funcify_and_cache_key(
core_op,
node=core_node,
parent_node=node,
......@@ -58,36 +61,56 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
src += ")"
to_tuple = numba_basic.numba_njit(
compile_function_src(
compile_numba_function_src(
src,
"to_tuple",
global_env={"to_fixed_tuple": to_fixed_tuple},
),
# cache=True leads to a numba.cloudpickle dump failure in Python 3.10
# May be fine in Python 3.11, but I didn't test. It was fine in 3.12
cache=sys.version_info >= (3, 12),
)
def blockwise_wrapper(*inputs_and_core_shapes):
inputs, core_shapes = inputs_and_core_shapes[:nin], inputs_and_core_shapes[nin:]
tuple_core_shapes = to_tuple(core_shapes)
return _vectorized(
core_op_fn,
input_bc_patterns,
output_bc_patterns,
output_dtypes,
inplace_pattern,
(), # constant_inputs
inputs,
tuple_core_shapes,
None, # size
)
)
def blockwise(*inputs_and_core_shapes):
raise NotImplementedError("Non-jitted BlockwiseWithCoreShape not implemented")
raise NotImplementedError(
"Numba implementation of Blockwise cannot be evaluated in Python (non-JIT) mode."
)
@overload(blockwise, jit_options=_jit_options)
def ov_blockwise(*inputs_and_core_shapes):
return blockwise_wrapper
def impl(*inputs_and_core_shapes):
inputs, core_shapes = (
inputs_and_core_shapes[:nin],
inputs_and_core_shapes[nin:],
)
tuple_core_shapes = to_tuple(core_shapes)
return _vectorized(
core_op_fn,
input_bc_patterns,
output_bc_patterns,
output_dtypes,
inplace_pattern,
(), # constant_inputs
inputs,
tuple_core_shapes,
None, # size
)
return impl
return blockwise
if core_op_key is None:
# If the core op cannot be cached, the Blockwise wrapper cannot be cached either
blockwise_key = None
else:
blockwise_key = "_".join(
map(
str,
(
type(op),
type(blockwise_op),
tuple(blockwise_op.destroy_map.items()),
blockwise_op.signature,
input_bc_patterns,
core_op_key,
),
)
)
blockwise_key = sha256(blockwise_key.encode()).hexdigest()
return blockwise, blockwise_key
from hashlib import sha256
import numpy as np
from pytensor.compile.builders import OpFromGraph
......@@ -8,14 +10,15 @@ from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
from pytensor.ifelse import IfElse
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
numba_funcify,
numba_njit,
numba_funcify_and_cache_key,
register_funcify_and_cache_key,
register_funcify_default_op_cache_key,
)
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.type import TensorType
@numba_funcify.register(OpFromGraph)
@register_funcify_and_cache_key(OpFromGraph)
def numba_funcify_OpFromGraph(op, node=None, **kwargs):
_ = kwargs.pop("storage_map", None)
......@@ -30,10 +33,27 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs):
accept_inplace=True,
)
NUMBA.optimizer(fgraph)
return numba_funcify(op.fgraph, squeeze_output=True, **kwargs)
fgraph_fn, fgraph_cache_key = numba_funcify_and_cache_key(
op.fgraph, squeeze_output=True, **kwargs
)
if fgraph_cache_key is None:
# Can't cache the inner graph
ofg_cache_key = None
else:
ofg_cache_key = sha256(
str(
(
type(op),
fgraph_cache_key,
)
).encode()
).hexdigest()
return fgraph_fn, ofg_cache_key
@numba_funcify.register(TypeCastingOp)
@register_funcify_default_op_cache_key(TypeCastingOp)
def numba_funcify_type_casting(op, **kwargs):
@numba_basic.numba_njit
def identity(x):
......@@ -42,7 +62,7 @@ def numba_funcify_type_casting(op, **kwargs):
return identity
@numba_funcify.register(DeepCopyOp)
@register_funcify_default_op_cache_key(DeepCopyOp)
def numba_funcify_DeepCopyOp(op, node, **kwargs):
if isinstance(node.inputs[0].type, TensorType):
......@@ -59,7 +79,7 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs):
return deepcopy
@numba_funcify.register(IfElse)
@register_funcify_default_op_cache_key(IfElse)
def numba_funcify_IfElse(op, **kwargs):
n_outs = op.n_outs
......@@ -88,7 +108,7 @@ def numba_funcify_IfElse(op, **kwargs):
return ifelse
@numba_funcify.register(CheckAndRaise)
@register_funcify_and_cache_key(CheckAndRaise)
def numba_funcify_CheckAndRaise(op, node, **kwargs):
error = op.exc_type
msg = op.msg
......@@ -100,4 +120,5 @@ def numba_funcify_CheckAndRaise(op, node, **kwargs):
raise error(msg)
return x
return check_and_raise
cache_key = sha256(str((type(op), error, msg)).encode()).hexdigest()
return check_and_raise, cache_key
import warnings
from hashlib import sha256
from typing import cast
import numba
......@@ -9,7 +10,8 @@ from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
generate_fallback_impl,
get_numba_type,
numba_funcify,
register_funcify_and_cache_key,
register_funcify_default_op_cache_key,
)
from pytensor.tensor import TensorVariable
from pytensor.tensor.extra_ops import (
......@@ -25,16 +27,16 @@ from pytensor.tensor.extra_ops import (
)
@numba_funcify.register(Bartlett)
@register_funcify_default_op_cache_key(Bartlett)
def numba_funcify_Bartlett(op, **kwargs):
@numba_basic.numba_njit(inline="always")
@numba_basic.numba_njit
def bartlett(x):
return np.bartlett(x.item())
return bartlett
@numba_funcify.register(CumOp)
@register_funcify_default_op_cache_key(CumOp)
def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
axis = op.axis
mode = op.mode
......@@ -94,7 +96,7 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
return cumop
@numba_funcify.register(FillDiagonal)
@register_funcify_default_op_cache_key(FillDiagonal)
def numba_funcify_FillDiagonal(op, **kwargs):
@numba_basic.numba_njit
def filldiagonal(a, val):
......@@ -104,7 +106,7 @@ def numba_funcify_FillDiagonal(op, **kwargs):
return filldiagonal
@numba_funcify.register(FillDiagonalOffset)
@register_funcify_default_op_cache_key(FillDiagonalOffset)
def numba_funcify_FillDiagonalOffset(op, node, **kwargs):
@numba_basic.numba_njit
def filldiagonaloffset(a, val, offset):
......@@ -129,7 +131,7 @@ def numba_funcify_FillDiagonalOffset(op, node, **kwargs):
return filldiagonaloffset
@numba_funcify.register(RavelMultiIndex)
@register_funcify_default_op_cache_key(RavelMultiIndex)
def numba_funcify_RavelMultiIndex(op, node, **kwargs):
mode = op.mode
order = op.order
......@@ -194,7 +196,7 @@ def numba_funcify_RavelMultiIndex(op, node, **kwargs):
return ravelmultiindex
@numba_funcify.register(Repeat)
@register_funcify_default_op_cache_key(Repeat)
def numba_funcify_Repeat(op, node, **kwargs):
axis = op.axis
a, _ = node.inputs
......@@ -202,7 +204,7 @@ def numba_funcify_Repeat(op, node, **kwargs):
# Numba only supports axis=None, which in our case is when axis is 0 and the input is a vector
if axis == 0 and a.type.ndim == 1:
@numba_basic.numba_njit(inline="always")
@numba_basic.numba_njit
def repeatop(x, repeats):
return np.repeat(x, repeats)
......@@ -212,7 +214,7 @@ def numba_funcify_Repeat(op, node, **kwargs):
return generate_fallback_impl(op, node)
@numba_funcify.register(Unique)
@register_funcify_default_op_cache_key(Unique)
def numba_funcify_Unique(op, node, **kwargs):
axis = op.axis
......@@ -230,7 +232,7 @@ def numba_funcify_Unique(op, node, **kwargs):
if not use_python:
@numba_basic.numba_njit(inline="always")
@numba_basic.numba_njit
def unique(x):
return np.unique(x)
......@@ -257,7 +259,7 @@ def numba_funcify_Unique(op, node, **kwargs):
return unique
@numba_funcify.register(UnravelIndex)
@register_funcify_and_cache_key(UnravelIndex)
def numba_funcify_UnravelIndex(op, node, **kwargs):
order = op.order
......@@ -289,10 +291,14 @@ def numba_funcify_UnravelIndex(op, node, **kwargs):
# unpacked into a `tuple`, so this discrepancy shouldn't really matter
return ((maybe_expand_dim(arr) // a) % shape).T
return unravelindex
cache_key = sha256(
str((type(op), op.order, len(node.outputs))).encode()
).hexdigest()
return unravelindex, cache_key
@numba_funcify.register(SearchsortedOp)
@register_funcify_default_op_cache_key(SearchsortedOp)
def numba_funcify_Searchsorted(op, node, **kwargs):
side = op.side
......@@ -319,7 +325,7 @@ def numba_funcify_Searchsorted(op, node, **kwargs):
else:
@numba_basic.numba_njit(inline="always")
@numba_basic.numba_njit
def searchsorted(a, v):
return np.searchsorted(a, v, side)
......
......@@ -3,11 +3,11 @@ import warnings
import numba
import numpy as np
from pytensor.link.numba.dispatch import basic as numba_basic
import pytensor.link.numba.dispatch.basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
get_numba_type,
int_to_float_fn,
numba_funcify,
register_funcify_default_op_cache_key,
)
from pytensor.tensor.nlinalg import (
SVD,
......@@ -20,7 +20,7 @@ from pytensor.tensor.nlinalg import (
)
@numba_funcify.register(SVD)
@register_funcify_default_op_cache_key(SVD)
def numba_funcify_SVD(op, node, **kwargs):
full_matrices = op.full_matrices
compute_uv = op.compute_uv
......@@ -44,19 +44,19 @@ def numba_funcify_SVD(op, node, **kwargs):
return svd
@numba_funcify.register(Det)
@register_funcify_default_op_cache_key(Det)
def numba_funcify_Det(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba_basic.numba_njit(inline="always")
@numba_basic.numba_njit
def det(x):
return np.array(np.linalg.det(inputs_cast(x))).astype(out_dtype)
return det
@numba_funcify.register(SLogDet)
@register_funcify_default_op_cache_key(SLogDet)
def numba_funcify_SLogDet(op, node, **kwargs):
out_dtype_1 = node.outputs[0].type.numpy_dtype
out_dtype_2 = node.outputs[1].type.numpy_dtype
......@@ -74,7 +74,7 @@ def numba_funcify_SLogDet(op, node, **kwargs):
return slogdet
@numba_funcify.register(Eig)
@register_funcify_default_op_cache_key(Eig)
def numba_funcify_Eig(op, node, **kwargs):
w_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, w_dtype)
......@@ -86,7 +86,7 @@ def numba_funcify_Eig(op, node, **kwargs):
return eig
@numba_funcify.register(Eigh)
@register_funcify_default_op_cache_key(Eigh)
def numba_funcify_Eigh(op, node, **kwargs):
uplo = op.UPLO
......@@ -113,31 +113,31 @@ def numba_funcify_Eigh(op, node, **kwargs):
else:
@numba_basic.numba_njit(inline="always")
@numba_basic.numba_njit
def eigh(x):
return np.linalg.eigh(x)
return eigh
@numba_funcify.register(MatrixInverse)
@register_funcify_default_op_cache_key(MatrixInverse)
def numba_funcify_MatrixInverse(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba_basic.numba_njit(inline="always")
@numba_basic.numba_njit
def matrix_inverse(x):
return np.linalg.inv(inputs_cast(x)).astype(out_dtype)
return matrix_inverse
@numba_funcify.register(MatrixPinv)
@register_funcify_default_op_cache_key(MatrixPinv)
def numba_funcify_MatrixPinv(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba_basic.numba_njit(inline="always")
@numba_basic.numba_njit
def matrixpinv(x):
return np.linalg.pinv(inputs_cast(x)).astype(out_dtype)
......
from collections.abc import Callable
from copy import copy, deepcopy
from functools import singledispatch
from hashlib import sha256
from textwrap import dedent
import numba
......@@ -13,7 +14,11 @@ import pytensor.tensor.random.basic as ptr
from pytensor.graph import Apply
from pytensor.graph.op import Op
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import direct_cast, numba_funcify
from pytensor.link.numba.dispatch.basic import (
direct_cast,
numba_funcify,
register_funcify_and_cache_key,
)
from pytensor.link.numba.dispatch.vectorize_codegen import (
_jit_options,
_vectorized,
......@@ -395,7 +400,7 @@ def numba_funcify_RandomVariable_core(op: RandomVariable, **kwargs):
)
@numba_funcify.register
@register_funcify_and_cache_key(RandomVariableWithCoreShape)
def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs):
core_shape = node.inputs[0]
......@@ -423,28 +428,44 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs
output_dtypes = encode_literals((rv_node.default_output().type.dtype,))
inplace_pattern = encode_literals(())
def random_wrapper(core_shape, rng, size, *dist_params):
if not inplace:
rng = copy(rng)
draws = _vectorized(
core_op_fn,
input_bc_patterns,
output_bc_patterns,
output_dtypes,
inplace_pattern,
(rng,),
dist_params,
(numba_ndarray.to_fixed_tuple(core_shape, core_shape_len),),
None if size_len is None else numba_ndarray.to_fixed_tuple(size, size_len),
)
return rng, draws
def random(core_shape, rng, size, *dist_params):
raise NotImplementedError("Non-jitted random variable not implemented")
raise NotImplementedError(
"Numba implementation of RandomVariable cannot be evaluated in Python (non-JIT) mode"
)
@overload(random, jit_options=_jit_options)
def ov_random(core_shape, rng, size, *dist_params):
return random_wrapper
return random
def impl(core_shape, rng, size, *dist_params):
if not inplace:
rng = copy(rng)
draws = _vectorized(
core_op_fn,
input_bc_patterns,
output_bc_patterns,
output_dtypes,
inplace_pattern,
(rng,),
dist_params,
(numba_ndarray.to_fixed_tuple(core_shape, core_shape_len),),
None
if size_len is None
else numba_ndarray.to_fixed_tuple(size, size_len),
)
return rng, draws
return impl
rv_op_props_dict = rv_op.props_dict() if hasattr(rv_op, "props_dict") else {}
random_rv_key_contents = (
type(op),
type(rv_op),
rv_op,
tuple(rv_op_props_dict.items()),
size_len,
core_shape_len,
inplace,
input_bc_patterns,
)
random_rv_key = sha256(str(random_rv_key_contents).encode()).hexdigest()
return random, random_rv_key
import math
from hashlib import sha256
import numpy as np
from pytensor.graph.basic import Variable
from pytensor.link.numba.cache import compile_numba_function_src
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
create_numba_signature,
generate_fallback_impl,
numba_funcify,
numba_funcify_and_cache_key,
register_funcify_and_cache_key,
)
from pytensor.link.numba.dispatch.cython_support import wrap_cython_function
from pytensor.link.utils import (
compile_function_src,
get_name_for_object,
unique_name_generator,
)
......@@ -30,13 +31,16 @@ from pytensor.scalar.basic import (
from pytensor.scalar.math import Erf, Erfc, GammaLn, Log1mexp, Sigmoid, Softplus
@numba_funcify.register(ScalarOp)
def numba_funcify_ScalarOp(op, node, **kwargs):
# TODO: Do we need to cache these functions so that we don't end up
# compiling the same Numba function over and over again?
def scalar_op_cache_key(op):
# Scalar Ops don't have _props, because of their weird outputs_types_preference function
# So we create hash differently
return sha256(str(type(op)).encode()).hexdigest()
@register_funcify_and_cache_key(ScalarOp)
def numba_funcify_ScalarOp(op, node, **kwargs):
if not hasattr(op, "nfunc_spec"):
return generate_fallback_impl(op, node, **kwargs)
return generate_fallback_impl(op, node=node, **kwargs), None
scalar_func_path = op.nfunc_spec[0]
scalar_func_numba = None
......@@ -58,6 +62,7 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
output_inner_dtype = None
# Cython functions might have an additional argument
cython_func = None
has_pyx_skip_dispatch = False
if scalar_func_path.startswith("scipy.special"):
......@@ -127,20 +132,18 @@ def {scalar_op_fn_name}({", ".join(input_names)}):
return direct_cast(scalar_func_numba({converted_call_args}, np.intc(1)), output_dtype)
"""
scalar_op_fn = compile_function_src(
scalar_op_src, scalar_op_fn_name, {**globals(), **global_env}
scalar_op_fn = compile_numba_function_src(
scalar_op_src,
scalar_op_fn_name,
{**globals(), **global_env},
)
signature = create_numba_signature(node, force_scalar=True)
# Functions that call a function pointer can't be cached
cache_key = None if cython_func else scalar_op_cache_key(op)
return numba_basic.numba_njit(scalar_op_fn), cache_key
return numba_basic.numba_njit(
signature,
# Functions that call a function pointer can't be cached
cache=False,
)(scalar_op_fn)
@numba_funcify.register(Switch)
@register_funcify_and_cache_key(Switch)
def numba_funcify_Switch(op, node, **kwargs):
@numba_basic.numba_njit
def switch(condition, x, y):
......@@ -149,7 +152,7 @@ def numba_funcify_Switch(op, node, **kwargs):
else:
return y
return switch
return switch, scalar_op_cache_key(op)
def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op: str):
......@@ -163,28 +166,26 @@ def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op:
def {binary_op_name}({input_signature}):
return {output_expr}
"""
nary_fn = compile_function_src(nary_src, binary_op_name, globals())
nary_fn = compile_numba_function_src(nary_src, binary_op_name, globals())
return nary_fn
@numba_funcify.register(Add)
@register_funcify_and_cache_key(Add)
def numba_funcify_Add(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True)
nary_add_fn = binary_to_nary_func(node.inputs, "add", "+")
return numba_basic.numba_njit(signature)(nary_add_fn)
return numba_basic.numba_njit(nary_add_fn), scalar_op_cache_key(op)
@numba_funcify.register(Mul)
@register_funcify_and_cache_key(Mul)
def numba_funcify_Mul(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True)
nary_add_fn = binary_to_nary_func(node.inputs, "mul", "*")
nary_mul_fn = binary_to_nary_func(node.inputs, "mul", "*")
return numba_basic.numba_njit(signature)(nary_add_fn)
return numba_basic.numba_njit(nary_mul_fn), scalar_op_cache_key(op)
@numba_funcify.register(Cast)
@register_funcify_and_cache_key(Cast)
def numba_funcify_Cast(op, node, **kwargs):
dtype = np.dtype(op.o_type.dtype)
......@@ -192,19 +193,19 @@ def numba_funcify_Cast(op, node, **kwargs):
def cast(x):
return numba_basic.direct_cast(x, dtype)
return cast
return cast, sha256(str((type(op), op.o_type.dtype)).encode()).hexdigest()
@numba_funcify.register(Identity)
@register_funcify_and_cache_key(Identity)
def numba_funcify_type_casting(op, **kwargs):
@numba_basic.numba_njit
def identity(x):
return x
return identity
return identity, scalar_op_cache_key(op)
@numba_funcify.register(Clip)
@register_funcify_and_cache_key(Clip)
def numba_funcify_Clip(op, **kwargs):
@numba_basic.numba_njit
def clip(x, min_val, max_val):
......@@ -215,26 +216,33 @@ def numba_funcify_Clip(op, **kwargs):
else:
return x
return clip
return clip, scalar_op_cache_key(op)
@numba_funcify.register(Composite)
@register_funcify_and_cache_key(Composite)
def numba_funcify_Composite(op, node, **kwargs):
_ = kwargs.pop("storage_map", None)
return numba_funcify(op.fgraph, squeeze_output=True, **kwargs)
composite_fn, fgraph_key = numba_funcify_and_cache_key(
op.fgraph, squeeze_output=True, **kwargs
)
if fgraph_key is None:
composite_key = None
else:
composite_key = sha256(str((type(op), fgraph_key)).encode()).hexdigest()
return composite_fn, composite_key
@numba_funcify.register(Second)
@register_funcify_and_cache_key(Second)
def numba_funcify_Second(op, node, **kwargs):
@numba_basic.numba_njit
def second(x, y):
return y
return second
return second, scalar_op_cache_key(op)
@numba_funcify.register(Reciprocal)
@register_funcify_and_cache_key(Reciprocal)
def numba_funcify_Reciprocal(op, node, **kwargs):
@numba_basic.numba_njit
def reciprocal(x):
......@@ -242,28 +250,28 @@ def numba_funcify_Reciprocal(op, node, **kwargs):
# `x` is an `int`
return 1 / x
return reciprocal
return reciprocal, scalar_op_cache_key(op)
@numba_funcify.register(Sigmoid)
@register_funcify_and_cache_key(Sigmoid)
def numba_funcify_Sigmoid(op, node, **kwargs):
@numba_basic.numba_njit
def sigmoid(x):
return 1 / (1 + np.exp(-x))
return sigmoid
return sigmoid, scalar_op_cache_key(op)
@numba_funcify.register(GammaLn)
@register_funcify_and_cache_key(GammaLn)
def numba_funcify_GammaLn(op, node, **kwargs):
@numba_basic.numba_njit
def gammaln(x):
return math.lgamma(x)
return gammaln
return gammaln, scalar_op_cache_key(op)
@numba_funcify.register(Log1mexp)
@register_funcify_and_cache_key(Log1mexp)
def numba_funcify_Log1mexp(op, node, **kwargs):
@numba_basic.numba_njit
def logp1mexp(x):
......@@ -272,28 +280,28 @@ def numba_funcify_Log1mexp(op, node, **kwargs):
else:
return np.log(-np.expm1(x))
return logp1mexp
return logp1mexp, scalar_op_cache_key(op)
@numba_funcify.register(Erf)
@register_funcify_and_cache_key(Erf)
def numba_funcify_Erf(op, **kwargs):
@numba_basic.numba_njit
def erf(x):
return math.erf(x)
return erf
return erf, scalar_op_cache_key(op)
@numba_funcify.register(Erfc)
@register_funcify_and_cache_key(Erfc)
def numba_funcify_Erfc(op, **kwargs):
@numba_basic.numba_njit
def erfc(x):
return math.erfc(x)
return erfc
return erfc, scalar_op_cache_key(op)
@numba_funcify.register(Softplus)
@register_funcify_and_cache_key(Softplus)
def numba_funcify_Softplus(op, node, **kwargs):
out_dtype = np.dtype(node.outputs[0].type.dtype)
......@@ -309,4 +317,4 @@ def numba_funcify_Softplus(op, node, **kwargs):
value = x
return numba_basic.direct_cast(value, out_dtype)
return softplus
return softplus, scalar_op_cache_key(op)
from hashlib import sha256
from textwrap import dedent, indent
import numpy as np
......@@ -7,13 +8,14 @@ from numba.extending import overload
from pytensor import In
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.mode import NUMBA, get_mode
from pytensor.link.numba.cache import compile_numba_function_src
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
create_arg_string,
create_tuple_string,
numba_funcify,
numba_funcify_and_cache_key,
register_funcify_and_cache_key,
)
from pytensor.link.utils import compile_function_src
from pytensor.scan.op import Scan
from pytensor.tensor.type import TensorType
......@@ -54,7 +56,7 @@ def array0d_range(x):
return range_arr
@numba_funcify.register(Scan)
@register_funcify_and_cache_key(Scan)
def numba_funcify_Scan(op: Scan, node, **kwargs):
# Apply inner rewrites
# TODO: Not sure this is the right place to do this, should we have a rewrite that
......@@ -97,7 +99,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
)
rewriter(fgraph)
scan_inner_func = numba_funcify(op.fgraph)
scan_inner_func, inner_func_cache_key = numba_funcify_and_cache_key(op.fgraph)
outer_in_names_to_vars = {
(f"outer_in_{i}" if i > 0 else "n_steps"): v for i, v in enumerate(node.inputs)
......@@ -439,6 +441,18 @@ def scan({", ".join(outer_in_names)}):
"scan_inner_func": scan_inner_func,
}
scan_op_fn = compile_function_src(scan_op_src, "scan", {**globals(), **global_env})
scan_op_fn = compile_numba_function_src(
scan_op_src,
"scan",
{**globals(), **global_env},
)
if inner_func_cache_key is None:
# If we can't cache the inner function, we can't cache the Scan either
scan_cache_key = None
else:
scan_cache_key = sha256(
f"({scan_op_src}, {inner_func_cache_key})".encode()
).hexdigest()
return numba_basic.numba_njit(scan_op_fn, boundscheck=False)
return numba_basic.numba_njit(scan_op_fn, boundscheck=False), scan_cache_key
......@@ -4,14 +4,16 @@ import numpy as np
from numba.np.unsafe import ndarray as numba_ndarray
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import create_arg_string, numba_njit
from pytensor.link.numba.dispatch.basic import (
create_arg_string,
register_funcify_default_op_cache_key,
)
from pytensor.link.utils import compile_function_src
from pytensor.tensor import NoneConst
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
@numba_funcify.register(Shape)
@register_funcify_default_op_cache_key(Shape)
def numba_funcify_Shape(op, **kwargs):
@numba_basic.numba_njit
def shape(x):
......@@ -20,7 +22,7 @@ def numba_funcify_Shape(op, **kwargs):
return shape
@numba_funcify.register(Shape_i)
@register_funcify_default_op_cache_key(Shape_i)
def numba_funcify_Shape_i(op, **kwargs):
i = op.i
......@@ -31,7 +33,7 @@ def numba_funcify_Shape_i(op, **kwargs):
return shape_i
@numba_funcify.register(SpecifyShape)
@register_funcify_default_op_cache_key(SpecifyShape)
def numba_funcify_SpecifyShape(op, node, **kwargs):
shape_inputs = node.inputs[1:]
shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))]
......@@ -53,10 +55,10 @@ def numba_funcify_SpecifyShape(op, node, **kwargs):
)
specify_shape = compile_function_src(func, "specify_shape", globals())
return numba_njit(specify_shape)
return numba_basic.numba_njit(specify_shape)
@numba_funcify.register(Reshape)
@register_funcify_default_op_cache_key(Reshape)
def numba_funcify_Reshape(op, **kwargs):
ndim = op.ndim
......
......@@ -2,11 +2,13 @@ import numpy as np
from numba.np.arraymath import _get_inner_prod
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import (
register_funcify_default_op_cache_key,
)
from pytensor.tensor.signal.conv import Convolve1d
@numba_funcify.register(Convolve1d)
@register_funcify_default_op_cache_key(Convolve1d)
def numba_funcify_Convolve1d(op, node, **kwargs):
# This specialized version is faster than the overloaded numba np.convolve
a_dtype, b_dtype = node.inputs[0].type.dtype, node.inputs[1].type.dtype
......
......@@ -4,7 +4,10 @@ import numpy as np
from pytensor import config
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import numba_funcify
from pytensor.link.numba.dispatch.basic import (
numba_funcify,
register_funcify_default_op_cache_key,
)
from pytensor.link.numba.dispatch.linalg.decomposition.cholesky import _cholesky
from pytensor.link.numba.dispatch.linalg.decomposition.lu import (
_lu_1,
......@@ -91,7 +94,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
return cholesky
@numba_funcify.register(PivotToPermutations)
@register_funcify_default_op_cache_key(PivotToPermutations)
def pivot_to_permutation(op, node, **kwargs):
inverse = op.inverse
dtype = node.outputs[0].dtype
......@@ -119,7 +122,7 @@ def numba_funcify_LU(op, node, **kwargs):
if dtype in complex_dtypes:
NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
@numba_basic.numba_njit(inline="always")
@numba_basic.numba_njit
def lu(a):
if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
......@@ -181,11 +184,10 @@ def numba_funcify_LUFactor(op, node, **kwargs):
return lu_factor
@numba_funcify.register(BlockDiagonal)
@register_funcify_default_op_cache_key(BlockDiagonal)
def numba_funcify_BlockDiagonal(op, node, **kwargs):
dtype = node.outputs[0].dtype
# TODO: Why do we always inline all functions? It doesn't work with starred args, so can't use it in this case.
@numba_basic.numba_njit
def block_diag(*arrs):
shapes = np.array([a.shape for a in arrs], dtype="int")
......@@ -338,7 +340,7 @@ def numba_funcify_QR(op, node, **kwargs):
integer_input = dtype in integer_dtypes
in_dtype = config.floatX if integer_input else dtype
@numba_basic.numba_njit(cache=False)
@numba_basic.numba_njit
def qr(a):
if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
......
......@@ -3,11 +3,13 @@ import warnings
import numpy as np
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import (
register_funcify_default_op_cache_key,
)
from pytensor.tensor.sort import ArgSortOp, SortOp
@numba_funcify.register(SortOp)
@register_funcify_default_op_cache_key(SortOp)
def numba_funcify_SortOp(op, node, **kwargs):
if op.kind != "quicksort":
warnings.warn(
......@@ -31,7 +33,7 @@ def numba_funcify_SortOp(op, node, **kwargs):
return sort_f
@numba_funcify.register(ArgSortOp)
@register_funcify_default_op_cache_key(ArgSortOp)
def numba_funcify_ArgSortOp(op, node, **kwargs):
kind = op.kind
......
import operator
import sys
from hashlib import sha256
import numba
import numpy as np
......@@ -7,11 +8,17 @@ from llvmlite import ir
from numba import types
from numba.core.pythonapi import box
import pytensor.link.numba.dispatch.basic as numba_basic
from pytensor.graph import Type
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import generate_fallback_impl, numba_njit
from pytensor.link.utils import compile_function_src, unique_name_generator
from pytensor.link.numba.cache import (
compile_numba_function_src,
)
from pytensor.link.numba.dispatch.basic import (
generate_fallback_impl,
register_funcify_and_cache_key,
register_funcify_default_op_cache_key,
)
from pytensor.link.utils import unique_name_generator
from pytensor.tensor import TensorType
from pytensor.tensor.rewriting.subtensor import is_full_slice
from pytensor.tensor.subtensor import (
......@@ -98,7 +105,7 @@ def enable_slice_boxing():
enable_slice_boxing()
@numba_funcify.register(MakeSlice)
@register_funcify_default_op_cache_key(MakeSlice)
def numba_funcify_MakeSlice(op, **kwargs):
@numba_basic.numba_njit
def makeslice(*x):
......@@ -107,9 +114,32 @@ def numba_funcify_MakeSlice(op, **kwargs):
return makeslice
@numba_funcify.register(Subtensor)
@numba_funcify.register(IncSubtensor)
@numba_funcify.register(AdvancedSubtensor1)
def subtensor_op_cache_key(op, **extra_fields):
key_parts = [type(op), tuple(extra_fields.items())]
if hasattr(op, "idx_list"):
idx_parts = []
for idx in op.idx_list:
if isinstance(idx, slice):
idx_parts.append(
(
idx.start is None,
idx.stop is None,
idx.step is None,
)
)
else:
idx_parts.append("i")
key_parts.append(tuple(idx_parts))
if isinstance(op, IncSubtensor | AdvancedIncSubtensor | AdvancedIncSubtensor1):
key_parts.append((op.inplace, op.set_instead_of_inc))
if isinstance(op, AdvancedIncSubtensor):
key_parts.append(op.ignore_duplicates)
return sha256(str(tuple(key_parts)).encode()).hexdigest()
@register_funcify_and_cache_key(Subtensor)
@register_funcify_and_cache_key(IncSubtensor)
@register_funcify_and_cache_key(AdvancedSubtensor1)
def numba_funcify_default_subtensor(op, node, **kwargs):
"""Create a Python function that assembles and uses an index on an array."""
......@@ -185,16 +215,17 @@ def {function_name}({", ".join(input_names)}):
return np.asarray(z)
"""
func = compile_function_src(
func = compile_numba_function_src(
subtensor_def_src,
function_name=function_name,
global_env=globals() | {"np": np},
)
return numba_njit(func, boundscheck=True)
cache_key = subtensor_op_cache_key(op, func="numba_funcify_default_subtensor")
return numba_basic.numba_njit(func, boundscheck=True), cache_key
@numba_funcify.register(AdvancedSubtensor)
@numba_funcify.register(AdvancedIncSubtensor)
@register_funcify_and_cache_key(AdvancedSubtensor)
@register_funcify_and_cache_key(AdvancedIncSubtensor)
def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
if isinstance(op, AdvancedSubtensor):
_x, _y, idxs = node.inputs[0], None, node.inputs[1:]
......@@ -255,7 +286,9 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
)
)
):
return generate_fallback_impl(op, node, **kwargs)
return generate_fallback_impl(op, node, **kwargs), subtensor_op_cache_key(
op, func="fallback_impl"
)
# What's left should all be supported natively by numba
return numba_funcify_default_subtensor(op, node, **kwargs)
......@@ -295,6 +328,7 @@ def numba_funcify_multiple_integer_vector_indexing(
vector_indices = idxs[first_axis:after_last_axis]
assert all(v.type.broadcastable == (False,) for v in vector_indices)
y_is_broadcasted = False
if isinstance(op, AdvancedSubtensor):
......@@ -313,7 +347,7 @@ def numba_funcify_multiple_integer_vector_indexing(
out_buffer[(*none_slices, i)] = x[(*none_slices, *scalar_idxs)]
return out_buffer
return advanced_subtensor_multiple_vector
ret_func = advanced_subtensor_multiple_vector
else:
inplace = op.inplace
......@@ -347,7 +381,7 @@ def numba_funcify_multiple_integer_vector_indexing(
out[(*outer, *scalar_idxs)] = y[(*outer, i)]
return out
return advanced_set_subtensor_multiple_vector
ret_func = advanced_set_subtensor_multiple_vector
else:
......@@ -369,10 +403,17 @@ def numba_funcify_multiple_integer_vector_indexing(
out[(*outer, *scalar_idxs)] += y[(*outer, i)]
return out
return advanced_inc_subtensor_multiple_vector
ret_func = advanced_inc_subtensor_multiple_vector
cache_key = subtensor_op_cache_key(
op,
func="multiple_integer_vector_indexing",
y_is_broadcasted=y_is_broadcasted,
)
return ret_func, cache_key
@numba_funcify.register(AdvancedIncSubtensor1)
@register_funcify_and_cache_key(AdvancedIncSubtensor1)
def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
inplace = op.inplace
set_instead_of_inc = op.set_instead_of_inc
......@@ -436,8 +477,14 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
x[idx] += val
return x
cache_key = subtensor_op_cache_key(
op,
func="numba_funcify_advancedincsubtensor1",
broadcast_with_index=broadcast_with_index,
)
if inplace:
return advancedincsubtensor1_inplace
return advancedincsubtensor1_inplace, cache_key
else:
......@@ -446,4 +493,4 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
x = x.copy()
return advancedincsubtensor1_inplace(x, vals, idxs)
return advancedincsubtensor1
return advancedincsubtensor1, cache_key
from hashlib import sha256
from textwrap import indent
import numpy as np
from pytensor.link.numba.cache import compile_numba_function_src
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
create_tuple_string,
numba_funcify,
register_funcify_and_cache_key,
register_funcify_default_op_cache_key,
)
from pytensor.link.utils import compile_function_src, unique_name_generator
from pytensor.link.utils import unique_name_generator
from pytensor.tensor.basic import (
Alloc,
AllocEmpty,
......@@ -23,7 +26,7 @@ from pytensor.tensor.basic import (
)
@numba_funcify.register(AllocEmpty)
@register_funcify_default_op_cache_key(AllocEmpty)
def numba_funcify_AllocEmpty(op, node, **kwargs):
global_env = {
"np": np,
......@@ -52,14 +55,14 @@ def allocempty({", ".join(shape_var_names)}):
return np.empty(scalar_shape, dtype)
"""
alloc_fn = compile_function_src(
alloc_fn = compile_numba_function_src(
alloc_def_src, "allocempty", {**globals(), **global_env}
)
return numba_basic.numba_njit(alloc_fn)
@numba_funcify.register(Alloc)
@register_funcify_and_cache_key(Alloc)
def numba_funcify_Alloc(op, node, **kwargs):
global_env = {"np": np}
......@@ -96,16 +99,23 @@ def alloc(val, {", ".join(shape_var_names)}):
res[...] = val
return res
"""
alloc_fn = compile_function_src(alloc_def_src, "alloc", {**globals(), **global_env})
alloc_fn = compile_numba_function_src(
alloc_def_src,
"alloc",
{**globals(), **global_env},
)
return numba_basic.numba_njit(alloc_fn)
cache_key = sha256(
str((type(op), node.inputs[0].type.broadcastable)).encode()
).hexdigest()
return numba_basic.numba_njit(alloc_fn), cache_key
@numba_funcify.register(ARange)
@register_funcify_default_op_cache_key(ARange)
def numba_funcify_ARange(op, **kwargs):
dtype = np.dtype(op.dtype)
@numba_basic.numba_njit(inline="always")
@numba_basic.numba_njit
def arange(start, stop, step):
return np.arange(
start.item(),
......@@ -117,7 +127,7 @@ def numba_funcify_ARange(op, **kwargs):
return arange
@numba_funcify.register(Join)
@register_funcify_default_op_cache_key(Join)
def numba_funcify_Join(op, **kwargs):
@numba_basic.numba_njit
def join(axis, *tensors):
......@@ -126,7 +136,7 @@ def numba_funcify_Join(op, **kwargs):
return join
@numba_funcify.register(Split)
@register_funcify_default_op_cache_key(Split)
def numba_funcify_Split(op, **kwargs):
@numba_basic.numba_njit
def split(tensor, axis, indices):
......@@ -135,14 +145,14 @@ def numba_funcify_Split(op, **kwargs):
return split
@numba_funcify.register(ExtractDiag)
@register_funcify_default_op_cache_key(ExtractDiag)
def numba_funcify_ExtractDiag(op, node, **kwargs):
view = op.view
axis1, axis2, offset = op.axis1, op.axis2, op.offset
if node.inputs[0].type.ndim == 2:
@numba_basic.numba_njit(inline="always")
@numba_basic.numba_njit
def extract_diag(x):
out = np.diag(x, k=offset)
......@@ -157,7 +167,7 @@ def numba_funcify_ExtractDiag(op, node, **kwargs):
leading_dims = (slice(None),) * axis1
middle_dims = (slice(None),) * (axis2 - axis1 - 1)
@numba_basic.numba_njit(inline="always")
@numba_basic.numba_njit
def extract_diag(x):
if offset >= 0:
diag_len = min(x.shape[axis1], max(0, x.shape[axis2] - offset))
......@@ -178,11 +188,11 @@ def numba_funcify_ExtractDiag(op, node, **kwargs):
return extract_diag
@numba_funcify.register(Eye)
@register_funcify_default_op_cache_key(Eye)
def numba_funcify_Eye(op, **kwargs):
dtype = np.dtype(op.dtype)
@numba_basic.numba_njit(inline="always")
@numba_basic.numba_njit
def eye(N, M, k):
return np.eye(
N.item(),
......@@ -194,7 +204,7 @@ def numba_funcify_Eye(op, **kwargs):
return eye
@numba_funcify.register(MakeVector)
@register_funcify_default_op_cache_key(MakeVector)
def numba_funcify_MakeVector(op, node, **kwargs):
dtype = np.dtype(op.dtype)
......@@ -215,32 +225,34 @@ def makevector({", ".join(input_names)}):
return np.array({create_list_string(input_names)}, dtype=dtype)
"""
makevector_fn = compile_function_src(
makevector_def_src, "makevector", {**globals(), **global_env}
makevector_fn = compile_numba_function_src(
makevector_def_src,
"makevector",
{**globals(), **global_env},
)
return numba_basic.numba_njit(makevector_fn)
@numba_funcify.register(TensorFromScalar)
@register_funcify_default_op_cache_key(TensorFromScalar)
def numba_funcify_TensorFromScalar(op, **kwargs):
@numba_basic.numba_njit(inline="always")
@numba_basic.numba_njit
def tensor_from_scalar(x):
return np.array(x)
return tensor_from_scalar
@numba_funcify.register(ScalarFromTensor)
@register_funcify_default_op_cache_key(ScalarFromTensor)
def numba_funcify_ScalarFromTensor(op, **kwargs):
@numba_basic.numba_njit(inline="always")
@numba_basic.numba_njit
def scalar_from_tensor(x):
return x.item()
return scalar_from_tensor
@numba_funcify.register(Nonzero)
@register_funcify_default_op_cache_key(Nonzero)
def numba_funcify_Nonzero(op, node, **kwargs):
@numba_basic.numba_njit
def nonzero(a):
......
......@@ -4,7 +4,7 @@ import base64
import pickle
from collections.abc import Callable, Sequence
from textwrap import indent
from typing import Any, cast
from typing import Any
import numba
import numpy as np
......@@ -15,8 +15,8 @@ from numba.core.base import BaseContext
from numba.core.types.misc import NoneType
from numba.np import arrayobj
from pytensor.link.numba.cache import compile_numba_function_src
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.utils import compile_function_src
def encode_literals(literals: Sequence) -> str:
......@@ -52,10 +52,13 @@ def store_core_outputs({inp_signature}, {out_signature}):
{indent(store_outputs, " " * 4)}
"""
global_env = {"core_op_fn": core_op_fn}
func = compile_function_src(
func_src, "store_core_outputs", {**globals(), **global_env}
func = compile_numba_function_src(
func_src,
"store_core_outputs",
{**globals(), **global_env},
)
return cast(Callable, numba_basic.numba_njit(func))
return numba_basic.numba_njit(func)
_jit_options = {
......@@ -74,7 +77,7 @@ _jit_options = {
@numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True)
def _vectorized(
typingctx,
scalar_func,
core_func,
input_bc_patterns,
output_bc_patterns,
output_dtypes,
......@@ -85,7 +88,7 @@ def _vectorized(
size_type,
):
arg_types = [
scalar_func,
core_func,
input_bc_patterns,
output_bc_patterns,
output_dtypes,
......@@ -173,16 +176,6 @@ def _vectorized(
)
out_types[output_idx] = output_type
core_signature = typingctx.resolve_function_type(
scalar_func,
[
*constant_inputs_types,
*core_input_types,
*core_out_types,
],
{},
)
ret_type = types.Tuple(out_types)
if len(output_dtypes) == 1:
......@@ -239,11 +232,21 @@ def _vectorized(
output_core_shapes,
)
core_signature = typingctx.resolve_function_type(
core_func,
[
*constant_inputs_types,
*core_input_types,
*core_out_types,
],
{},
)
make_loop_call(
typingctx,
ctx,
builder,
scalar_func,
core_func,
core_signature,
iter_shape,
constant_inputs,
......@@ -416,8 +419,8 @@ def make_loop_call(
typingctx,
context: numba.core.base.BaseContext,
builder: ir.IRBuilder,
scalar_func: Any,
scalar_signature: types.FunctionType,
core_func: Any,
core_signature: types.FunctionType,
iter_shape: tuple[ir.Instruction, ...],
constant_inputs: tuple[ir.Instruction, ...],
inputs: tuple[ir.Instruction, ...],
......@@ -557,10 +560,10 @@ def make_loop_call(
val = core_array._getvalue()
output_slices.append(val)
inner_codegen = context.get_function(scalar_func, scalar_signature)
inner_codegen = context.get_function(core_func, core_signature)
if isinstance(scalar_signature.args[0], types.StarArgTuple | types.StarArgUniTuple):
input_vals = [context.make_tuple(builder, scalar_signature.args[0], input_vals)]
if isinstance(core_signature.args[0], types.StarArgTuple | types.StarArgUniTuple):
input_vals = [context.make_tuple(builder, core_signature.args[0], input_vals)]
inner_codegen(builder, [*constant_inputs, *input_vals, *output_slices])
......
......@@ -13,21 +13,32 @@ from tests.tensor.signal.test_conv import convolve1d_grad_benchmarker
pytestmark = pytest.mark.filterwarnings("error")
@pytest.mark.parametrize("bcast_order", (1, 0))
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
@pytest.mark.parametrize("x_smaller", (False, True))
def test_convolve1d(x_smaller, mode):
def test_convolve1d(mode, bcast_order):
x = dmatrix("x")
y = dmatrix("y")
if x_smaller:
out = convolve1d(x[None], y[:, None], mode=mode)
# Testing two orders because this revealed a bug in the past
if bcast_order == 0:
out = convolve1d(x[:, None], y[None, :], mode=mode)
else:
out = convolve1d(y[:, None], x[None], mode=mode)
out = convolve1d(x[None], y[:, None], mode=mode)
rng = np.random.default_rng()
test_x = rng.normal(size=(3, 5))
test_y = rng.normal(size=(7, 11))
# Blockwise dispatch for numba can't be run on object mode
compare_numba_and_py([x, y], out, [test_x, test_y], eval_obj_mode=False)
numba_fn, res = compare_numba_and_py(
[x, y], out, [test_x, test_y], eval_obj_mode=False
)
# Try other order of inputs, as implementation depends on it
# Result should be the same, just in different order, except for 'same' mode
if mode != "same":
np.testing.assert_allclose(
np.swapaxes(numba_fn(test_y, test_x), 0, 1),
res,
)
@pytest.mark.parametrize("mode", ("full", "valid"), ids=lambda x: f"mode={x}")
......
......@@ -402,7 +402,9 @@ def test_config_options_fastmath():
with config.change_flags(numba__fastmath=True):
pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode)
numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__[
"jitable_func"
].py_func.__globals__["impl_sum"]
assert numba_sum_fn.targetoptions["fastmath"] == {
"afn",
"arcp",
......@@ -413,7 +415,9 @@ def test_config_options_fastmath():
with config.change_flags(numba__fastmath=False):
pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode)
numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__[
"jitable_func"
].py_func.__globals__["impl_sum"]
assert numba_sum_fn.targetoptions["fastmath"] is False
......@@ -422,9 +426,10 @@ def test_config_options_cached():
with config.change_flags(numba__cache=True):
pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode)
numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
# Caching is disabled unless the dispatched function returns an explicit cache key
assert isinstance(numba_sum_fn._cache, numba.core.caching.NullCache)
numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__[
"jitable_func"
].py_func.__globals__["impl_sum"]
assert not isinstance(numba_sum_fn._cache, numba.core.caching.NullCache)
with config.change_flags(numba__cache=False):
pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论