提交 6ac5ab28 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Cache keys for numba Op dispatches

上级 74ab0383
import sys from hashlib import sha256
from typing import cast from typing import cast
from numba.core.extending import overload from numba.core.extending import overload
from numba.np.unsafe.ndarray import to_fixed_tuple from numba.np.unsafe.ndarray import to_fixed_tuple
from pytensor.link.numba.cache import compile_numba_function_src
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import numba_funcify from pytensor.link.numba.dispatch.basic import (
numba_funcify_and_cache_key,
register_funcify_and_cache_key,
)
from pytensor.link.numba.dispatch.vectorize_codegen import ( from pytensor.link.numba.dispatch.vectorize_codegen import (
_jit_options, _jit_options,
_vectorized, _vectorized,
encode_literals, encode_literals,
store_core_outputs, store_core_outputs,
) )
from pytensor.link.utils import compile_function_src
from pytensor.tensor import TensorVariable, get_vector_length from pytensor.tensor import TensorVariable, get_vector_length
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
@numba_funcify.register(BlockwiseWithCoreShape) @register_funcify_and_cache_key(BlockwiseWithCoreShape)
def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs): def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
[blockwise_node] = op.fgraph.apply_nodes [blockwise_node] = op.fgraph.apply_nodes
blockwise_op: Blockwise = blockwise_node.op blockwise_op: Blockwise = blockwise_node.op
...@@ -30,7 +33,7 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs): ...@@ -30,7 +33,7 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
cast(tuple[TensorVariable], node.inputs[:nin]), cast(tuple[TensorVariable], node.inputs[:nin]),
propagate_unbatched_core_inputs=True, propagate_unbatched_core_inputs=True,
) )
core_op_fn = numba_funcify( core_op_fn, core_op_key = numba_funcify_and_cache_key(
core_op, core_op,
node=core_node, node=core_node,
parent_node=node, parent_node=node,
...@@ -58,18 +61,25 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs): ...@@ -58,18 +61,25 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
src += ")" src += ")"
to_tuple = numba_basic.numba_njit( to_tuple = numba_basic.numba_njit(
compile_function_src( compile_numba_function_src(
src, src,
"to_tuple", "to_tuple",
global_env={"to_fixed_tuple": to_fixed_tuple}, global_env={"to_fixed_tuple": to_fixed_tuple},
), )
# cache=True leads to a numba.cloudpickle dump failure in Python 3.10
# May be fine in Python 3.11, but I didn't test. It was fine in 3.12
cache=sys.version_info >= (3, 12),
) )
def blockwise_wrapper(*inputs_and_core_shapes): def blockwise(*inputs_and_core_shapes):
inputs, core_shapes = inputs_and_core_shapes[:nin], inputs_and_core_shapes[nin:] raise NotImplementedError(
"Numba implementation of Blockwise cannot be evaluated in Python (non-JIT) mode."
)
@overload(blockwise, jit_options=_jit_options)
def ov_blockwise(*inputs_and_core_shapes):
def impl(*inputs_and_core_shapes):
inputs, core_shapes = (
inputs_and_core_shapes[:nin],
inputs_and_core_shapes[nin:],
)
tuple_core_shapes = to_tuple(core_shapes) tuple_core_shapes = to_tuple(core_shapes)
return _vectorized( return _vectorized(
core_op_fn, core_op_fn,
...@@ -83,11 +93,24 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs): ...@@ -83,11 +93,24 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
None, # size None, # size
) )
def blockwise(*inputs_and_core_shapes): return impl
raise NotImplementedError("Non-jitted BlockwiseWithCoreShape not implemented")
@overload(blockwise, jit_options=_jit_options)
def ov_blockwise(*inputs_and_core_shapes):
return blockwise_wrapper
return blockwise if core_op_key is None:
# If the core op cannot be cached, the Blockwise wrapper cannot be cached either
blockwise_key = None
else:
blockwise_key = "_".join(
map(
str,
(
type(op),
type(blockwise_op),
tuple(blockwise_op.destroy_map.items()),
blockwise_op.signature,
input_bc_patterns,
core_op_key,
),
)
)
blockwise_key = sha256(blockwise_key.encode()).hexdigest()
return blockwise, blockwise_key
from hashlib import sha256
import numpy as np import numpy as np
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
...@@ -8,14 +10,15 @@ from pytensor.compile.ops import DeepCopyOp, TypeCastingOp ...@@ -8,14 +10,15 @@ from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
from pytensor.ifelse import IfElse from pytensor.ifelse import IfElse
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import ( from pytensor.link.numba.dispatch.basic import (
numba_funcify, numba_funcify_and_cache_key,
numba_njit, register_funcify_and_cache_key,
register_funcify_default_op_cache_key,
) )
from pytensor.raise_op import CheckAndRaise from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.type import TensorType from pytensor.tensor.type import TensorType
@numba_funcify.register(OpFromGraph) @register_funcify_and_cache_key(OpFromGraph)
def numba_funcify_OpFromGraph(op, node=None, **kwargs): def numba_funcify_OpFromGraph(op, node=None, **kwargs):
_ = kwargs.pop("storage_map", None) _ = kwargs.pop("storage_map", None)
...@@ -30,10 +33,27 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs): ...@@ -30,10 +33,27 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs):
accept_inplace=True, accept_inplace=True,
) )
NUMBA.optimizer(fgraph) NUMBA.optimizer(fgraph)
return numba_funcify(op.fgraph, squeeze_output=True, **kwargs) fgraph_fn, fgraph_cache_key = numba_funcify_and_cache_key(
op.fgraph, squeeze_output=True, **kwargs
)
if fgraph_cache_key is None:
# Can't cache the inner graph
ofg_cache_key = None
else:
ofg_cache_key = sha256(
str(
(
type(op),
fgraph_cache_key,
)
).encode()
).hexdigest()
return fgraph_fn, ofg_cache_key
@numba_funcify.register(TypeCastingOp) @register_funcify_default_op_cache_key(TypeCastingOp)
def numba_funcify_type_casting(op, **kwargs): def numba_funcify_type_casting(op, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def identity(x): def identity(x):
...@@ -42,7 +62,7 @@ def numba_funcify_type_casting(op, **kwargs): ...@@ -42,7 +62,7 @@ def numba_funcify_type_casting(op, **kwargs):
return identity return identity
@numba_funcify.register(DeepCopyOp) @register_funcify_default_op_cache_key(DeepCopyOp)
def numba_funcify_DeepCopyOp(op, node, **kwargs): def numba_funcify_DeepCopyOp(op, node, **kwargs):
if isinstance(node.inputs[0].type, TensorType): if isinstance(node.inputs[0].type, TensorType):
...@@ -59,7 +79,7 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs): ...@@ -59,7 +79,7 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs):
return deepcopy return deepcopy
@numba_funcify.register(IfElse) @register_funcify_default_op_cache_key(IfElse)
def numba_funcify_IfElse(op, **kwargs): def numba_funcify_IfElse(op, **kwargs):
n_outs = op.n_outs n_outs = op.n_outs
...@@ -88,7 +108,7 @@ def numba_funcify_IfElse(op, **kwargs): ...@@ -88,7 +108,7 @@ def numba_funcify_IfElse(op, **kwargs):
return ifelse return ifelse
@numba_funcify.register(CheckAndRaise) @register_funcify_and_cache_key(CheckAndRaise)
def numba_funcify_CheckAndRaise(op, node, **kwargs): def numba_funcify_CheckAndRaise(op, node, **kwargs):
error = op.exc_type error = op.exc_type
msg = op.msg msg = op.msg
...@@ -100,4 +120,5 @@ def numba_funcify_CheckAndRaise(op, node, **kwargs): ...@@ -100,4 +120,5 @@ def numba_funcify_CheckAndRaise(op, node, **kwargs):
raise error(msg) raise error(msg)
return x return x
return check_and_raise cache_key = sha256(str((type(op), error, msg)).encode()).hexdigest()
return check_and_raise, cache_key
import warnings import warnings
from hashlib import sha256
from typing import cast from typing import cast
import numba import numba
...@@ -9,7 +10,8 @@ from pytensor.link.numba.dispatch import basic as numba_basic ...@@ -9,7 +10,8 @@ from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import ( from pytensor.link.numba.dispatch.basic import (
generate_fallback_impl, generate_fallback_impl,
get_numba_type, get_numba_type,
numba_funcify, register_funcify_and_cache_key,
register_funcify_default_op_cache_key,
) )
from pytensor.tensor import TensorVariable from pytensor.tensor import TensorVariable
from pytensor.tensor.extra_ops import ( from pytensor.tensor.extra_ops import (
...@@ -25,16 +27,16 @@ from pytensor.tensor.extra_ops import ( ...@@ -25,16 +27,16 @@ from pytensor.tensor.extra_ops import (
) )
@numba_funcify.register(Bartlett) @register_funcify_default_op_cache_key(Bartlett)
def numba_funcify_Bartlett(op, **kwargs): def numba_funcify_Bartlett(op, **kwargs):
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit
def bartlett(x): def bartlett(x):
return np.bartlett(x.item()) return np.bartlett(x.item())
return bartlett return bartlett
@numba_funcify.register(CumOp) @register_funcify_default_op_cache_key(CumOp)
def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs): def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
axis = op.axis axis = op.axis
mode = op.mode mode = op.mode
...@@ -94,7 +96,7 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs): ...@@ -94,7 +96,7 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
return cumop return cumop
@numba_funcify.register(FillDiagonal) @register_funcify_default_op_cache_key(FillDiagonal)
def numba_funcify_FillDiagonal(op, **kwargs): def numba_funcify_FillDiagonal(op, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def filldiagonal(a, val): def filldiagonal(a, val):
...@@ -104,7 +106,7 @@ def numba_funcify_FillDiagonal(op, **kwargs): ...@@ -104,7 +106,7 @@ def numba_funcify_FillDiagonal(op, **kwargs):
return filldiagonal return filldiagonal
@numba_funcify.register(FillDiagonalOffset) @register_funcify_default_op_cache_key(FillDiagonalOffset)
def numba_funcify_FillDiagonalOffset(op, node, **kwargs): def numba_funcify_FillDiagonalOffset(op, node, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def filldiagonaloffset(a, val, offset): def filldiagonaloffset(a, val, offset):
...@@ -129,7 +131,7 @@ def numba_funcify_FillDiagonalOffset(op, node, **kwargs): ...@@ -129,7 +131,7 @@ def numba_funcify_FillDiagonalOffset(op, node, **kwargs):
return filldiagonaloffset return filldiagonaloffset
@numba_funcify.register(RavelMultiIndex) @register_funcify_default_op_cache_key(RavelMultiIndex)
def numba_funcify_RavelMultiIndex(op, node, **kwargs): def numba_funcify_RavelMultiIndex(op, node, **kwargs):
mode = op.mode mode = op.mode
order = op.order order = op.order
...@@ -194,7 +196,7 @@ def numba_funcify_RavelMultiIndex(op, node, **kwargs): ...@@ -194,7 +196,7 @@ def numba_funcify_RavelMultiIndex(op, node, **kwargs):
return ravelmultiindex return ravelmultiindex
@numba_funcify.register(Repeat) @register_funcify_default_op_cache_key(Repeat)
def numba_funcify_Repeat(op, node, **kwargs): def numba_funcify_Repeat(op, node, **kwargs):
axis = op.axis axis = op.axis
a, _ = node.inputs a, _ = node.inputs
...@@ -202,7 +204,7 @@ def numba_funcify_Repeat(op, node, **kwargs): ...@@ -202,7 +204,7 @@ def numba_funcify_Repeat(op, node, **kwargs):
# Numba only supports axis=None, which in our case is when axis is 0 and the input is a vector # Numba only supports axis=None, which in our case is when axis is 0 and the input is a vector
if axis == 0 and a.type.ndim == 1: if axis == 0 and a.type.ndim == 1:
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit
def repeatop(x, repeats): def repeatop(x, repeats):
return np.repeat(x, repeats) return np.repeat(x, repeats)
...@@ -212,7 +214,7 @@ def numba_funcify_Repeat(op, node, **kwargs): ...@@ -212,7 +214,7 @@ def numba_funcify_Repeat(op, node, **kwargs):
return generate_fallback_impl(op, node) return generate_fallback_impl(op, node)
@numba_funcify.register(Unique) @register_funcify_default_op_cache_key(Unique)
def numba_funcify_Unique(op, node, **kwargs): def numba_funcify_Unique(op, node, **kwargs):
axis = op.axis axis = op.axis
...@@ -230,7 +232,7 @@ def numba_funcify_Unique(op, node, **kwargs): ...@@ -230,7 +232,7 @@ def numba_funcify_Unique(op, node, **kwargs):
if not use_python: if not use_python:
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit
def unique(x): def unique(x):
return np.unique(x) return np.unique(x)
...@@ -257,7 +259,7 @@ def numba_funcify_Unique(op, node, **kwargs): ...@@ -257,7 +259,7 @@ def numba_funcify_Unique(op, node, **kwargs):
return unique return unique
@numba_funcify.register(UnravelIndex) @register_funcify_and_cache_key(UnravelIndex)
def numba_funcify_UnravelIndex(op, node, **kwargs): def numba_funcify_UnravelIndex(op, node, **kwargs):
order = op.order order = op.order
...@@ -289,10 +291,14 @@ def numba_funcify_UnravelIndex(op, node, **kwargs): ...@@ -289,10 +291,14 @@ def numba_funcify_UnravelIndex(op, node, **kwargs):
# unpacked into a `tuple`, so this discrepancy shouldn't really matter # unpacked into a `tuple`, so this discrepancy shouldn't really matter
return ((maybe_expand_dim(arr) // a) % shape).T return ((maybe_expand_dim(arr) // a) % shape).T
return unravelindex cache_key = sha256(
str((type(op), op.order, len(node.outputs))).encode()
).hexdigest()
return unravelindex, cache_key
@numba_funcify.register(SearchsortedOp) @register_funcify_default_op_cache_key(SearchsortedOp)
def numba_funcify_Searchsorted(op, node, **kwargs): def numba_funcify_Searchsorted(op, node, **kwargs):
side = op.side side = op.side
...@@ -319,7 +325,7 @@ def numba_funcify_Searchsorted(op, node, **kwargs): ...@@ -319,7 +325,7 @@ def numba_funcify_Searchsorted(op, node, **kwargs):
else: else:
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit
def searchsorted(a, v): def searchsorted(a, v):
return np.searchsorted(a, v, side) return np.searchsorted(a, v, side)
......
...@@ -3,11 +3,11 @@ import warnings ...@@ -3,11 +3,11 @@ import warnings
import numba import numba
import numpy as np import numpy as np
from pytensor.link.numba.dispatch import basic as numba_basic import pytensor.link.numba.dispatch.basic as numba_basic
from pytensor.link.numba.dispatch.basic import ( from pytensor.link.numba.dispatch.basic import (
get_numba_type, get_numba_type,
int_to_float_fn, int_to_float_fn,
numba_funcify, register_funcify_default_op_cache_key,
) )
from pytensor.tensor.nlinalg import ( from pytensor.tensor.nlinalg import (
SVD, SVD,
...@@ -20,7 +20,7 @@ from pytensor.tensor.nlinalg import ( ...@@ -20,7 +20,7 @@ from pytensor.tensor.nlinalg import (
) )
@numba_funcify.register(SVD) @register_funcify_default_op_cache_key(SVD)
def numba_funcify_SVD(op, node, **kwargs): def numba_funcify_SVD(op, node, **kwargs):
full_matrices = op.full_matrices full_matrices = op.full_matrices
compute_uv = op.compute_uv compute_uv = op.compute_uv
...@@ -44,19 +44,19 @@ def numba_funcify_SVD(op, node, **kwargs): ...@@ -44,19 +44,19 @@ def numba_funcify_SVD(op, node, **kwargs):
return svd return svd
@numba_funcify.register(Det) @register_funcify_default_op_cache_key(Det)
def numba_funcify_Det(op, node, **kwargs): def numba_funcify_Det(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype) inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit
def det(x): def det(x):
return np.array(np.linalg.det(inputs_cast(x))).astype(out_dtype) return np.array(np.linalg.det(inputs_cast(x))).astype(out_dtype)
return det return det
@numba_funcify.register(SLogDet) @register_funcify_default_op_cache_key(SLogDet)
def numba_funcify_SLogDet(op, node, **kwargs): def numba_funcify_SLogDet(op, node, **kwargs):
out_dtype_1 = node.outputs[0].type.numpy_dtype out_dtype_1 = node.outputs[0].type.numpy_dtype
out_dtype_2 = node.outputs[1].type.numpy_dtype out_dtype_2 = node.outputs[1].type.numpy_dtype
...@@ -74,7 +74,7 @@ def numba_funcify_SLogDet(op, node, **kwargs): ...@@ -74,7 +74,7 @@ def numba_funcify_SLogDet(op, node, **kwargs):
return slogdet return slogdet
@numba_funcify.register(Eig) @register_funcify_default_op_cache_key(Eig)
def numba_funcify_Eig(op, node, **kwargs): def numba_funcify_Eig(op, node, **kwargs):
w_dtype = node.outputs[0].type.numpy_dtype w_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, w_dtype) inputs_cast = int_to_float_fn(node.inputs, w_dtype)
...@@ -86,7 +86,7 @@ def numba_funcify_Eig(op, node, **kwargs): ...@@ -86,7 +86,7 @@ def numba_funcify_Eig(op, node, **kwargs):
return eig return eig
@numba_funcify.register(Eigh) @register_funcify_default_op_cache_key(Eigh)
def numba_funcify_Eigh(op, node, **kwargs): def numba_funcify_Eigh(op, node, **kwargs):
uplo = op.UPLO uplo = op.UPLO
...@@ -113,31 +113,31 @@ def numba_funcify_Eigh(op, node, **kwargs): ...@@ -113,31 +113,31 @@ def numba_funcify_Eigh(op, node, **kwargs):
else: else:
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit
def eigh(x): def eigh(x):
return np.linalg.eigh(x) return np.linalg.eigh(x)
return eigh return eigh
@numba_funcify.register(MatrixInverse) @register_funcify_default_op_cache_key(MatrixInverse)
def numba_funcify_MatrixInverse(op, node, **kwargs): def numba_funcify_MatrixInverse(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype) inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit
def matrix_inverse(x): def matrix_inverse(x):
return np.linalg.inv(inputs_cast(x)).astype(out_dtype) return np.linalg.inv(inputs_cast(x)).astype(out_dtype)
return matrix_inverse return matrix_inverse
@numba_funcify.register(MatrixPinv) @register_funcify_default_op_cache_key(MatrixPinv)
def numba_funcify_MatrixPinv(op, node, **kwargs): def numba_funcify_MatrixPinv(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype) inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit
def matrixpinv(x): def matrixpinv(x):
return np.linalg.pinv(inputs_cast(x)).astype(out_dtype) return np.linalg.pinv(inputs_cast(x)).astype(out_dtype)
......
from collections.abc import Callable from collections.abc import Callable
from copy import copy, deepcopy from copy import copy, deepcopy
from functools import singledispatch from functools import singledispatch
from hashlib import sha256
from textwrap import dedent from textwrap import dedent
import numba import numba
...@@ -13,7 +14,11 @@ import pytensor.tensor.random.basic as ptr ...@@ -13,7 +14,11 @@ import pytensor.tensor.random.basic as ptr
from pytensor.graph import Apply from pytensor.graph import Apply
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import direct_cast, numba_funcify from pytensor.link.numba.dispatch.basic import (
direct_cast,
numba_funcify,
register_funcify_and_cache_key,
)
from pytensor.link.numba.dispatch.vectorize_codegen import ( from pytensor.link.numba.dispatch.vectorize_codegen import (
_jit_options, _jit_options,
_vectorized, _vectorized,
...@@ -395,7 +400,7 @@ def numba_funcify_RandomVariable_core(op: RandomVariable, **kwargs): ...@@ -395,7 +400,7 @@ def numba_funcify_RandomVariable_core(op: RandomVariable, **kwargs):
) )
@numba_funcify.register @register_funcify_and_cache_key(RandomVariableWithCoreShape)
def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs): def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs):
core_shape = node.inputs[0] core_shape = node.inputs[0]
...@@ -423,7 +428,14 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs ...@@ -423,7 +428,14 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs
output_dtypes = encode_literals((rv_node.default_output().type.dtype,)) output_dtypes = encode_literals((rv_node.default_output().type.dtype,))
inplace_pattern = encode_literals(()) inplace_pattern = encode_literals(())
def random_wrapper(core_shape, rng, size, *dist_params): def random(core_shape, rng, size, *dist_params):
raise NotImplementedError(
"Numba implementation of RandomVariable cannot be evaluated in Python (non-JIT) mode"
)
@overload(random, jit_options=_jit_options)
def ov_random(core_shape, rng, size, *dist_params):
def impl(core_shape, rng, size, *dist_params):
if not inplace: if not inplace:
rng = copy(rng) rng = copy(rng)
...@@ -436,15 +448,24 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs ...@@ -436,15 +448,24 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs
(rng,), (rng,),
dist_params, dist_params,
(numba_ndarray.to_fixed_tuple(core_shape, core_shape_len),), (numba_ndarray.to_fixed_tuple(core_shape, core_shape_len),),
None if size_len is None else numba_ndarray.to_fixed_tuple(size, size_len), None
if size_len is None
else numba_ndarray.to_fixed_tuple(size, size_len),
) )
return rng, draws return rng, draws
def random(core_shape, rng, size, *dist_params): return impl
raise NotImplementedError("Non-jitted random variable not implemented")
@overload(random, jit_options=_jit_options)
def ov_random(core_shape, rng, size, *dist_params):
return random_wrapper
return random rv_op_props_dict = rv_op.props_dict() if hasattr(rv_op, "props_dict") else {}
random_rv_key_contents = (
type(op),
type(rv_op),
rv_op,
tuple(rv_op_props_dict.items()),
size_len,
core_shape_len,
inplace,
input_bc_patterns,
)
random_rv_key = sha256(str(random_rv_key_contents).encode()).hexdigest()
return random, random_rv_key
import math import math
from hashlib import sha256
import numpy as np import numpy as np
from pytensor.graph.basic import Variable from pytensor.graph.basic import Variable
from pytensor.link.numba.cache import compile_numba_function_src
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import ( from pytensor.link.numba.dispatch.basic import (
create_numba_signature,
generate_fallback_impl, generate_fallback_impl,
numba_funcify, numba_funcify_and_cache_key,
register_funcify_and_cache_key,
) )
from pytensor.link.numba.dispatch.cython_support import wrap_cython_function from pytensor.link.numba.dispatch.cython_support import wrap_cython_function
from pytensor.link.utils import ( from pytensor.link.utils import (
compile_function_src,
get_name_for_object, get_name_for_object,
unique_name_generator, unique_name_generator,
) )
...@@ -30,13 +31,16 @@ from pytensor.scalar.basic import ( ...@@ -30,13 +31,16 @@ from pytensor.scalar.basic import (
from pytensor.scalar.math import Erf, Erfc, GammaLn, Log1mexp, Sigmoid, Softplus from pytensor.scalar.math import Erf, Erfc, GammaLn, Log1mexp, Sigmoid, Softplus
@numba_funcify.register(ScalarOp) def scalar_op_cache_key(op):
def numba_funcify_ScalarOp(op, node, **kwargs): # Scalar Ops don't have _props, because of their weird outputs_types_preference function
# TODO: Do we need to cache these functions so that we don't end up # So we create hash differently
# compiling the same Numba function over and over again? return sha256(str(type(op)).encode()).hexdigest()
@register_funcify_and_cache_key(ScalarOp)
def numba_funcify_ScalarOp(op, node, **kwargs):
if not hasattr(op, "nfunc_spec"): if not hasattr(op, "nfunc_spec"):
return generate_fallback_impl(op, node, **kwargs) return generate_fallback_impl(op, node=node, **kwargs), None
scalar_func_path = op.nfunc_spec[0] scalar_func_path = op.nfunc_spec[0]
scalar_func_numba = None scalar_func_numba = None
...@@ -58,6 +62,7 @@ def numba_funcify_ScalarOp(op, node, **kwargs): ...@@ -58,6 +62,7 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
output_inner_dtype = None output_inner_dtype = None
# Cython functions might have an additional argument # Cython functions might have an additional argument
cython_func = None
has_pyx_skip_dispatch = False has_pyx_skip_dispatch = False
if scalar_func_path.startswith("scipy.special"): if scalar_func_path.startswith("scipy.special"):
...@@ -127,20 +132,18 @@ def {scalar_op_fn_name}({", ".join(input_names)}): ...@@ -127,20 +132,18 @@ def {scalar_op_fn_name}({", ".join(input_names)}):
return direct_cast(scalar_func_numba({converted_call_args}, np.intc(1)), output_dtype) return direct_cast(scalar_func_numba({converted_call_args}, np.intc(1)), output_dtype)
""" """
scalar_op_fn = compile_function_src( scalar_op_fn = compile_numba_function_src(
scalar_op_src, scalar_op_fn_name, {**globals(), **global_env} scalar_op_src,
scalar_op_fn_name,
{**globals(), **global_env},
) )
signature = create_numba_signature(node, force_scalar=True)
return numba_basic.numba_njit(
signature,
# Functions that call a function pointer can't be cached # Functions that call a function pointer can't be cached
cache=False, cache_key = None if cython_func else scalar_op_cache_key(op)
)(scalar_op_fn) return numba_basic.numba_njit(scalar_op_fn), cache_key
@numba_funcify.register(Switch) @register_funcify_and_cache_key(Switch)
def numba_funcify_Switch(op, node, **kwargs): def numba_funcify_Switch(op, node, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def switch(condition, x, y): def switch(condition, x, y):
...@@ -149,7 +152,7 @@ def numba_funcify_Switch(op, node, **kwargs): ...@@ -149,7 +152,7 @@ def numba_funcify_Switch(op, node, **kwargs):
else: else:
return y return y
return switch return switch, scalar_op_cache_key(op)
def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op: str): def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op: str):
...@@ -163,28 +166,26 @@ def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op: ...@@ -163,28 +166,26 @@ def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op:
def {binary_op_name}({input_signature}): def {binary_op_name}({input_signature}):
return {output_expr} return {output_expr}
""" """
nary_fn = compile_function_src(nary_src, binary_op_name, globals()) nary_fn = compile_numba_function_src(nary_src, binary_op_name, globals())
return nary_fn return nary_fn
@numba_funcify.register(Add) @register_funcify_and_cache_key(Add)
def numba_funcify_Add(op, node, **kwargs): def numba_funcify_Add(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True)
nary_add_fn = binary_to_nary_func(node.inputs, "add", "+") nary_add_fn = binary_to_nary_func(node.inputs, "add", "+")
return numba_basic.numba_njit(signature)(nary_add_fn) return numba_basic.numba_njit(nary_add_fn), scalar_op_cache_key(op)
@numba_funcify.register(Mul) @register_funcify_and_cache_key(Mul)
def numba_funcify_Mul(op, node, **kwargs): def numba_funcify_Mul(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True) nary_mul_fn = binary_to_nary_func(node.inputs, "mul", "*")
nary_add_fn = binary_to_nary_func(node.inputs, "mul", "*")
return numba_basic.numba_njit(signature)(nary_add_fn) return numba_basic.numba_njit(nary_mul_fn), scalar_op_cache_key(op)
@numba_funcify.register(Cast) @register_funcify_and_cache_key(Cast)
def numba_funcify_Cast(op, node, **kwargs): def numba_funcify_Cast(op, node, **kwargs):
dtype = np.dtype(op.o_type.dtype) dtype = np.dtype(op.o_type.dtype)
...@@ -192,19 +193,19 @@ def numba_funcify_Cast(op, node, **kwargs): ...@@ -192,19 +193,19 @@ def numba_funcify_Cast(op, node, **kwargs):
def cast(x): def cast(x):
return numba_basic.direct_cast(x, dtype) return numba_basic.direct_cast(x, dtype)
return cast return cast, sha256(str((type(op), op.o_type.dtype)).encode()).hexdigest()
@numba_funcify.register(Identity) @register_funcify_and_cache_key(Identity)
def numba_funcify_type_casting(op, **kwargs): def numba_funcify_type_casting(op, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def identity(x): def identity(x):
return x return x
return identity return identity, scalar_op_cache_key(op)
@numba_funcify.register(Clip) @register_funcify_and_cache_key(Clip)
def numba_funcify_Clip(op, **kwargs): def numba_funcify_Clip(op, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def clip(x, min_val, max_val): def clip(x, min_val, max_val):
...@@ -215,26 +216,33 @@ def numba_funcify_Clip(op, **kwargs): ...@@ -215,26 +216,33 @@ def numba_funcify_Clip(op, **kwargs):
else: else:
return x return x
return clip return clip, scalar_op_cache_key(op)
@numba_funcify.register(Composite) @register_funcify_and_cache_key(Composite)
def numba_funcify_Composite(op, node, **kwargs): def numba_funcify_Composite(op, node, **kwargs):
_ = kwargs.pop("storage_map", None) _ = kwargs.pop("storage_map", None)
return numba_funcify(op.fgraph, squeeze_output=True, **kwargs) composite_fn, fgraph_key = numba_funcify_and_cache_key(
op.fgraph, squeeze_output=True, **kwargs
)
if fgraph_key is None:
composite_key = None
else:
composite_key = sha256(str((type(op), fgraph_key)).encode()).hexdigest()
return composite_fn, composite_key
@numba_funcify.register(Second) @register_funcify_and_cache_key(Second)
def numba_funcify_Second(op, node, **kwargs): def numba_funcify_Second(op, node, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def second(x, y): def second(x, y):
return y return y
return second return second, scalar_op_cache_key(op)
@numba_funcify.register(Reciprocal) @register_funcify_and_cache_key(Reciprocal)
def numba_funcify_Reciprocal(op, node, **kwargs): def numba_funcify_Reciprocal(op, node, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def reciprocal(x): def reciprocal(x):
...@@ -242,28 +250,28 @@ def numba_funcify_Reciprocal(op, node, **kwargs): ...@@ -242,28 +250,28 @@ def numba_funcify_Reciprocal(op, node, **kwargs):
# `x` is an `int` # `x` is an `int`
return 1 / x return 1 / x
return reciprocal return reciprocal, scalar_op_cache_key(op)
@numba_funcify.register(Sigmoid) @register_funcify_and_cache_key(Sigmoid)
def numba_funcify_Sigmoid(op, node, **kwargs): def numba_funcify_Sigmoid(op, node, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def sigmoid(x): def sigmoid(x):
return 1 / (1 + np.exp(-x)) return 1 / (1 + np.exp(-x))
return sigmoid return sigmoid, scalar_op_cache_key(op)
@numba_funcify.register(GammaLn) @register_funcify_and_cache_key(GammaLn)
def numba_funcify_GammaLn(op, node, **kwargs): def numba_funcify_GammaLn(op, node, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def gammaln(x): def gammaln(x):
return math.lgamma(x) return math.lgamma(x)
return gammaln return gammaln, scalar_op_cache_key(op)
@numba_funcify.register(Log1mexp) @register_funcify_and_cache_key(Log1mexp)
def numba_funcify_Log1mexp(op, node, **kwargs): def numba_funcify_Log1mexp(op, node, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def logp1mexp(x): def logp1mexp(x):
...@@ -272,28 +280,28 @@ def numba_funcify_Log1mexp(op, node, **kwargs): ...@@ -272,28 +280,28 @@ def numba_funcify_Log1mexp(op, node, **kwargs):
else: else:
return np.log(-np.expm1(x)) return np.log(-np.expm1(x))
return logp1mexp return logp1mexp, scalar_op_cache_key(op)
@numba_funcify.register(Erf) @register_funcify_and_cache_key(Erf)
def numba_funcify_Erf(op, **kwargs): def numba_funcify_Erf(op, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def erf(x): def erf(x):
return math.erf(x) return math.erf(x)
return erf return erf, scalar_op_cache_key(op)
@numba_funcify.register(Erfc) @register_funcify_and_cache_key(Erfc)
def numba_funcify_Erfc(op, **kwargs): def numba_funcify_Erfc(op, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def erfc(x): def erfc(x):
return math.erfc(x) return math.erfc(x)
return erfc return erfc, scalar_op_cache_key(op)
@numba_funcify.register(Softplus) @register_funcify_and_cache_key(Softplus)
def numba_funcify_Softplus(op, node, **kwargs): def numba_funcify_Softplus(op, node, **kwargs):
out_dtype = np.dtype(node.outputs[0].type.dtype) out_dtype = np.dtype(node.outputs[0].type.dtype)
...@@ -309,4 +317,4 @@ def numba_funcify_Softplus(op, node, **kwargs): ...@@ -309,4 +317,4 @@ def numba_funcify_Softplus(op, node, **kwargs):
value = x value = x
return numba_basic.direct_cast(value, out_dtype) return numba_basic.direct_cast(value, out_dtype)
return softplus return softplus, scalar_op_cache_key(op)
from hashlib import sha256
from textwrap import dedent, indent from textwrap import dedent, indent
import numpy as np import numpy as np
...@@ -7,13 +8,14 @@ from numba.extending import overload ...@@ -7,13 +8,14 @@ from numba.extending import overload
from pytensor import In from pytensor import In
from pytensor.compile.function.types import add_supervisor_to_fgraph from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.mode import NUMBA, get_mode from pytensor.compile.mode import NUMBA, get_mode
from pytensor.link.numba.cache import compile_numba_function_src
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import ( from pytensor.link.numba.dispatch.basic import (
create_arg_string, create_arg_string,
create_tuple_string, create_tuple_string,
numba_funcify, numba_funcify_and_cache_key,
register_funcify_and_cache_key,
) )
from pytensor.link.utils import compile_function_src
from pytensor.scan.op import Scan from pytensor.scan.op import Scan
from pytensor.tensor.type import TensorType from pytensor.tensor.type import TensorType
...@@ -54,7 +56,7 @@ def array0d_range(x): ...@@ -54,7 +56,7 @@ def array0d_range(x):
return range_arr return range_arr
@numba_funcify.register(Scan) @register_funcify_and_cache_key(Scan)
def numba_funcify_Scan(op: Scan, node, **kwargs): def numba_funcify_Scan(op: Scan, node, **kwargs):
# Apply inner rewrites # Apply inner rewrites
# TODO: Not sure this is the right place to do this, should we have a rewrite that # TODO: Not sure this is the right place to do this, should we have a rewrite that
...@@ -97,7 +99,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): ...@@ -97,7 +99,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
) )
rewriter(fgraph) rewriter(fgraph)
scan_inner_func = numba_funcify(op.fgraph) scan_inner_func, inner_func_cache_key = numba_funcify_and_cache_key(op.fgraph)
outer_in_names_to_vars = { outer_in_names_to_vars = {
(f"outer_in_{i}" if i > 0 else "n_steps"): v for i, v in enumerate(node.inputs) (f"outer_in_{i}" if i > 0 else "n_steps"): v for i, v in enumerate(node.inputs)
...@@ -439,6 +441,18 @@ def scan({", ".join(outer_in_names)}): ...@@ -439,6 +441,18 @@ def scan({", ".join(outer_in_names)}):
"scan_inner_func": scan_inner_func, "scan_inner_func": scan_inner_func,
} }
scan_op_fn = compile_function_src(scan_op_src, "scan", {**globals(), **global_env}) scan_op_fn = compile_numba_function_src(
scan_op_src,
"scan",
{**globals(), **global_env},
)
if inner_func_cache_key is None:
# If we can't cache the inner function, we can't cache the Scan either
scan_cache_key = None
else:
scan_cache_key = sha256(
f"({scan_op_src}, {inner_func_cache_key})".encode()
).hexdigest()
return numba_basic.numba_njit(scan_op_fn, boundscheck=False) return numba_basic.numba_njit(scan_op_fn, boundscheck=False), scan_cache_key
...@@ -4,14 +4,16 @@ import numpy as np ...@@ -4,14 +4,16 @@ import numpy as np
from numba.np.unsafe import ndarray as numba_ndarray from numba.np.unsafe import ndarray as numba_ndarray
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch import numba_funcify from pytensor.link.numba.dispatch.basic import (
from pytensor.link.numba.dispatch.basic import create_arg_string, numba_njit create_arg_string,
register_funcify_default_op_cache_key,
)
from pytensor.link.utils import compile_function_src from pytensor.link.utils import compile_function_src
from pytensor.tensor import NoneConst from pytensor.tensor import NoneConst
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
@numba_funcify.register(Shape) @register_funcify_default_op_cache_key(Shape)
def numba_funcify_Shape(op, **kwargs): def numba_funcify_Shape(op, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def shape(x): def shape(x):
...@@ -20,7 +22,7 @@ def numba_funcify_Shape(op, **kwargs): ...@@ -20,7 +22,7 @@ def numba_funcify_Shape(op, **kwargs):
return shape return shape
@numba_funcify.register(Shape_i) @register_funcify_default_op_cache_key(Shape_i)
def numba_funcify_Shape_i(op, **kwargs): def numba_funcify_Shape_i(op, **kwargs):
i = op.i i = op.i
...@@ -31,7 +33,7 @@ def numba_funcify_Shape_i(op, **kwargs): ...@@ -31,7 +33,7 @@ def numba_funcify_Shape_i(op, **kwargs):
return shape_i return shape_i
@numba_funcify.register(SpecifyShape) @register_funcify_default_op_cache_key(SpecifyShape)
def numba_funcify_SpecifyShape(op, node, **kwargs): def numba_funcify_SpecifyShape(op, node, **kwargs):
shape_inputs = node.inputs[1:] shape_inputs = node.inputs[1:]
shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))] shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))]
...@@ -53,10 +55,10 @@ def numba_funcify_SpecifyShape(op, node, **kwargs): ...@@ -53,10 +55,10 @@ def numba_funcify_SpecifyShape(op, node, **kwargs):
) )
specify_shape = compile_function_src(func, "specify_shape", globals()) specify_shape = compile_function_src(func, "specify_shape", globals())
return numba_njit(specify_shape) return numba_basic.numba_njit(specify_shape)
@numba_funcify.register(Reshape) @register_funcify_default_op_cache_key(Reshape)
def numba_funcify_Reshape(op, **kwargs): def numba_funcify_Reshape(op, **kwargs):
ndim = op.ndim ndim = op.ndim
......
...@@ -2,11 +2,13 @@ import numpy as np ...@@ -2,11 +2,13 @@ import numpy as np
from numba.np.arraymath import _get_inner_prod from numba.np.arraymath import _get_inner_prod
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch import numba_funcify from pytensor.link.numba.dispatch.basic import (
register_funcify_default_op_cache_key,
)
from pytensor.tensor.signal.conv import Convolve1d from pytensor.tensor.signal.conv import Convolve1d
@numba_funcify.register(Convolve1d) @register_funcify_default_op_cache_key(Convolve1d)
def numba_funcify_Convolve1d(op, node, **kwargs): def numba_funcify_Convolve1d(op, node, **kwargs):
# This specialized version is faster than the overloaded numba np.convolve # This specialized version is faster than the overloaded numba np.convolve
a_dtype, b_dtype = node.inputs[0].type.dtype, node.inputs[1].type.dtype a_dtype, b_dtype = node.inputs[0].type.dtype, node.inputs[1].type.dtype
......
...@@ -4,7 +4,10 @@ import numpy as np ...@@ -4,7 +4,10 @@ import numpy as np
from pytensor import config from pytensor import config
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import numba_funcify from pytensor.link.numba.dispatch.basic import (
numba_funcify,
register_funcify_default_op_cache_key,
)
from pytensor.link.numba.dispatch.linalg.decomposition.cholesky import _cholesky from pytensor.link.numba.dispatch.linalg.decomposition.cholesky import _cholesky
from pytensor.link.numba.dispatch.linalg.decomposition.lu import ( from pytensor.link.numba.dispatch.linalg.decomposition.lu import (
_lu_1, _lu_1,
...@@ -91,7 +94,7 @@ def numba_funcify_Cholesky(op, node, **kwargs): ...@@ -91,7 +94,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
return cholesky return cholesky
@numba_funcify.register(PivotToPermutations) @register_funcify_default_op_cache_key(PivotToPermutations)
def pivot_to_permutation(op, node, **kwargs): def pivot_to_permutation(op, node, **kwargs):
inverse = op.inverse inverse = op.inverse
dtype = node.outputs[0].dtype dtype = node.outputs[0].dtype
...@@ -119,7 +122,7 @@ def numba_funcify_LU(op, node, **kwargs): ...@@ -119,7 +122,7 @@ def numba_funcify_LU(op, node, **kwargs):
if dtype in complex_dtypes: if dtype in complex_dtypes:
NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit
def lu(a): def lu(a):
if check_finite: if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
...@@ -181,11 +184,10 @@ def numba_funcify_LUFactor(op, node, **kwargs): ...@@ -181,11 +184,10 @@ def numba_funcify_LUFactor(op, node, **kwargs):
return lu_factor return lu_factor
@numba_funcify.register(BlockDiagonal) @register_funcify_default_op_cache_key(BlockDiagonal)
def numba_funcify_BlockDiagonal(op, node, **kwargs): def numba_funcify_BlockDiagonal(op, node, **kwargs):
dtype = node.outputs[0].dtype dtype = node.outputs[0].dtype
# TODO: Why do we always inline all functions? It doesn't work with starred args, so can't use it in this case.
@numba_basic.numba_njit @numba_basic.numba_njit
def block_diag(*arrs): def block_diag(*arrs):
shapes = np.array([a.shape for a in arrs], dtype="int") shapes = np.array([a.shape for a in arrs], dtype="int")
...@@ -338,7 +340,7 @@ def numba_funcify_QR(op, node, **kwargs): ...@@ -338,7 +340,7 @@ def numba_funcify_QR(op, node, **kwargs):
integer_input = dtype in integer_dtypes integer_input = dtype in integer_dtypes
in_dtype = config.floatX if integer_input else dtype in_dtype = config.floatX if integer_input else dtype
@numba_basic.numba_njit(cache=False) @numba_basic.numba_njit
def qr(a): def qr(a):
if check_finite: if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
......
...@@ -3,11 +3,13 @@ import warnings ...@@ -3,11 +3,13 @@ import warnings
import numpy as np import numpy as np
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch import numba_funcify from pytensor.link.numba.dispatch.basic import (
register_funcify_default_op_cache_key,
)
from pytensor.tensor.sort import ArgSortOp, SortOp from pytensor.tensor.sort import ArgSortOp, SortOp
@numba_funcify.register(SortOp) @register_funcify_default_op_cache_key(SortOp)
def numba_funcify_SortOp(op, node, **kwargs): def numba_funcify_SortOp(op, node, **kwargs):
if op.kind != "quicksort": if op.kind != "quicksort":
warnings.warn( warnings.warn(
...@@ -31,7 +33,7 @@ def numba_funcify_SortOp(op, node, **kwargs): ...@@ -31,7 +33,7 @@ def numba_funcify_SortOp(op, node, **kwargs):
return sort_f return sort_f
@numba_funcify.register(ArgSortOp) @register_funcify_default_op_cache_key(ArgSortOp)
def numba_funcify_ArgSortOp(op, node, **kwargs): def numba_funcify_ArgSortOp(op, node, **kwargs):
kind = op.kind kind = op.kind
......
import operator import operator
import sys import sys
from hashlib import sha256
import numba import numba
import numpy as np import numpy as np
...@@ -7,11 +8,17 @@ from llvmlite import ir ...@@ -7,11 +8,17 @@ from llvmlite import ir
from numba import types from numba import types
from numba.core.pythonapi import box from numba.core.pythonapi import box
import pytensor.link.numba.dispatch.basic as numba_basic
from pytensor.graph import Type from pytensor.graph import Type
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.cache import (
from pytensor.link.numba.dispatch import numba_funcify compile_numba_function_src,
from pytensor.link.numba.dispatch.basic import generate_fallback_impl, numba_njit )
from pytensor.link.utils import compile_function_src, unique_name_generator from pytensor.link.numba.dispatch.basic import (
generate_fallback_impl,
register_funcify_and_cache_key,
register_funcify_default_op_cache_key,
)
from pytensor.link.utils import unique_name_generator
from pytensor.tensor import TensorType from pytensor.tensor import TensorType
from pytensor.tensor.rewriting.subtensor import is_full_slice from pytensor.tensor.rewriting.subtensor import is_full_slice
from pytensor.tensor.subtensor import ( from pytensor.tensor.subtensor import (
...@@ -98,7 +105,7 @@ def enable_slice_boxing(): ...@@ -98,7 +105,7 @@ def enable_slice_boxing():
enable_slice_boxing() enable_slice_boxing()
@numba_funcify.register(MakeSlice) @register_funcify_default_op_cache_key(MakeSlice)
def numba_funcify_MakeSlice(op, **kwargs): def numba_funcify_MakeSlice(op, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def makeslice(*x): def makeslice(*x):
...@@ -107,9 +114,32 @@ def numba_funcify_MakeSlice(op, **kwargs): ...@@ -107,9 +114,32 @@ def numba_funcify_MakeSlice(op, **kwargs):
return makeslice return makeslice
@numba_funcify.register(Subtensor) def subtensor_op_cache_key(op, **extra_fields):
@numba_funcify.register(IncSubtensor) key_parts = [type(op), tuple(extra_fields.items())]
@numba_funcify.register(AdvancedSubtensor1) if hasattr(op, "idx_list"):
idx_parts = []
for idx in op.idx_list:
if isinstance(idx, slice):
idx_parts.append(
(
idx.start is None,
idx.stop is None,
idx.step is None,
)
)
else:
idx_parts.append("i")
key_parts.append(tuple(idx_parts))
if isinstance(op, IncSubtensor | AdvancedIncSubtensor | AdvancedIncSubtensor1):
key_parts.append((op.inplace, op.set_instead_of_inc))
if isinstance(op, AdvancedIncSubtensor):
key_parts.append(op.ignore_duplicates)
return sha256(str(tuple(key_parts)).encode()).hexdigest()
@register_funcify_and_cache_key(Subtensor)
@register_funcify_and_cache_key(IncSubtensor)
@register_funcify_and_cache_key(AdvancedSubtensor1)
def numba_funcify_default_subtensor(op, node, **kwargs): def numba_funcify_default_subtensor(op, node, **kwargs):
"""Create a Python function that assembles and uses an index on an array.""" """Create a Python function that assembles and uses an index on an array."""
...@@ -185,16 +215,17 @@ def {function_name}({", ".join(input_names)}): ...@@ -185,16 +215,17 @@ def {function_name}({", ".join(input_names)}):
return np.asarray(z) return np.asarray(z)
""" """
func = compile_function_src( func = compile_numba_function_src(
subtensor_def_src, subtensor_def_src,
function_name=function_name, function_name=function_name,
global_env=globals() | {"np": np}, global_env=globals() | {"np": np},
) )
return numba_njit(func, boundscheck=True) cache_key = subtensor_op_cache_key(op, func="numba_funcify_default_subtensor")
return numba_basic.numba_njit(func, boundscheck=True), cache_key
@numba_funcify.register(AdvancedSubtensor) @register_funcify_and_cache_key(AdvancedSubtensor)
@numba_funcify.register(AdvancedIncSubtensor) @register_funcify_and_cache_key(AdvancedIncSubtensor)
def numba_funcify_AdvancedSubtensor(op, node, **kwargs): def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
if isinstance(op, AdvancedSubtensor): if isinstance(op, AdvancedSubtensor):
_x, _y, idxs = node.inputs[0], None, node.inputs[1:] _x, _y, idxs = node.inputs[0], None, node.inputs[1:]
...@@ -255,7 +286,9 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs): ...@@ -255,7 +286,9 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
) )
) )
): ):
return generate_fallback_impl(op, node, **kwargs) return generate_fallback_impl(op, node, **kwargs), subtensor_op_cache_key(
op, func="fallback_impl"
)
# What's left should all be supported natively by numba # What's left should all be supported natively by numba
return numba_funcify_default_subtensor(op, node, **kwargs) return numba_funcify_default_subtensor(op, node, **kwargs)
...@@ -295,6 +328,7 @@ def numba_funcify_multiple_integer_vector_indexing( ...@@ -295,6 +328,7 @@ def numba_funcify_multiple_integer_vector_indexing(
vector_indices = idxs[first_axis:after_last_axis] vector_indices = idxs[first_axis:after_last_axis]
assert all(v.type.broadcastable == (False,) for v in vector_indices) assert all(v.type.broadcastable == (False,) for v in vector_indices)
y_is_broadcasted = False
if isinstance(op, AdvancedSubtensor): if isinstance(op, AdvancedSubtensor):
...@@ -313,7 +347,7 @@ def numba_funcify_multiple_integer_vector_indexing( ...@@ -313,7 +347,7 @@ def numba_funcify_multiple_integer_vector_indexing(
out_buffer[(*none_slices, i)] = x[(*none_slices, *scalar_idxs)] out_buffer[(*none_slices, i)] = x[(*none_slices, *scalar_idxs)]
return out_buffer return out_buffer
return advanced_subtensor_multiple_vector ret_func = advanced_subtensor_multiple_vector
else: else:
inplace = op.inplace inplace = op.inplace
...@@ -347,7 +381,7 @@ def numba_funcify_multiple_integer_vector_indexing( ...@@ -347,7 +381,7 @@ def numba_funcify_multiple_integer_vector_indexing(
out[(*outer, *scalar_idxs)] = y[(*outer, i)] out[(*outer, *scalar_idxs)] = y[(*outer, i)]
return out return out
return advanced_set_subtensor_multiple_vector ret_func = advanced_set_subtensor_multiple_vector
else: else:
...@@ -369,10 +403,17 @@ def numba_funcify_multiple_integer_vector_indexing( ...@@ -369,10 +403,17 @@ def numba_funcify_multiple_integer_vector_indexing(
out[(*outer, *scalar_idxs)] += y[(*outer, i)] out[(*outer, *scalar_idxs)] += y[(*outer, i)]
return out return out
return advanced_inc_subtensor_multiple_vector ret_func = advanced_inc_subtensor_multiple_vector
cache_key = subtensor_op_cache_key(
op,
func="multiple_integer_vector_indexing",
y_is_broadcasted=y_is_broadcasted,
)
return ret_func, cache_key
@numba_funcify.register(AdvancedIncSubtensor1) @register_funcify_and_cache_key(AdvancedIncSubtensor1)
def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs): def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
inplace = op.inplace inplace = op.inplace
set_instead_of_inc = op.set_instead_of_inc set_instead_of_inc = op.set_instead_of_inc
...@@ -436,8 +477,14 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs): ...@@ -436,8 +477,14 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
x[idx] += val x[idx] += val
return x return x
cache_key = subtensor_op_cache_key(
op,
func="numba_funcify_advancedincsubtensor1",
broadcast_with_index=broadcast_with_index,
)
if inplace: if inplace:
return advancedincsubtensor1_inplace return advancedincsubtensor1_inplace, cache_key
else: else:
...@@ -446,4 +493,4 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs): ...@@ -446,4 +493,4 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
x = x.copy() x = x.copy()
return advancedincsubtensor1_inplace(x, vals, idxs) return advancedincsubtensor1_inplace(x, vals, idxs)
return advancedincsubtensor1 return advancedincsubtensor1, cache_key
from hashlib import sha256
from textwrap import indent from textwrap import indent
import numpy as np import numpy as np
from pytensor.link.numba.cache import compile_numba_function_src
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import ( from pytensor.link.numba.dispatch.basic import (
create_tuple_string, create_tuple_string,
numba_funcify, register_funcify_and_cache_key,
register_funcify_default_op_cache_key,
) )
from pytensor.link.utils import compile_function_src, unique_name_generator from pytensor.link.utils import unique_name_generator
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
Alloc, Alloc,
AllocEmpty, AllocEmpty,
...@@ -23,7 +26,7 @@ from pytensor.tensor.basic import ( ...@@ -23,7 +26,7 @@ from pytensor.tensor.basic import (
) )
@numba_funcify.register(AllocEmpty) @register_funcify_default_op_cache_key(AllocEmpty)
def numba_funcify_AllocEmpty(op, node, **kwargs): def numba_funcify_AllocEmpty(op, node, **kwargs):
global_env = { global_env = {
"np": np, "np": np,
...@@ -52,14 +55,14 @@ def allocempty({", ".join(shape_var_names)}): ...@@ -52,14 +55,14 @@ def allocempty({", ".join(shape_var_names)}):
return np.empty(scalar_shape, dtype) return np.empty(scalar_shape, dtype)
""" """
alloc_fn = compile_function_src( alloc_fn = compile_numba_function_src(
alloc_def_src, "allocempty", {**globals(), **global_env} alloc_def_src, "allocempty", {**globals(), **global_env}
) )
return numba_basic.numba_njit(alloc_fn) return numba_basic.numba_njit(alloc_fn)
@numba_funcify.register(Alloc) @register_funcify_and_cache_key(Alloc)
def numba_funcify_Alloc(op, node, **kwargs): def numba_funcify_Alloc(op, node, **kwargs):
global_env = {"np": np} global_env = {"np": np}
...@@ -96,16 +99,23 @@ def alloc(val, {", ".join(shape_var_names)}): ...@@ -96,16 +99,23 @@ def alloc(val, {", ".join(shape_var_names)}):
res[...] = val res[...] = val
return res return res
""" """
alloc_fn = compile_function_src(alloc_def_src, "alloc", {**globals(), **global_env}) alloc_fn = compile_numba_function_src(
alloc_def_src,
"alloc",
{**globals(), **global_env},
)
return numba_basic.numba_njit(alloc_fn) cache_key = sha256(
str((type(op), node.inputs[0].type.broadcastable)).encode()
).hexdigest()
return numba_basic.numba_njit(alloc_fn), cache_key
@numba_funcify.register(ARange) @register_funcify_default_op_cache_key(ARange)
def numba_funcify_ARange(op, **kwargs): def numba_funcify_ARange(op, **kwargs):
dtype = np.dtype(op.dtype) dtype = np.dtype(op.dtype)
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit
def arange(start, stop, step): def arange(start, stop, step):
return np.arange( return np.arange(
start.item(), start.item(),
...@@ -117,7 +127,7 @@ def numba_funcify_ARange(op, **kwargs): ...@@ -117,7 +127,7 @@ def numba_funcify_ARange(op, **kwargs):
return arange return arange
@numba_funcify.register(Join) @register_funcify_default_op_cache_key(Join)
def numba_funcify_Join(op, **kwargs): def numba_funcify_Join(op, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def join(axis, *tensors): def join(axis, *tensors):
...@@ -126,7 +136,7 @@ def numba_funcify_Join(op, **kwargs): ...@@ -126,7 +136,7 @@ def numba_funcify_Join(op, **kwargs):
return join return join
@numba_funcify.register(Split) @register_funcify_default_op_cache_key(Split)
def numba_funcify_Split(op, **kwargs): def numba_funcify_Split(op, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def split(tensor, axis, indices): def split(tensor, axis, indices):
...@@ -135,14 +145,14 @@ def numba_funcify_Split(op, **kwargs): ...@@ -135,14 +145,14 @@ def numba_funcify_Split(op, **kwargs):
return split return split
@numba_funcify.register(ExtractDiag) @register_funcify_default_op_cache_key(ExtractDiag)
def numba_funcify_ExtractDiag(op, node, **kwargs): def numba_funcify_ExtractDiag(op, node, **kwargs):
view = op.view view = op.view
axis1, axis2, offset = op.axis1, op.axis2, op.offset axis1, axis2, offset = op.axis1, op.axis2, op.offset
if node.inputs[0].type.ndim == 2: if node.inputs[0].type.ndim == 2:
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit
def extract_diag(x): def extract_diag(x):
out = np.diag(x, k=offset) out = np.diag(x, k=offset)
...@@ -157,7 +167,7 @@ def numba_funcify_ExtractDiag(op, node, **kwargs): ...@@ -157,7 +167,7 @@ def numba_funcify_ExtractDiag(op, node, **kwargs):
leading_dims = (slice(None),) * axis1 leading_dims = (slice(None),) * axis1
middle_dims = (slice(None),) * (axis2 - axis1 - 1) middle_dims = (slice(None),) * (axis2 - axis1 - 1)
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit
def extract_diag(x): def extract_diag(x):
if offset >= 0: if offset >= 0:
diag_len = min(x.shape[axis1], max(0, x.shape[axis2] - offset)) diag_len = min(x.shape[axis1], max(0, x.shape[axis2] - offset))
...@@ -178,11 +188,11 @@ def numba_funcify_ExtractDiag(op, node, **kwargs): ...@@ -178,11 +188,11 @@ def numba_funcify_ExtractDiag(op, node, **kwargs):
return extract_diag return extract_diag
@numba_funcify.register(Eye) @register_funcify_default_op_cache_key(Eye)
def numba_funcify_Eye(op, **kwargs): def numba_funcify_Eye(op, **kwargs):
dtype = np.dtype(op.dtype) dtype = np.dtype(op.dtype)
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit
def eye(N, M, k): def eye(N, M, k):
return np.eye( return np.eye(
N.item(), N.item(),
...@@ -194,7 +204,7 @@ def numba_funcify_Eye(op, **kwargs): ...@@ -194,7 +204,7 @@ def numba_funcify_Eye(op, **kwargs):
return eye return eye
@numba_funcify.register(MakeVector) @register_funcify_default_op_cache_key(MakeVector)
def numba_funcify_MakeVector(op, node, **kwargs): def numba_funcify_MakeVector(op, node, **kwargs):
dtype = np.dtype(op.dtype) dtype = np.dtype(op.dtype)
...@@ -215,32 +225,34 @@ def makevector({", ".join(input_names)}): ...@@ -215,32 +225,34 @@ def makevector({", ".join(input_names)}):
return np.array({create_list_string(input_names)}, dtype=dtype) return np.array({create_list_string(input_names)}, dtype=dtype)
""" """
makevector_fn = compile_function_src( makevector_fn = compile_numba_function_src(
makevector_def_src, "makevector", {**globals(), **global_env} makevector_def_src,
"makevector",
{**globals(), **global_env},
) )
return numba_basic.numba_njit(makevector_fn) return numba_basic.numba_njit(makevector_fn)
@numba_funcify.register(TensorFromScalar) @register_funcify_default_op_cache_key(TensorFromScalar)
def numba_funcify_TensorFromScalar(op, **kwargs): def numba_funcify_TensorFromScalar(op, **kwargs):
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit
def tensor_from_scalar(x): def tensor_from_scalar(x):
return np.array(x) return np.array(x)
return tensor_from_scalar return tensor_from_scalar
@numba_funcify.register(ScalarFromTensor) @register_funcify_default_op_cache_key(ScalarFromTensor)
def numba_funcify_ScalarFromTensor(op, **kwargs): def numba_funcify_ScalarFromTensor(op, **kwargs):
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit
def scalar_from_tensor(x): def scalar_from_tensor(x):
return x.item() return x.item()
return scalar_from_tensor return scalar_from_tensor
@numba_funcify.register(Nonzero) @register_funcify_default_op_cache_key(Nonzero)
def numba_funcify_Nonzero(op, node, **kwargs): def numba_funcify_Nonzero(op, node, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def nonzero(a): def nonzero(a):
......
...@@ -4,7 +4,7 @@ import base64 ...@@ -4,7 +4,7 @@ import base64
import pickle import pickle
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from textwrap import indent from textwrap import indent
from typing import Any, cast from typing import Any
import numba import numba
import numpy as np import numpy as np
...@@ -15,8 +15,8 @@ from numba.core.base import BaseContext ...@@ -15,8 +15,8 @@ from numba.core.base import BaseContext
from numba.core.types.misc import NoneType from numba.core.types.misc import NoneType
from numba.np import arrayobj from numba.np import arrayobj
from pytensor.link.numba.cache import compile_numba_function_src
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.utils import compile_function_src
def encode_literals(literals: Sequence) -> str: def encode_literals(literals: Sequence) -> str:
...@@ -52,10 +52,13 @@ def store_core_outputs({inp_signature}, {out_signature}): ...@@ -52,10 +52,13 @@ def store_core_outputs({inp_signature}, {out_signature}):
{indent(store_outputs, " " * 4)} {indent(store_outputs, " " * 4)}
""" """
global_env = {"core_op_fn": core_op_fn} global_env = {"core_op_fn": core_op_fn}
func = compile_function_src(
func_src, "store_core_outputs", {**globals(), **global_env} func = compile_numba_function_src(
func_src,
"store_core_outputs",
{**globals(), **global_env},
) )
return cast(Callable, numba_basic.numba_njit(func)) return numba_basic.numba_njit(func)
_jit_options = { _jit_options = {
...@@ -74,7 +77,7 @@ _jit_options = { ...@@ -74,7 +77,7 @@ _jit_options = {
@numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True) @numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True)
def _vectorized( def _vectorized(
typingctx, typingctx,
scalar_func, core_func,
input_bc_patterns, input_bc_patterns,
output_bc_patterns, output_bc_patterns,
output_dtypes, output_dtypes,
...@@ -85,7 +88,7 @@ def _vectorized( ...@@ -85,7 +88,7 @@ def _vectorized(
size_type, size_type,
): ):
arg_types = [ arg_types = [
scalar_func, core_func,
input_bc_patterns, input_bc_patterns,
output_bc_patterns, output_bc_patterns,
output_dtypes, output_dtypes,
...@@ -173,16 +176,6 @@ def _vectorized( ...@@ -173,16 +176,6 @@ def _vectorized(
) )
out_types[output_idx] = output_type out_types[output_idx] = output_type
core_signature = typingctx.resolve_function_type(
scalar_func,
[
*constant_inputs_types,
*core_input_types,
*core_out_types,
],
{},
)
ret_type = types.Tuple(out_types) ret_type = types.Tuple(out_types)
if len(output_dtypes) == 1: if len(output_dtypes) == 1:
...@@ -239,11 +232,21 @@ def _vectorized( ...@@ -239,11 +232,21 @@ def _vectorized(
output_core_shapes, output_core_shapes,
) )
core_signature = typingctx.resolve_function_type(
core_func,
[
*constant_inputs_types,
*core_input_types,
*core_out_types,
],
{},
)
make_loop_call( make_loop_call(
typingctx, typingctx,
ctx, ctx,
builder, builder,
scalar_func, core_func,
core_signature, core_signature,
iter_shape, iter_shape,
constant_inputs, constant_inputs,
...@@ -416,8 +419,8 @@ def make_loop_call( ...@@ -416,8 +419,8 @@ def make_loop_call(
typingctx, typingctx,
context: numba.core.base.BaseContext, context: numba.core.base.BaseContext,
builder: ir.IRBuilder, builder: ir.IRBuilder,
scalar_func: Any, core_func: Any,
scalar_signature: types.FunctionType, core_signature: types.FunctionType,
iter_shape: tuple[ir.Instruction, ...], iter_shape: tuple[ir.Instruction, ...],
constant_inputs: tuple[ir.Instruction, ...], constant_inputs: tuple[ir.Instruction, ...],
inputs: tuple[ir.Instruction, ...], inputs: tuple[ir.Instruction, ...],
...@@ -557,10 +560,10 @@ def make_loop_call( ...@@ -557,10 +560,10 @@ def make_loop_call(
val = core_array._getvalue() val = core_array._getvalue()
output_slices.append(val) output_slices.append(val)
inner_codegen = context.get_function(scalar_func, scalar_signature) inner_codegen = context.get_function(core_func, core_signature)
if isinstance(scalar_signature.args[0], types.StarArgTuple | types.StarArgUniTuple): if isinstance(core_signature.args[0], types.StarArgTuple | types.StarArgUniTuple):
input_vals = [context.make_tuple(builder, scalar_signature.args[0], input_vals)] input_vals = [context.make_tuple(builder, core_signature.args[0], input_vals)]
inner_codegen(builder, [*constant_inputs, *input_vals, *output_slices]) inner_codegen(builder, [*constant_inputs, *input_vals, *output_slices])
......
...@@ -13,21 +13,32 @@ from tests.tensor.signal.test_conv import convolve1d_grad_benchmarker ...@@ -13,21 +13,32 @@ from tests.tensor.signal.test_conv import convolve1d_grad_benchmarker
pytestmark = pytest.mark.filterwarnings("error") pytestmark = pytest.mark.filterwarnings("error")
@pytest.mark.parametrize("bcast_order", (1, 0))
@pytest.mark.parametrize("mode", ["full", "valid", "same"]) @pytest.mark.parametrize("mode", ["full", "valid", "same"])
@pytest.mark.parametrize("x_smaller", (False, True)) def test_convolve1d(mode, bcast_order):
def test_convolve1d(x_smaller, mode):
x = dmatrix("x") x = dmatrix("x")
y = dmatrix("y") y = dmatrix("y")
if x_smaller: # Testing two orders because this revealed a bug in the past
out = convolve1d(x[None], y[:, None], mode=mode) if bcast_order == 0:
out = convolve1d(x[:, None], y[None, :], mode=mode)
else: else:
out = convolve1d(y[:, None], x[None], mode=mode) out = convolve1d(x[None], y[:, None], mode=mode)
rng = np.random.default_rng() rng = np.random.default_rng()
test_x = rng.normal(size=(3, 5)) test_x = rng.normal(size=(3, 5))
test_y = rng.normal(size=(7, 11)) test_y = rng.normal(size=(7, 11))
# Blockwise dispatch for numba can't be run on object mode # Blockwise dispatch for numba can't be run on object mode
compare_numba_and_py([x, y], out, [test_x, test_y], eval_obj_mode=False) numba_fn, res = compare_numba_and_py(
[x, y], out, [test_x, test_y], eval_obj_mode=False
)
# Try other order of inputs, as implementation depends on it
# Result should be the same, just in different order, except for 'same' mode
if mode != "same":
np.testing.assert_allclose(
np.swapaxes(numba_fn(test_y, test_x), 0, 1),
res,
)
@pytest.mark.parametrize("mode", ("full", "valid"), ids=lambda x: f"mode={x}") @pytest.mark.parametrize("mode", ("full", "valid"), ids=lambda x: f"mode={x}")
......
...@@ -402,7 +402,9 @@ def test_config_options_fastmath(): ...@@ -402,7 +402,9 @@ def test_config_options_fastmath():
with config.change_flags(numba__fastmath=True): with config.change_flags(numba__fastmath=True):
pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode) pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode)
numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"] numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__[
"jitable_func"
].py_func.__globals__["impl_sum"]
assert numba_sum_fn.targetoptions["fastmath"] == { assert numba_sum_fn.targetoptions["fastmath"] == {
"afn", "afn",
"arcp", "arcp",
...@@ -413,7 +415,9 @@ def test_config_options_fastmath(): ...@@ -413,7 +415,9 @@ def test_config_options_fastmath():
with config.change_flags(numba__fastmath=False): with config.change_flags(numba__fastmath=False):
pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode) pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode)
numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"] numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__[
"jitable_func"
].py_func.__globals__["impl_sum"]
assert numba_sum_fn.targetoptions["fastmath"] is False assert numba_sum_fn.targetoptions["fastmath"] is False
...@@ -422,9 +426,10 @@ def test_config_options_cached(): ...@@ -422,9 +426,10 @@ def test_config_options_cached():
with config.change_flags(numba__cache=True): with config.change_flags(numba__cache=True):
pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode) pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode)
numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"] numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__[
# Caching is disabled unless the dispatched function returns an explicit cache key "jitable_func"
assert isinstance(numba_sum_fn._cache, numba.core.caching.NullCache) ].py_func.__globals__["impl_sum"]
assert not isinstance(numba_sum_fn._cache, numba.core.caching.NullCache)
with config.change_flags(numba__cache=False): with config.change_flags(numba__cache=False):
pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode) pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论