提交 d4696e6b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Split aesara.link.numba.dispatch into distinct modules

上级 0ae63d1b
import inspect
import operator
import warnings
from functools import reduce, singledispatch
from numbers import Number
from textwrap import dedent, indent
from typing import Any, Callable, Dict, List, Optional, Union
import numba
import numpy as np
import scipy
import scipy.special
from llvmlite.llvmpy.core import Type as llvm_Type
from numba import _helperlib, types
from numba.core.errors import TypingError
from numba.cpython.unsafe.tuple import tuple_setitem
from numba.extending import box
from numba.np.unsafe.ndarray import to_fixed_tuple
from numpy.core.multiarray import normalize_axis_index
from numpy.random import RandomState
import aesara.tensor.random.basic as aer
from aesara.compile.ops import DeepCopyOp, ViewOp
from aesara.graph.basic import Apply, Variable
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.type import Type
from aesara.link.utils import (
compile_function_src,
fgraph_to_python,
get_name_for_object,
unique_name_generator,
)
from aesara.scalar.basic import (
Add,
Cast,
Clip,
Composite,
Identity,
Mul,
Scalar,
ScalarOp,
Second,
Switch,
)
from aesara.scalar.math import Softplus
from aesara.tensor.basic import (
Alloc,
AllocDiag,
AllocEmpty,
ARange,
ExtractDiag,
Eye,
Join,
MakeVector,
Rebroadcast,
ScalarFromTensor,
TensorFromScalar,
get_vector_length,
)
from aesara.tensor.blas import BatchedDot
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.extra_ops import (
Bartlett,
CumOp,
DiffOp,
FillDiagonal,
FillDiagonalOffset,
RavelMultiIndex,
Repeat,
SearchsortedOp,
Unique,
UnravelIndex,
)
from aesara.tensor.math import Dot, MaxAndArgmax
from aesara.tensor.nlinalg import (
SVD,
Det,
Eig,
Eigh,
Inv,
MatrixInverse,
MatrixPinv,
QRFull,
)
from aesara.tensor.nnet.basic import LogSoftmax, Softmax
from aesara.tensor.random.type import RandomStateType
from aesara.tensor.random.var import RandomStateSharedVariable
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from aesara.tensor.slinalg import Cholesky, Solve
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
AdvancedSubtensor,
AdvancedSubtensor1,
IncSubtensor,
Subtensor,
)
from aesara.tensor.type import TensorType, tensor
from aesara.tensor.type_other import MakeSlice
def get_numba_type(
aesara_type: Type, layout: str = "A", force_scalar: bool = False
) -> numba.types.Type:
"""Create a Numba type object for a ``Type``."""
if isinstance(aesara_type, TensorType):
dtype = aesara_type.numpy_dtype
numba_dtype = numba.from_dtype(dtype)
if force_scalar:
return numba_dtype
return numba.types.Array(numba_dtype, aesara_type.ndim, layout)
elif isinstance(aesara_type, Scalar):
dtype = np.dtype(aesara_type.dtype)
numba_dtype = numba.from_dtype(dtype)
return numba_dtype
else:
raise NotImplementedError(f"Numba type not implemented for {aesara_type}")
def create_numba_signature(node: Apply, force_scalar: bool = False) -> numba.types.Type:
"""Create a Numba type for the signature of an ``Apply`` node."""
input_types = []
for inp in node.inputs:
input_types.append(get_numba_type(inp.type, force_scalar=force_scalar))
output_types = []
for out in node.outputs:
output_types.append(get_numba_type(out.type, force_scalar=force_scalar))
if len(output_types) > 1:
return numba.types.Tuple(output_types)(*input_types)
elif len(output_types) == 1:
return output_types[0](*input_types)
else:
return numba.types.void(*input_types)
def slice_new(self, start, stop, step):
fnty = llvm_Type.function(self.pyobj, [self.pyobj, self.pyobj, self.pyobj])
fn = self._get_function(fnty, name="PySlice_New")
return self.builder.call(fn, [start, stop, step])
def enable_slice_boxing():
"""Enable boxing for Numba's native ``slice``s.
TODO: this can be removed when https://github.com/numba/numba/pull/6939 is
merged and a release is made.
"""
@box(types.SliceType)
def box_slice(typ, val, c):
"""Implement boxing for ``slice`` objects in Numba.
This makes it possible to return an Numba's internal representation of a
``slice`` object as a proper ``slice`` to Python.
"""
start = c.box(types.int64, c.builder.extract_value(val, 0))
stop = c.box(types.int64, c.builder.extract_value(val, 1))
if typ.has_step:
step = c.box(types.int64, c.builder.extract_value(val, 2))
else:
step = c.pyapi.get_null_object()
slice_val = slice_new(c.pyapi, start, stop, step)
return slice_val
@numba.extending.overload(operator.contains)
def in_seq_empty_tuple(x, y):
if isinstance(x, types.Tuple) and not x.types:
return lambda x, y: False
enable_slice_boxing()
@numba.generated_jit(nopython=True)
def to_scalar(x):
if isinstance(x, (numba.types.Number, numba.types.Boolean)):
return lambda x: x
elif isinstance(x, numba.types.Array):
return lambda x: x.item()
else:
raise TypingError(f"{x} must be a scalar compatible type.")
def enable_slice_literals():
"""Enable lowering for ``SliceLiteral``s.
TODO: This can be removed once https://github.com/numba/numba/pull/6996 is merged
and a release is made.
"""
from numba.core import types
from numba.core.datamodel.models import SliceModel
from numba.core.datamodel.registry import register_default
from numba.core.imputils import lower_cast, lower_constant
from numba.core.types.misc import SliceLiteral
from numba.cpython.slicing import get_defaults
register_default(numba.types.misc.SliceLiteral)(SliceModel)
@property
def key(self):
return self.name
SliceLiteral.key = key
def make_slice_from_constant(context, builder, ty, pyval):
sli = context.make_helper(builder, ty)
lty = context.get_value_type(types.intp)
(
default_start_pos,
default_start_neg,
default_stop_pos,
default_stop_neg,
default_step,
) = [context.get_constant(types.intp, x) for x in get_defaults(context)]
step = pyval.step
if step is None:
step_is_neg = False
step = default_step
else:
step_is_neg = step < 0
step = lty(step)
start = pyval.start
if start is None:
if step_is_neg:
start = default_start_neg
else:
start = default_start_pos
else:
start = lty(start)
stop = pyval.stop
if stop is None:
if step_is_neg:
stop = default_stop_neg
else:
stop = default_stop_pos
else:
stop = lty(stop)
sli.start = start
sli.stop = stop
sli.step = step
return sli._getvalue()
@lower_constant(numba.types.SliceType)
def constant_slice(context, builder, ty, pyval):
if isinstance(ty, types.Literal):
typ = ty.literal_type
else:
typ = ty
return make_slice_from_constant(context, builder, typ, pyval)
@lower_cast(numba.types.misc.SliceLiteral, numba.types.SliceType)
def cast_from_literal(context, builder, fromty, toty, val):
return make_slice_from_constant(
context,
builder,
toty,
fromty.literal_value,
)
enable_slice_literals()
def create_tuple_creator(f, n):
"""Construct a compile-time ``tuple``-comprehension-like loop.
See https://github.com/numba/numba/issues/2771#issuecomment-414358902
"""
assert n > 0
f = numba.njit(f)
@numba.njit
def creator(args):
return (f(0, *args),)
for i in range(1, n):
@numba.njit
def creator(args, creator=creator, i=i):
return creator(args) + (f(i, *args),)
return numba.njit(lambda *args: creator(args))
def create_tuple_string(x):
args = ", ".join(x + ([""] if len(x) == 1 else []))
return f"({args})"
@singledispatch
def numba_typify(data, dtype=None, **kwargs):
return data
@numba_typify.register(RandomState)
def numba_typify_RandomState(state, **kwargs):
ints, index = state.get_state()[1:3]
ptr = _helperlib.rnd_get_np_state_ptr()
_helperlib.rnd_set_state(ptr, (index, [int(x) for x in ints]))
return ints
@singledispatch
def numba_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a Numba compatible function from an Aesara `Op`."""
warnings.warn(
f"Numba will use object mode to run {op}'s perform method",
UserWarning,
)
n_outputs = len(node.outputs)
if n_outputs > 1:
ret_sig = numba.types.Tuple([get_numba_type(o.type) for o in node.outputs])
else:
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
def perform(*inputs):
with numba.objmode(ret=ret_sig):
outputs = [[None] for i in range(n_outputs)]
op.perform(node, inputs, outputs)
outputs = tuple([o[0] for o in outputs])
if n_outputs == 1:
ret = outputs[0]
else:
ret = outputs
return ret
return perform
@numba_funcify.register(FunctionGraph)
def numba_funcify_FunctionGraph(
fgraph,
node=None,
fgraph_name="numba_funcified_fgraph",
**kwargs,
):
return fgraph_to_python(
fgraph,
numba_funcify,
type_conversion_fn=numba_typify,
fgraph_name=fgraph_name,
**kwargs,
)
@numba_funcify.register(ScalarOp)
def numba_funcify_ScalarOp(op, node, **kwargs):
# TODO: Do we need to cache these functions so that we don't end up
# compiling the same Numba function over and over again?
scalar_func_name = op.nfunc_spec[0]
if scalar_func_name.startswith("scipy."):
func_package = scipy
scalar_func_name = scalar_func_name.split(".", 1)[-1]
else:
func_package = np
if "." in scalar_func_name:
scalar_func = reduce(getattr, [scipy] + scalar_func_name.split("."))
else:
scalar_func = getattr(func_package, scalar_func_name)
scalar_op_fn_name = get_name_for_object(scalar_func)
unique_names = unique_name_generator(
[scalar_op_fn_name, "scalar_func"], suffix_sep="_"
)
input_names = ", ".join([unique_names(v, force_unique=True) for v in node.inputs])
global_env = {"scalar_func": scalar_func}
scalar_op_src = f"""
def {scalar_op_fn_name}({input_names}):
return scalar_func({input_names})
"""
scalar_op_fn = compile_function_src(scalar_op_src, scalar_op_fn_name, global_env)
signature = create_numba_signature(node, force_scalar=True)
return numba.njit(signature, inline="always")(scalar_op_fn)
@numba_funcify.register(Switch)
def numba_funcify_Switch(op, node, **kwargs):
@numba.njit(inline="always")
def switch(condition, x, y):
if condition:
return x
else:
return y
return switch
def binary_to_nary_func(inputs: List[Variable], binary_op_name: str, binary_op: str):
"""Create a Numba-compatible N-ary function from a binary function."""
unique_names = unique_name_generator(["binary_op_name"], suffix_sep="_")
input_names = [unique_names(v, force_unique=True) for v in inputs]
input_signature = ", ".join(input_names)
output_expr = binary_op.join(input_names)
nary_src = f"""
def {binary_op_name}({input_signature}):
return {output_expr}
"""
nary_fn = compile_function_src(nary_src, binary_op_name)
return nary_fn
@numba_funcify.register(Add)
def numba_funcify_Add(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True)
nary_add_fn = binary_to_nary_func(node.inputs, "add", "+")
return numba.njit(signature, inline="always")(nary_add_fn)
@numba_funcify.register(Mul)
def numba_funcify_Mul(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True)
nary_mul_fn = binary_to_nary_func(node.inputs, "mul", "*")
return numba.njit(signature, inline="always")(nary_mul_fn)
def create_vectorize_func(op, node, use_signature=False, identity=None, **kwargs):
scalar_op_fn = numba_funcify(op.scalar_op, node=node, inline="always", **kwargs)
if len(node.outputs) > 1:
raise NotImplementedError(
"Multi-output Elemwise Ops are not supported by the Numba backend"
)
if use_signature:
signature = [create_numba_signature(node, force_scalar=True)]
else:
signature = []
numba_vectorize = numba.vectorize(signature, identity=identity)
elemwise_fn = numba_vectorize(scalar_op_fn)
elemwise_fn.py_scalar_func = scalar_op_fn
return elemwise_fn
@numba_funcify.register(Elemwise)
def numba_funcify_Elemwise(op, node, **kwargs):
elemwise_fn = create_vectorize_func(op, node, use_signature=False)
elemwise_fn_name = elemwise_fn.__name__
if op.inplace_pattern:
input_idx = op.inplace_pattern[0]
sign_obj = inspect.signature(elemwise_fn.py_scalar_func)
input_names = list(sign_obj.parameters.keys())
unique_names = unique_name_generator([elemwise_fn_name, "np"], suffix_sep="_")
input_names = [unique_names(i, force_unique=True) for i in input_names]
updated_input_name = input_names[input_idx]
inplace_global_env = {elemwise_fn_name: elemwise_fn, "np": np}
inplace_elemwise_fn_name = f"{elemwise_fn_name}_inplace"
input_signature_str = ", ".join(input_names)
if node.inputs[input_idx].ndim > 0:
inplace_elemwise_src = f"""
def {inplace_elemwise_fn_name}({input_signature_str}):
return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name})
"""
else:
# We can't perform in-place updates on Numba scalars, so we need to
# convert them to NumPy scalars.
# TODO: We should really prevent the rewrites from creating
# in-place updates on scalars when the Numba mode is selected (or
# in general?).
inplace_elemwise_src = f"""
def {inplace_elemwise_fn_name}({input_signature_str}):
{updated_input_name}_scalar = np.asarray({updated_input_name})
return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name}_scalar).item()
"""
inplace_elemwise_fn = compile_function_src(
inplace_elemwise_src, inplace_elemwise_fn_name, inplace_global_env
)
return numba.njit(inline="always")(inplace_elemwise_fn)
return elemwise_fn
def create_axis_reducer(
reduce_fn: numba.np.ufunc.dufunc.DUFunc,
identity: Union[np.ndarray, Number],
axis: int,
ndim: int,
dtype: numba.types.Type,
keepdims: bool = False,
) -> numba.core.dispatcher.Dispatcher:
r"""Create a Numba JITed function that performs a NumPy reduction on a given axis.
The functions generated by this function take the following form:
.. code-block:: python
def careduce_axis(x):
res_shape = tuple(shape[i] if i < axis else shape[i + 1] for i in range(ndim - 1))
res = np.full(res_shape, identity, dtype=dtype)
x_axis_first = x.transpose(reaxis_first)
for m in range(x.shape[axis]):
reduce_fn(res, x_axis_first[m], res)
if keepdims:
return np.expand_dims(res, axis)
else:
return res
This can be removed/replaced when
https://github.com/numba/numba/issues/4504 is implemented.
Parameters
==========
reduce_fn:
The Numba ``ufunc`` representing a binary op that can perform the
reduction on arbitrary ``ndarray``\s.
identity:
The identity value for the reduction.
axis:
The axis to reduce.
ndim:
The number of dimensions of the result.
dtype:
The data type of the result.
keepdims:
Determines whether or not the reduced dimension is retained.
"""
if ndim > 1:
if keepdims:
@numba.njit(inline="always")
def set_out_dims(x):
return np.expand_dims(x, axis)
else:
@numba.njit(inline="always")
def set_out_dims(x):
return x
res_shape_tuple_ctor = create_tuple_creator(
lambda i, shape: shape[i] if i < axis else shape[i + 1], ndim - 1
)
reaxis_first = (axis,) + tuple(i for i in range(ndim) if i != axis)
@numba.njit(boundscheck=False)
def careduce_axis(x):
res_shape = res_shape_tuple_ctor(x.shape)
x_axis_first = x.transpose(reaxis_first)
res = np.full(res_shape, to_scalar(identity), dtype=dtype)
for m in range(x.shape[axis]):
reduce_fn(res, x_axis_first[m], res)
return set_out_dims(res)
else:
if keepdims:
@numba.njit(inline="always")
def set_out_dims(x):
return np.array([x], dtype)
else:
@numba.njit(inline="always")
def set_out_dims(x):
return direct_cast(x, dtype)
@numba.njit(boundscheck=False)
def careduce_axis(x):
res = to_scalar(identity)
for val in x:
res = reduce_fn(res, val)
return set_out_dims(res)
return careduce_axis
def create_multiaxis_reducer(
reduce_fn, identity, axes, ndim, dtype, input_name="input"
):
r"""Construct a function that reduces multiple axes.
The functions generated by this function take the following form:
.. code-block:: python
def careduce_maximum(input):
axis_0_res = careduce_axes_fn_0(input)
axis_1_res = careduce_axes_fn_1(axis_0_res)
...
axis_N_res = careduce_axes_fn_N(axis_N_minus_1_res)
return axis_N_res
The range 0-N is determined by the `axes` argument (i.e. the
axes to be reduced).
Parameters
==========
reduce_fn:
The Numba ``ufunc`` representing a binary op that can perform the
reduction on arbitrary ``ndarray``\s.
identity:
The identity value for the reduction.
axes:
The axes to reduce.
ndim:
The number of dimensions of the result.
dtype:
The data type of the result.
"""
if len(axes) == 1:
return create_axis_reducer(reduce_fn, identity, axes[0], ndim, dtype)
careduce_fn_name = f"careduce_{get_name_for_object(reduce_fn)}"
global_env = {}
to_reduce = reversed(sorted(axes))
careduce_lines_src = []
var_name = input_name
for i, axis in enumerate(to_reduce):
careducer_axes_fn_name = f"careduce_axes_fn_{i}"
global_env[careducer_axes_fn_name] = create_axis_reducer(
reduce_fn, identity, axis - i, ndim, dtype
)
ndim -= 1
last_var_name = var_name
var_name = f"axis_{i}_res"
careduce_lines_src.append(
f"{var_name} = {careducer_axes_fn_name}({last_var_name})"
)
careduce_assign_lines = indent("\n".join(careduce_lines_src), " " * 4)
careduce_def_src = f"""
def {careduce_fn_name}({input_name}):
{careduce_assign_lines}
return {var_name}
"""
careduce_fn = compile_function_src(careduce_def_src, careduce_fn_name, global_env)
return numba.njit(careduce_fn)
@numba_funcify.register(CAReduce)
def numba_funcify_CAReduce(op, node, **kwargs):
axes = op.axis
if axes is None:
axes = list(range(node.inputs[0].ndim))
if hasattr(op, "acc_dtype") and op.acc_dtype is not None:
acc_dtype = op.acc_dtype
else:
acc_dtype = node.outputs[0].type.dtype
np_acc_dtype = np.dtype(acc_dtype)
scalar_op_identity = np.asarray(op.scalar_op.identity, dtype=np_acc_dtype)
scalar_nfunc_spec = op.scalar_op.nfunc_spec
# We construct a dummy `Apply` that has the minimum required number of
# inputs for the scalar `Op`. Without this, we would get a scalar function
# with too few arguments.
dummy_node = Apply(
op,
[tensor(np_acc_dtype, [False]) for i in range(scalar_nfunc_spec[1])],
[tensor(np_acc_dtype, [False]) for o in range(scalar_nfunc_spec[2])],
)
# TODO: Use `scalar_op_identity`?
elemwise_fn = create_vectorize_func(op, dummy_node, use_signature=True, **kwargs)
input_name = get_name_for_object(node.inputs[0])
ndim = node.inputs[0].ndim
careduce_fn = create_multiaxis_reducer(
elemwise_fn, scalar_op_identity, axes, ndim, np_acc_dtype, input_name=input_name
)
return careduce_fn
@numba_funcify.register(Composite)
def numba_funcify_Composite(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True)
composite_fn = numba.njit(signature)(
numba_funcify(op.fgraph, squeeze_output=True, **kwargs)
)
return composite_fn
def create_index_func(node, objmode=False):
"""Create a Python function that assembles and uses an index on an array."""
def convert_indices(indices, entry):
if indices and isinstance(entry, Type):
rval = indices.pop(0)
return rval.auto_name
elif isinstance(entry, slice):
return (
f"slice({convert_indices(indices, entry.start)}, "
f"{convert_indices(indices, entry.stop)}, "
f"{convert_indices(indices, entry.step)})"
)
elif isinstance(entry, type(None)):
return "None"
else:
raise ValueError()
set_or_inc = isinstance(
node.op, (IncSubtensor, AdvancedIncSubtensor1, AdvancedIncSubtensor)
)
index_start_idx = 1 + int(set_or_inc)
unique_names = unique_name_generator(
["subtensor", "incsubtensor", "z"], suffix_sep="_"
)
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
op_indices = list(node.inputs[index_start_idx:])
idx_list = getattr(node.op, "idx_list", None)
indices_creation_src = (
tuple(convert_indices(op_indices, idx) for idx in idx_list)
if idx_list
else tuple(input_names[index_start_idx:])
)
if len(indices_creation_src) == 1:
indices_creation_src = f"indices = ({indices_creation_src[0]},)"
else:
indices_creation_src = ", ".join(indices_creation_src)
indices_creation_src = f"indices = ({indices_creation_src})"
if set_or_inc:
fn_name = "incsubtensor"
if node.op.inplace:
index_prologue = f"z = {input_names[0]}"
else:
index_prologue = f"z = np.copy({input_names[0]})"
if node.inputs[1].ndim == 0:
# TODO FIXME: This is a hack to get around a weird Numba typing
# issue. See https://github.com/numba/numba/issues/6000
y_name = f"{input_names[1]}.item()"
else:
y_name = input_names[1]
if node.op.set_instead_of_inc:
index_body = f"z[indices] = {y_name}"
else:
index_body = f"z[indices] += {y_name}"
else:
fn_name = "subtensor"
index_prologue = ""
index_body = f"z = {input_names[0]}[indices]"
if objmode:
output_var = node.outputs[0]
if not set_or_inc:
# Since `z` is being "created" while in object mode, it's
# considered an "outgoing" variable and needs to be manually typed
output_sig = f"z='{output_var.dtype}[{', '.join([':'] * output_var.ndim)}]'"
else:
output_sig = ""
index_body = f"""
with objmode({output_sig}):
{index_body}
"""
subtensor_def_src = f"""
def {fn_name}({", ".join(input_names)}):
{index_prologue}
{indices_creation_src}
{index_body}
return z
"""
return subtensor_def_src
@numba_funcify.register(Subtensor)
@numba_funcify.register(AdvancedSubtensor)
@numba_funcify.register(AdvancedSubtensor1)
def numba_funcify_Subtensor(op, node, **kwargs):
subtensor_def_src = create_index_func(
node, objmode=isinstance(op, AdvancedSubtensor)
)
global_env = {"np": np, "objmode": numba.objmode}
subtensor_fn = compile_function_src(subtensor_def_src, "subtensor", global_env)
return numba.njit(subtensor_fn)
@numba_funcify.register(IncSubtensor)
@numba_funcify.register(AdvancedIncSubtensor)
@numba_funcify.register(AdvancedIncSubtensor1)
def numba_funcify_IncSubtensor(op, node, **kwargs):
incsubtensor_def_src = create_index_func(
node, objmode=isinstance(op, AdvancedIncSubtensor)
)
global_env = {"np": np, "objmode": numba.objmode}
incsubtensor_fn = compile_function_src(
incsubtensor_def_src, "incsubtensor", global_env
)
return numba.njit(incsubtensor_fn)
@numba_funcify.register(DeepCopyOp)
def numba_funcify_DeepCopyOp(op, node, **kwargs):
# Scalars are apparently returned as actual Python scalar types and not
# NumPy scalars, so we need two separate Numba functions for each case.
if node.outputs[0].type.ndim == 0:
# TODO: Do we really need to compile a pass-through function like this?
@numba.njit(inline="always")
def deepcopyop(x):
return x
else:
@numba.njit(inline="always")
def deepcopyop(x):
return x.copy()
return deepcopyop
@numba_funcify.register(MakeSlice)
def numba_funcify_MakeSlice(op, **kwargs):
@numba.njit
def makeslice(*x):
return slice(*x)
return makeslice
@numba_funcify.register(MakeVector)
def numba_funcify_MakeVector(op, **kwargs):
dtype = np.dtype(op.dtype)
@numba.njit
def makevector(*args):
return np.array([a.item() for a in args], dtype=dtype)
return makevector
@numba_funcify.register(Shape)
def numba_funcify_Shape(op, **kwargs):
@numba.njit(inline="always")
def shape(x):
return np.asarray(np.shape(x))
return shape
@numba_funcify.register(Shape_i)
def numba_funcify_Shape_i(op, **kwargs):
i = op.i
@numba.njit(inline="always")
def shape_i(x):
return np.shape(x)[i]
return shape_i
@numba_funcify.register(TensorFromScalar)
def numba_funcify_TensorFromScalar(op, **kwargs):
@numba.njit(inline="always")
def tensor_from_scalar(x):
return np.array(x)
return tensor_from_scalar
@numba_funcify.register(ScalarFromTensor)
def numba_funcify_ScalarFromTensor(op, **kwargs):
@numba.njit(inline="always")
def scalar_from_tensor(x):
return x.item()
return scalar_from_tensor
@numba_funcify.register(AllocEmpty)
def numba_funcify_AllocEmpty(op, node, **kwargs):
global_env = {"np": np, "to_scalar": to_scalar, "dtype": np.dtype(op.dtype)}
unique_names = unique_name_generator(
["np", "to_scalar", "dtype", "allocempty", "scalar_shape"], suffix_sep="_"
)
shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs]
shape_var_item_names = [f"{name}_item" for name in shape_var_names]
shapes_to_items_src = indent(
"\n".join(
[
f"{item_name} = to_scalar({shape_name})"
for item_name, shape_name in zip(shape_var_item_names, shape_var_names)
]
),
" " * 4,
)
alloc_def_src = f"""
def allocempty({", ".join(shape_var_names)}):
{shapes_to_items_src}
scalar_shape = {create_tuple_string(shape_var_item_names)}
return np.empty(scalar_shape, dtype)
"""
alloc_fn = compile_function_src(alloc_def_src, "allocempty", global_env)
return numba.njit(alloc_fn)
@numba_funcify.register(Alloc)
def numba_funcify_Alloc(op, node, **kwargs):
global_env = {"np": np, "to_scalar": to_scalar}
unique_names = unique_name_generator(
["np", "to_scalar", "alloc", "val_np", "val", "scalar_shape", "res"],
suffix_sep="_",
)
shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs[1:]]
shape_var_item_names = [f"{name}_item" for name in shape_var_names]
shapes_to_items_src = indent(
"\n".join(
[
f"{item_name} = to_scalar({shape_name})"
for item_name, shape_name in zip(shape_var_item_names, shape_var_names)
]
),
" " * 4,
)
alloc_def_src = f"""
def alloc(val, {", ".join(shape_var_names)}):
val_np = np.asarray(val)
{shapes_to_items_src}
scalar_shape = {create_tuple_string(shape_var_item_names)}
res = np.empty(scalar_shape, dtype=val_np.dtype)
res[...] = val_np
return res
"""
alloc_fn = compile_function_src(alloc_def_src, "alloc", global_env)
return numba.njit(alloc_fn)
@numba_funcify.register(AllocDiag)
def numba_funcify_AllocDiag(op, **kwargs):
offset = op.offset
@numba.njit(inline="always")
def allocdiag(v):
return np.diag(v, k=offset)
return allocdiag
@numba_funcify.register(Second)
def numba_funcify_Second(op, node, **kwargs):
@numba.njit(inline="always")
def second(x, y):
return y
return second
@numba_funcify.register(DimShuffle)
def numba_funcify_DimShuffle(op, **kwargs):
shuffle = tuple(op.shuffle)
drop = tuple(op.drop)
augment = tuple(op.augment)
inplace = op.inplace
ndim_new_shape = len(shuffle) + len(augment)
create_zeros_tuple = create_tuple_creator(lambda _: 0, ndim_new_shape)
if len(shuffle) > 0:
@numba.njit
def populate_new_shape(i, j, new_shape, shuffle_shape):
if i in augment:
new_shape = tuple_setitem(new_shape, i, 1)
return j, new_shape
else:
new_shape = tuple_setitem(new_shape, i, shuffle_shape[j])
return j + 1, new_shape
else:
# When `len(shuffle) == 0`, the `shuffle_shape[j]` expression above is
# is typed as `getitem(Tuple(), int)`, which has no implementation
# (since getting an item from an empty sequence doesn't make sense).
# To avoid this compile-time error, we omit the expression altogether.
@numba.njit(inline="always")
def populate_new_shape(i, j, new_shape, shuffle_shape):
return j, tuple_setitem(new_shape, i, 1)
@numba.njit
def dimshuffle_inner(x, shuffle):
res = np.transpose(x, shuffle + drop)
shuffle_shape = res.shape[: len(shuffle)]
new_shape = create_zeros_tuple()
j = 0
for i in range(len(new_shape)):
j, new_shape = populate_new_shape(i, j, new_shape, shuffle_shape)
# FIXME: Numba's `array.reshape` only accepts C arrays.
res_reshape = np.reshape(np.ascontiguousarray(res), new_shape)
if not inplace:
return res_reshape.copy()
else:
return res_reshape
# Without the following wrapper function we would see this error:
# E No implementation of function Function(<built-in function getitem>) found for signature:
# E
# E >>> getitem(UniTuple(int64 x 2), slice<a:b>)
# E
# E There are 22 candidate implementations:
# E - Of which 22 did not match due to:
# E Overload of function 'getitem': File: <numerous>: Line N/A.
# E With argument(s): '(UniTuple(int64 x 2), slice<a:b>)':
# E No match.
# ...(on this line)...
# E shuffle_shape = res.shape[: len(shuffle)]
@numba.njit(inline="always")
def dimshuffle(x):
return dimshuffle_inner(np.asarray(x), shuffle)
return dimshuffle
@numba_funcify.register(Rebroadcast)
def numba_funcify_Rebroadcast(op, **kwargs):
op_axis = tuple(op.axis.items())
@numba.njit
def rebroadcast(x):
for axis, value in numba.literal_unroll(op_axis):
if value and x.shape[axis] != 1:
raise ValueError(
("Dimension in Rebroadcast's input was supposed to be 1")
)
return x
return rebroadcast
@numba.extending.intrinsic
def direct_cast(typingctx, val, typ):
if isinstance(typ, numba.types.TypeRef):
casted = typ.instance_type
elif isinstance(typ, numba.types.DTypeSpec):
casted = typ.dtype
else:
casted = typ
sig = casted(casted, typ)
def codegen(context, builder, signature, args):
val, _ = args
context.nrt.incref(builder, signature.return_type, val)
return val
return sig, codegen
@numba_funcify.register(Cast)
def numba_funcify_Cast(op, node, **kwargs):
dtype = np.dtype(op.o_type.dtype)
@numba.njit(inline="always")
def cast(x):
return direct_cast(x, dtype)
return cast
@numba_funcify.register(Reshape)
def numba_funcify_Reshape(op, **kwargs):
ndim = op.ndim
if ndim == 0:
@numba.njit(inline="always")
def reshape(x, shape):
return x.item()
else:
@numba.njit(inline="always")
def reshape(x, shape):
# TODO: Use this until https://github.com/numba/numba/issues/7353 is closed.
return np.reshape(
np.ascontiguousarray(np.asarray(x)), to_fixed_tuple(shape, ndim)
)
return reshape
@numba_funcify.register(SpecifyShape)
def numba_funcify_SpecifyShape(op, **kwargs):
@numba.njit
def specifyshape(x, shape):
assert np.array_equal(x.shape, shape)
return x
return specifyshape
@numba_funcify.register(Identity)
@numba_funcify.register(ViewOp)
def numba_funcify_ViewOp(op, **kwargs):
@numba.njit(inline="always")
def viewop(x):
return x
return viewop
@numba_funcify.register(Clip)
def numba_funcify_Clip(op, **kwargs):
@numba.njit
def clip(_x, _min, _max):
x = to_scalar(_x)
_min_scalar = to_scalar(_min)
_max_scalar = to_scalar(_max)
if x < _min_scalar:
return _min_scalar
elif x > _max_scalar:
return _max_scalar
else:
return x
return clip
@numba_funcify.register(ARange)
def numba_funcify_ARange(op, **kwargs):
dtype = np.dtype(op.dtype)
@numba.njit(inline="always")
def arange(start, stop, step):
return np.arange(
to_scalar(start), to_scalar(stop), to_scalar(step), dtype=dtype
)
return arange
@numba_funcify.register(Join)
def numba_funcify_Join(op, **kwargs):
view = op.view
if view != -1:
# TODO: Where (and why) is this `Join.view` even being used? From a
# quick search, the answer appears to be "nowhere", so we should
# probably just remove it.
raise NotImplementedError("The `view` parameter to `Join` is not supported")
@numba.njit
def join(axis, *tensors):
return np.concatenate(tensors, to_scalar(axis))
return join
@numba_funcify.register(ExtractDiag)
def numba_funcify_ExtractDiag(op, **kwargs):
offset = op.offset
# axis1 = op.axis1
# axis2 = op.axis2
@numba.njit(inline="always")
def extract_diag(x):
return np.diag(x, k=offset)
return extract_diag
@numba_funcify.register(Eye)
def numba_funcify_Eye(op, **kwargs):
dtype = np.dtype(op.dtype)
@numba.njit(inline="always")
def eye(N, M, k):
return np.eye(to_scalar(N), to_scalar(M), to_scalar(k), dtype=dtype)
return eye
@numba_funcify.register(Bartlett)
def numba_funcify_Bartlett(op, **kwargs):
@numba.njit(inline="always")
def bartlett(x):
return np.bartlett(to_scalar(x))
return bartlett
@numba_funcify.register(CumOp)
def numba_funcify_CumOp(op, node, **kwargs):
axis = op.axis
mode = op.mode
ndim = node.outputs[0].ndim
reaxis_first = (axis,) + tuple(i for i in range(ndim) if i != axis)
if mode == "add":
np_func = np.add
identity = 0
else:
np_func = np.multiply
identity = 1
@numba.njit(boundscheck=False)
def cumop(x):
out_dtype = x.dtype
if x.shape[axis] < 2:
return x.astype(out_dtype)
x_axis_first = x.transpose(reaxis_first)
res = np.empty(x_axis_first.shape, dtype=out_dtype)
for m in range(x.shape[axis]):
if m == 0:
np_func(identity, x_axis_first[m], res[m])
else:
np_func(res[m - 1], x_axis_first[m], res[m])
return res.transpose(reaxis_first)
return cumop
@numba_funcify.register(DiffOp)
def numba_funcify_DiffOp(op, node, **kwargs):
n = op.n
axis = op.axis
ndim = node.inputs[0].ndim
dtype = node.outputs[0].dtype
axis = normalize_axis_index(axis, ndim)
slice1 = [slice(None)] * ndim
slice2 = [slice(None)] * ndim
slice1[axis] = slice(1, None)
slice2[axis] = slice(None, -1)
slice1 = tuple(slice1)
slice2 = tuple(slice2)
op = np.not_equal if dtype == "bool" else np.subtract
@numba.njit(boundscheck=False)
def diffop(x):
res = x.copy()
for _ in range(n):
res = op(res[slice1], res[slice2])
return res
return diffop
@numba_funcify.register(FillDiagonal)
def numba_funcify_FillDiagonal(op, **kwargs):
@numba.njit
def filldiagonal(a, val):
np.fill_diagonal(a, val)
return a
return filldiagonal
@numba_funcify.register(FillDiagonalOffset)
def numba_funcify_FillDiagonalOffset(op, node, **kwargs):
@numba.njit
def filldiagonaloffset(a, val, offset):
height, width = a.shape
if offset >= 0:
start = to_scalar(offset)
num_of_step = min(min(width, height), width - offset)
else:
start = -to_scalar(offset) * a.shape[1]
num_of_step = min(min(width, height), height + offset)
step = a.shape[1] + 1
end = start + step * num_of_step
b = a.ravel()
b[start:end:step] = val
# TODO: This isn't implemented in Numba
# a.flat[start:end:step] = val
# return a
return b.reshape(a.shape)
return filldiagonaloffset
@numba_funcify.register(RavelMultiIndex)
def numba_funcify_RavelMultiIndex(op, node, **kwargs):
mode = op.mode
order = op.order
if order != "C":
raise NotImplementedError(
"Numba does not implement `order` in `numpy.ravel_multi_index`"
)
if mode == "raise":
@numba.njit
def mode_fn(*args):
raise ValueError("invalid entry in coordinates array")
elif mode == "wrap":
@numba.njit(inline="always")
def mode_fn(new_arr, i, j, v, d):
new_arr[i, j] = v % d
elif mode == "clip":
@numba.njit(inline="always")
def mode_fn(new_arr, i, j, v, d):
new_arr[i, j] = min(max(v, 0), d - 1)
if node.inputs[0].ndim == 0:
@numba.njit
def ravelmultiindex(*inp):
shape = inp[-1]
arr = np.stack(inp[:-1])
new_arr = arr.T.astype(np.float64).copy()
for i, b in enumerate(new_arr):
if b < 0 or b >= shape[i]:
mode_fn(new_arr, i, 0, b, shape[i])
a = np.ones(len(shape), dtype=np.float64)
a[: len(shape) - 1] = np.cumprod(shape[-1:0:-1])[::-1]
return np.array(a.dot(new_arr.T), dtype=np.int64)
else:
@numba.njit
def ravelmultiindex(*inp):
shape = inp[-1]
arr = np.stack(inp[:-1])
new_arr = arr.T.astype(np.float64).copy()
for i, b in enumerate(new_arr):
for j, (d, v) in enumerate(zip(shape, b)):
if v < 0 or v >= d:
mode_fn(new_arr, i, j, v, d)
a = np.ones(len(shape), dtype=np.float64)
a[: len(shape) - 1] = np.cumprod(shape[-1:0:-1])[::-1]
return a.dot(new_arr.T).astype(np.int64)
return ravelmultiindex
@numba_funcify.register(Repeat)
def numba_funcify_Repeat(op, node, **kwargs):
axis = op.axis
use_python = False
if axis is not None:
use_python = True
if use_python:
warnings.warn(
(
"Numba will use object mode to allow the "
"`axis` argument to `numpy.repeat`."
),
UserWarning,
)
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
def repeatop(x, repeats):
with numba.objmode(ret=ret_sig):
ret = np.repeat(x, repeats, axis)
return ret
else:
repeats_ndim = node.inputs[1].ndim
if repeats_ndim == 0:
@numba.njit(inline="always")
def repeatop(x, repeats):
return np.repeat(x, repeats.item())
else:
@numba.njit(inline="always")
def repeatop(x, repeats):
return np.repeat(x, repeats)
return repeatop
@numba_funcify.register(Unique)
def numba_funcify_Unique(op, node, **kwargs):
axis = op.axis
use_python = False
if axis is not None:
use_python = True
return_index = op.return_index
return_inverse = op.return_inverse
return_counts = op.return_counts
returns_multi = return_index or return_inverse or return_counts
use_python |= returns_multi
if not use_python:
@numba.njit(inline="always")
def unique(x):
return np.unique(x)
else:
warnings.warn(
(
"Numba will use object mode to allow the "
"`axis` and/or `return_*` arguments to `numpy.unique`."
),
UserWarning,
)
if returns_multi:
ret_sig = numba.types.Tuple([get_numba_type(o.type) for o in node.outputs])
else:
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
def unique(x):
with numba.objmode(ret=ret_sig):
ret = np.unique(x, return_index, return_inverse, return_counts, axis)
return ret
return unique
@numba_funcify.register(UnravelIndex)
def numba_funcify_UnravelIndex(op, node, **kwargs):
order = op.order
if order != "C":
raise NotImplementedError(
"Numba does not support the `order` argument in `numpy.unravel_index`"
)
if len(node.outputs) == 1:
@numba.njit(inline="always")
def maybe_expand_dim(arr):
return arr
else:
@numba.njit(inline="always")
def maybe_expand_dim(arr):
return np.expand_dims(arr, 1)
@numba.njit
def unravelindex(arr, shape):
a = np.ones(len(shape), dtype=np.int64)
a[1:] = shape[:0:-1]
a = np.cumprod(a)[::-1]
# Aesara actually returns a `tuple` of these values, instead of an
# `ndarray`; however, this `ndarray` result should be able to be
# unpacked into a `tuple`, so this discrepancy shouldn't really matter
return ((maybe_expand_dim(arr) // a) % shape).T
return unravelindex
@numba_funcify.register(SearchsortedOp)
def numba_funcify_Searchsorted(op, node, **kwargs):
side = op.side
use_python = False
if len(node.inputs) == 3:
use_python = True
if use_python:
warnings.warn(
(
"Numba will use object mode to allow the "
"`sorter` argument to `numpy.searchsorted`."
),
UserWarning,
)
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
def searchsorted(a, v, sorter):
with numba.objmode(ret=ret_sig):
ret = np.searchsorted(a, v, side, sorter)
return ret
else:
@numba.njit(inline="always")
def searchsorted(a, v):
return np.searchsorted(a, v, side)
return searchsorted
def int_to_float_fn(inputs, out_dtype):
"""Create a Numba function that converts integer and boolean ``ndarray``s to floats."""
if any(i.type.numpy_dtype.kind in "ib" for i in inputs):
args_dtype = np.dtype(f"f{out_dtype.itemsize}")
@numba.njit(inline="always")
def inputs_cast(x):
return x.astype(args_dtype)
else:
@numba.njit(inline="always")
def inputs_cast(x):
return x
return inputs_cast
@numba_funcify.register(Dot)
def numba_funcify_Dot(op, node, **kwargs):
# Numba's `np.dot` does not support integer dtypes, so we need to cast to
# float.
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
def dot(x, y):
return np.asarray(np.dot(inputs_cast(x), inputs_cast(y))).astype(out_dtype)
return dot
@numba_funcify.register(Softmax)
def numba_funcify_Softmax(op, node, **kwargs):
x_at = node.inputs[0]
x_dtype = x_at.type.numpy_dtype
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
# np.max(x, axis=1)
reduce_max = create_axis_reducer(np.maximum, -np.inf, 1, x_at.ndim, x_dtype)
# np.sum(x, axis=1)
reduce_sum = create_axis_reducer(np.add, 0.0, 1, x_at.ndim, x_dtype)
@numba.njit
def softmax(x):
z = np.expand_dims(reduce_max(x), -1)
e_x = np.exp(x - z)
w = np.expand_dims(reduce_sum(e_x), -1)
sm = e_x / w
return sm
return softmax
@numba_funcify.register(LogSoftmax)
def numba_funcify_LogSoftmax(op, node, **kwargs):
x_at = node.inputs[0]
x_dtype = x_at.type.numpy_dtype
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
# np.max(x, axis=1)
reduce_max = create_axis_reducer(np.maximum, -np.inf, 1, x_at.ndim, x_dtype)
# np.sum(x, axis=1, keepdims=True)
reduce_sum = create_axis_reducer(np.add, 0.0, 1, x_at.ndim, x_dtype, keepdims=True)
@numba.njit
def log_softmax(x):
xdev = x - np.expand_dims(reduce_max(x), -1)
lsm = xdev - np.log(reduce_sum(np.exp(xdev)))
return lsm
return log_softmax
@numba_funcify.register(Softplus)
def numba_funcify_Softplus(op, node, **kwargs):
x_dtype = np.dtype(node.inputs[0].dtype)
@numba.njit
def softplus(x):
if x < -37.0:
return direct_cast(np.exp(x), x_dtype)
elif x < 18.0:
return direct_cast(np.log1p(np.exp(x)), x_dtype)
elif x < 33.3:
return direct_cast(x + np.exp(-x), x_dtype)
else:
return direct_cast(x, x_dtype)
return softplus
def create_axis_apply_fn(fn, axis, ndim, dtype):
reaxis_first = tuple(i for i in range(ndim) if i != axis) + (axis,)
@numba.njit(boundscheck=False)
def axis_apply_fn(x):
x_reaxis = x.transpose(reaxis_first)
res = np.zeros(x_reaxis.shape[:-1], dtype=dtype)
for m in np.ndindex(res.shape):
v = fn(x_reaxis[m])
res[m] = v
return res
return axis_apply_fn
@numba_funcify.register(MaxAndArgmax)
def numba_funcify_MaxAndArgmax(op, node, **kwargs):
axis = op.axis
x_at = node.inputs[0]
x_dtype = x_at.type.numpy_dtype
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
x_ndim = x_at.ndim
if x_ndim == 0:
@numba.njit(inline="always")
def maxandargmax(x):
return x, 0
else:
axes = tuple(int(ax) for ax in axis)
# NumPy does not support multiple axes for argmax; this is a
# work-around
keep_axes = tuple(i for i in range(x_ndim) if i not in axes)
reduce_max = create_multiaxis_reducer(
np.maximum, -np.inf, axes, x_ndim, x_dtype
)
reduced_x_ndim = x_ndim - len(axes) + 1
argmax_axis = create_axis_apply_fn(
np.argmax, reduced_x_ndim - 1, reduced_x_ndim, np.int64
)
reaxis_order = keep_axes + axes
sl1 = slice(None, len(keep_axes))
sl2 = slice(len(keep_axes), None)
@numba.njit
def maxandargmax(x):
max_res = reduce_max(x)
# Not-reduced axes in front
transposed_x = np.ascontiguousarray(np.transpose(x, reaxis_order))
kept_shape = transposed_x.shape[sl1]
reduced_shape = transposed_x.shape[sl2]
reduced_size = 1
for s in reduced_shape:
reduced_size *= s
# Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
# Otherwise reshape would complain citing float arg
new_shape = kept_shape + (reduced_size,)
reshaped_x = transposed_x.reshape(new_shape)
max_idx_res = argmax_axis(reshaped_x)
return max_res, max_idx_res
return maxandargmax
@numba_funcify.register(Cholesky)
def numba_funcify_Cholesky(op, node, **kwargs):
lower = op.lower
out_dtype = node.outputs[0].type.numpy_dtype
if lower:
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
def cholesky(a):
return np.linalg.cholesky(inputs_cast(a)).astype(out_dtype)
else:
# TODO: Use SciPy's BLAS/LAPACK Cython wrappers.
warnings.warn(
(
"Numba will use object mode to allow the "
"`lower` argument to `scipy.linalg.cholesky`."
),
UserWarning,
)
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
def cholesky(a):
with numba.objmode(ret=ret_sig):
ret = scipy.linalg.cholesky(a, lower=lower).astype(out_dtype)
return ret
return cholesky
@numba_funcify.register(Solve)
def numba_funcify_Solve(op, node, **kwargs):
assume_a = op.assume_a
# check_finite = op.check_finite
if assume_a != "gen":
lower = op.lower
warnings.warn(
(
"Numba will use object mode to allow the "
"`compute_uv` argument to `numpy.linalg.svd`."
),
UserWarning,
)
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
def solve(a, b):
with numba.objmode(ret=ret_sig):
ret = scipy.linalg.solve_triangular(
a,
b,
lower=lower,
# check_finite=check_finite
)
return ret
else:
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
def solve(a, b):
return np.linalg.solve(
inputs_cast(a),
inputs_cast(b),
# assume_a=assume_a,
# check_finite=check_finite,
).astype(out_dtype)
return solve
@numba_funcify.register(Det)
def numba_funcify_Det(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
def det(x):
return direct_cast(np.linalg.det(inputs_cast(x)), out_dtype)
return det
@numba_funcify.register(Eig)
def numba_funcify_Eig(op, node, **kwargs):
out_dtype_1 = node.outputs[0].type.numpy_dtype
out_dtype_2 = node.outputs[1].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype_1)
@numba.njit
def eig(x):
out = np.linalg.eig(inputs_cast(x))
return (out[0].astype(out_dtype_1), out[1].astype(out_dtype_2))
return eig
@numba_funcify.register(Eigh)
def numba_funcify_Eigh(op, node, **kwargs):
uplo = op.UPLO
if uplo != "L":
warnings.warn(
(
"Numba will use object mode to allow the "
"`UPLO` argument to `numpy.linalg.eigh`."
),
UserWarning,
)
out_dtypes = tuple(o.type.numpy_dtype for o in node.outputs)
ret_sig = numba.types.Tuple(
[get_numba_type(node.outputs[0].type), get_numba_type(node.outputs[1].type)]
)
@numba.njit
def eigh(x):
with numba.objmode(ret=ret_sig):
out = np.linalg.eigh(x, UPLO=uplo)
ret = (out[0].astype(out_dtypes[0]), out[1].astype(out_dtypes[1]))
return ret
else:
@numba.njit(inline="always")
def eigh(x):
return np.linalg.eigh(x)
return eigh
@numba_funcify.register(MatrixInverse)
def numba_funcify_MatrixInverse(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
def matrix_inverse(x):
return np.linalg.inv(inputs_cast(x)).astype(out_dtype)
return matrix_inverse
@numba_funcify.register(MatrixPinv)
def numba_funcify_MatrixPinv(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
def matrixpinv(x):
return np.linalg.pinv(inputs_cast(x)).astype(out_dtype)
return matrixpinv
@numba_funcify.register(Inv)
def numba_funcify_Inv(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
def inv(x):
return np.linalg.inv(inputs_cast(x)).astype(out_dtype)
return inv
@numba_funcify.register(QRFull)
def numba_funcify_QRFull(op, node, **kwargs):
mode = op.mode
if mode != "reduced":
warnings.warn(
(
"Numba will use object mode to allow the "
"`mode` argument to `numpy.linalg.qr`."
),
UserWarning,
)
if len(node.outputs) > 1:
ret_sig = numba.types.Tuple([get_numba_type(o.type) for o in node.outputs])
else:
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
def qr_full(x):
with numba.objmode(ret=ret_sig):
ret = np.linalg.qr(x, mode=mode)
return ret
else:
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
def qr_full(x):
return np.linalg.qr(inputs_cast(x))
return qr_full
@numba_funcify.register(SVD)
def numba_funcify_SVD(op, node, **kwargs):
full_matrices = op.full_matrices
compute_uv = op.compute_uv
if not compute_uv:
warnings.warn(
(
"Numba will use object mode to allow the "
"`compute_uv` argument to `numpy.linalg.svd`."
),
UserWarning,
)
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
def svd(x):
with numba.objmode(ret=ret_sig):
ret = np.linalg.svd(x, full_matrices, compute_uv)
return ret
else:
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
def svd(x):
return np.linalg.svd(inputs_cast(x), full_matrices)
return svd
@numba_funcify.register(BatchedDot)
def numba_funcify_BatchedDot(op, node, **kwargs):
dtype = node.outputs[0].type.numpy_dtype
@numba.njit
def batched_dot(x, y):
shape = x.shape[:-1] + y.shape[2:]
z0 = np.empty(shape, dtype=dtype)
for i in range(z0.shape[0]):
z0[i] = np.dot(x[i], y[i])
return z0
return batched_dot
# NOTE: The remaining `aesara.tensor.blas` `Op`s appear unnecessary, because
# they're only used to optimize basic `Dot` nodes, and those GEMV and GEMM
# optimizations are apparently already performed by Numba
def make_numba_random_fn(node, np_random_func):
"""Create Numba implementations for existing Numba-supported ``np.random`` functions.
The functions generated here add parameter broadcasting and the ``size``
argument to the Numba-supported scalar ``np.random`` functions.
"""
tuple_size = get_vector_length(node.inputs[1])
size_dims = tuple_size - max(i.ndim for i in node.inputs[3:])
# Make a broadcast-capable version of the Numba supported scalar sampling
# function
bcast_fn_name = f"aesara_random_{get_name_for_object(np_random_func)}"
sized_fn_name = "sized_random_variable"
unique_names = unique_name_generator(
[
bcast_fn_name,
sized_fn_name,
"np",
"np_random_func",
"numba_vectorize",
"to_fixed_tuple",
"tuple_size",
"size_dims",
"rng",
"size",
"dtype",
],
suffix_sep="_",
)
bcast_fn_input_names = ", ".join(
[unique_names(i, force_unique=True) for i in node.inputs[3:]]
)
bcast_fn_global_env = {
"np_random_func": np_random_func,
"numba_vectorize": numba.vectorize,
}
bcast_fn_src = f"""
@numba_vectorize
def {bcast_fn_name}({bcast_fn_input_names}):
return np_random_func({bcast_fn_input_names})
"""
bcast_fn = compile_function_src(bcast_fn_src, bcast_fn_name, bcast_fn_global_env)
random_fn_input_names = ", ".join(
["rng", "size", "dtype"] + [unique_names(i) for i in node.inputs[3:]]
)
# Now, create a Numba JITable function that implements the `size` parameter
out_dtype = node.outputs[1].type.numpy_dtype
random_fn_global_env = {
bcast_fn_name: bcast_fn,
"out_dtype": out_dtype,
}
if tuple_size > 0:
random_fn_body = dedent(
f"""
size = to_fixed_tuple(size, tuple_size)
data = np.empty(size, dtype=out_dtype)
for i in np.ndindex(size[:size_dims]):
data[i] = {bcast_fn_name}({bcast_fn_input_names})
"""
)
random_fn_global_env.update(
{
"np": np,
"to_fixed_tuple": to_fixed_tuple,
"tuple_size": tuple_size,
"size_dims": size_dims,
}
)
else:
random_fn_body = f"""data = {bcast_fn_name}({bcast_fn_input_names})"""
sized_fn_src = dedent(
f"""
def {sized_fn_name}({random_fn_input_names}):
{indent(random_fn_body, " " * 4)}
return (rng, data)
"""
)
random_fn = compile_function_src(sized_fn_src, sized_fn_name, random_fn_global_env)
random_fn = numba.njit(random_fn)
return random_fn
@numba_funcify.register(aer.UniformRV)
@numba_funcify.register(aer.TriangularRV)
@numba_funcify.register(aer.BetaRV)
@numba_funcify.register(aer.NormalRV)
@numba_funcify.register(aer.LogNormalRV)
@numba_funcify.register(aer.GammaRV)
@numba_funcify.register(aer.ChiSquareRV)
@numba_funcify.register(aer.ParetoRV)
@numba_funcify.register(aer.GumbelRV)
@numba_funcify.register(aer.ExponentialRV)
@numba_funcify.register(aer.WeibullRV)
@numba_funcify.register(aer.LogisticRV)
@numba_funcify.register(aer.VonMisesRV)
@numba_funcify.register(aer.PoissonRV)
@numba_funcify.register(aer.GeometricRV)
@numba_funcify.register(aer.HyperGeometricRV)
@numba_funcify.register(aer.CauchyRV)
@numba_funcify.register(aer.WaldRV)
@numba_funcify.register(aer.LaplaceRV)
@numba_funcify.register(aer.BinomialRV)
@numba_funcify.register(aer.NegBinomialRV)
@numba_funcify.register(aer.MultinomialRV)
@numba_funcify.register(aer.RandIntRV) # only the first two arguments are supported
@numba_funcify.register(aer.ChoiceRV) # the `p` argument is not supported
@numba_funcify.register(aer.PermutationRV)
def numba_funcify_RandomVariable(op, node, **kwargs):
name = op.name
np_random_func = getattr(np.random, name)
if not isinstance(node.inputs[0], (RandomStateType, RandomStateSharedVariable)):
raise TypeError("Numba does not support NumPy `Generator`s")
return make_numba_random_fn(node, np_random_func)
def create_numba_random_fn(
op: Op,
node: Apply,
scalar_fn: Callable[[str], str],
global_env: Optional[Dict[str, Any]] = None,
) -> Callable:
"""Create a vectorized function from a callable that generates the ``str`` function body.
TODO: This could/should be generalized for other simple function
construction cases that need unique-ified symbol names.
"""
np_random_fn_name = f"aesara_random_{get_name_for_object(op.name)}"
if global_env:
np_global_env = global_env.copy()
else:
np_global_env = {}
np_global_env["np"] = np
np_global_env["numba_vectorize"] = numba.vectorize
unique_names = unique_name_generator(
[
np_random_fn_name,
]
+ list(np_global_env.keys())
+ [
"rng",
"size",
"dtype",
],
suffix_sep="_",
)
np_names = [unique_names(i, force_unique=True) for i in node.inputs[3:]]
np_input_names = ", ".join(np_names)
np_random_fn_src = f"""
@numba_vectorize
def {np_random_fn_name}({np_input_names}):
{scalar_fn(*np_names)}
"""
np_random_fn = compile_function_src(
np_random_fn_src, np_random_fn_name, np_global_env
)
return make_numba_random_fn(node, np_random_fn)
@numba_funcify.register(aer.HalfNormalRV)
def numba_funcify_HalfNormalRV(op, node, **kwargs):
def body_fn(a, b):
return f" return {a} + {b} * abs(np.random.normal(0, 1))"
return create_numba_random_fn(op, node, body_fn)
@numba_funcify.register(aer.BernoulliRV)
def numba_funcify_BernoulliRV(op, node, **kwargs):
out_dtype = node.outputs[1].type.numpy_dtype
def body_fn(a):
return f"""
if {a} < np.random.uniform(0, 1):
return direct_cast(0, out_dtype)
else:
return direct_cast(1, out_dtype)
"""
return create_numba_random_fn(
op, node, body_fn, {"out_dtype": out_dtype, "direct_cast": direct_cast}
)
# isort: off
from aesara.link.numba.dispatch.basic import numba_funcify, numba_typify
# Load dispatch specializations
import aesara.link.numba.dispatch.scalar
import aesara.link.numba.dispatch.tensor_basic
import aesara.link.numba.dispatch.extra_ops
import aesara.link.numba.dispatch.nlinalg
import aesara.link.numba.dispatch.random
import aesara.link.numba.dispatch.elemwise
# isort: on
import operator
import warnings
from functools import singledispatch
import numba
import numba.np.unsafe.ndarray as numba_ndarray
import numpy as np
import scipy
import scipy.special
from llvmlite.llvmpy.core import Type as llvm_Type
from numba import types
from numba.core.errors import TypingError
from numba.extending import box
from aesara.compile.ops import DeepCopyOp
from aesara.graph.basic import Apply
from aesara.graph.fg import FunctionGraph
from aesara.graph.type import Type
from aesara.link.utils import (
compile_function_src,
fgraph_to_python,
unique_name_generator,
)
from aesara.scalar.basic import Scalar
from aesara.scalar.math import Softplus
from aesara.tensor.blas import BatchedDot
from aesara.tensor.math import Dot
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from aesara.tensor.slinalg import Cholesky, Solve
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
AdvancedSubtensor,
AdvancedSubtensor1,
IncSubtensor,
Subtensor,
)
from aesara.tensor.type import TensorType
from aesara.tensor.type_other import MakeSlice
def get_numba_type(
aesara_type: Type, layout: str = "A", force_scalar: bool = False
) -> numba.types.Type:
"""Create a Numba type object for a ``Type``."""
if isinstance(aesara_type, TensorType):
dtype = aesara_type.numpy_dtype
numba_dtype = numba.from_dtype(dtype)
if force_scalar:
return numba_dtype
return numba.types.Array(numba_dtype, aesara_type.ndim, layout)
elif isinstance(aesara_type, Scalar):
dtype = np.dtype(aesara_type.dtype)
numba_dtype = numba.from_dtype(dtype)
return numba_dtype
else:
raise NotImplementedError(f"Numba type not implemented for {aesara_type}")
def create_numba_signature(node: Apply, force_scalar: bool = False) -> numba.types.Type:
"""Create a Numba type for the signature of an ``Apply`` node."""
input_types = []
for inp in node.inputs:
input_types.append(get_numba_type(inp.type, force_scalar=force_scalar))
output_types = []
for out in node.outputs:
output_types.append(get_numba_type(out.type, force_scalar=force_scalar))
if len(output_types) > 1:
return numba.types.Tuple(output_types)(*input_types)
elif len(output_types) == 1:
return output_types[0](*input_types)
else:
return numba.types.void(*input_types)
def slice_new(self, start, stop, step):
fnty = llvm_Type.function(self.pyobj, [self.pyobj, self.pyobj, self.pyobj])
fn = self._get_function(fnty, name="PySlice_New")
return self.builder.call(fn, [start, stop, step])
def enable_slice_boxing():
"""Enable boxing for Numba's native ``slice``s.
TODO: this can be removed when https://github.com/numba/numba/pull/6939 is
merged and a release is made.
"""
@box(types.SliceType)
def box_slice(typ, val, c):
"""Implement boxing for ``slice`` objects in Numba.
This makes it possible to return an Numba's internal representation of a
``slice`` object as a proper ``slice`` to Python.
"""
start = c.box(types.int64, c.builder.extract_value(val, 0))
stop = c.box(types.int64, c.builder.extract_value(val, 1))
if typ.has_step:
step = c.box(types.int64, c.builder.extract_value(val, 2))
else:
step = c.pyapi.get_null_object()
slice_val = slice_new(c.pyapi, start, stop, step)
return slice_val
@numba.extending.overload(operator.contains)
def in_seq_empty_tuple(x, y):
if isinstance(x, types.Tuple) and not x.types:
return lambda x, y: False
enable_slice_boxing()
@numba.generated_jit(nopython=True)
def to_scalar(x):
if isinstance(x, (numba.types.Number, numba.types.Boolean)):
return lambda x: x
elif isinstance(x, numba.types.Array):
return lambda x: x.item()
else:
raise TypingError(f"{x} must be a scalar compatible type.")
def enable_slice_literals():
"""Enable lowering for ``SliceLiteral``s.
TODO: This can be removed once https://github.com/numba/numba/pull/6996 is merged
and a release is made.
"""
from numba.core import types
from numba.core.datamodel.models import SliceModel
from numba.core.datamodel.registry import register_default
from numba.core.imputils import lower_cast, lower_constant
from numba.core.types.misc import SliceLiteral
from numba.cpython.slicing import get_defaults
register_default(numba.types.misc.SliceLiteral)(SliceModel)
@property
def key(self):
return self.name
SliceLiteral.key = key
def make_slice_from_constant(context, builder, ty, pyval):
sli = context.make_helper(builder, ty)
lty = context.get_value_type(types.intp)
(
default_start_pos,
default_start_neg,
default_stop_pos,
default_stop_neg,
default_step,
) = [context.get_constant(types.intp, x) for x in get_defaults(context)]
step = pyval.step
if step is None:
step_is_neg = False
step = default_step
else:
step_is_neg = step < 0
step = lty(step)
start = pyval.start
if start is None:
if step_is_neg:
start = default_start_neg
else:
start = default_start_pos
else:
start = lty(start)
stop = pyval.stop
if stop is None:
if step_is_neg:
stop = default_stop_neg
else:
stop = default_stop_pos
else:
stop = lty(stop)
sli.start = start
sli.stop = stop
sli.step = step
return sli._getvalue()
@lower_constant(numba.types.SliceType)
def constant_slice(context, builder, ty, pyval):
if isinstance(ty, types.Literal):
typ = ty.literal_type
else:
typ = ty
return make_slice_from_constant(context, builder, typ, pyval)
@lower_cast(numba.types.misc.SliceLiteral, numba.types.SliceType)
def cast_from_literal(context, builder, fromty, toty, val):
return make_slice_from_constant(
context,
builder,
toty,
fromty.literal_value,
)
enable_slice_literals()
def create_tuple_creator(f, n):
"""Construct a compile-time ``tuple``-comprehension-like loop.
See https://github.com/numba/numba/issues/2771#issuecomment-414358902
"""
assert n > 0
f = numba.njit(f)
@numba.njit
def creator(args):
return (f(0, *args),)
for i in range(1, n):
@numba.njit
def creator(args, creator=creator, i=i):
return creator(args) + (f(i, *args),)
return numba.njit(lambda *args: creator(args))
def create_tuple_string(x):
args = ", ".join(x + ([""] if len(x) == 1 else []))
return f"({args})"
@singledispatch
def numba_typify(data, dtype=None, **kwargs):
return data
@singledispatch
def numba_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a Numba compatible function from an Aesara `Op`."""
warnings.warn(
f"Numba will use object mode to run {op}'s perform method",
UserWarning,
)
n_outputs = len(node.outputs)
if n_outputs > 1:
ret_sig = numba.types.Tuple([get_numba_type(o.type) for o in node.outputs])
else:
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
def perform(*inputs):
with numba.objmode(ret=ret_sig):
outputs = [[None] for i in range(n_outputs)]
op.perform(node, inputs, outputs)
outputs = tuple([o[0] for o in outputs])
if n_outputs == 1:
ret = outputs[0]
else:
ret = outputs
return ret
return perform
@numba_funcify.register(FunctionGraph)
def numba_funcify_FunctionGraph(
fgraph,
node=None,
fgraph_name="numba_funcified_fgraph",
**kwargs,
):
return fgraph_to_python(
fgraph,
numba_funcify,
type_conversion_fn=numba_typify,
fgraph_name=fgraph_name,
**kwargs,
)
def create_index_func(node, objmode=False):
"""Create a Python function that assembles and uses an index on an array."""
def convert_indices(indices, entry):
if indices and isinstance(entry, Type):
rval = indices.pop(0)
return rval.auto_name
elif isinstance(entry, slice):
return (
f"slice({convert_indices(indices, entry.start)}, "
f"{convert_indices(indices, entry.stop)}, "
f"{convert_indices(indices, entry.step)})"
)
elif isinstance(entry, type(None)):
return "None"
else:
raise ValueError()
set_or_inc = isinstance(
node.op, (IncSubtensor, AdvancedIncSubtensor1, AdvancedIncSubtensor)
)
index_start_idx = 1 + int(set_or_inc)
unique_names = unique_name_generator(
["subtensor", "incsubtensor", "z"], suffix_sep="_"
)
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
op_indices = list(node.inputs[index_start_idx:])
idx_list = getattr(node.op, "idx_list", None)
indices_creation_src = (
tuple(convert_indices(op_indices, idx) for idx in idx_list)
if idx_list
else tuple(input_names[index_start_idx:])
)
if len(indices_creation_src) == 1:
indices_creation_src = f"indices = ({indices_creation_src[0]},)"
else:
indices_creation_src = ", ".join(indices_creation_src)
indices_creation_src = f"indices = ({indices_creation_src})"
if set_or_inc:
fn_name = "incsubtensor"
if node.op.inplace:
index_prologue = f"z = {input_names[0]}"
else:
index_prologue = f"z = np.copy({input_names[0]})"
if node.inputs[1].ndim == 0:
# TODO FIXME: This is a hack to get around a weird Numba typing
# issue. See https://github.com/numba/numba/issues/6000
y_name = f"{input_names[1]}.item()"
else:
y_name = input_names[1]
if node.op.set_instead_of_inc:
index_body = f"z[indices] = {y_name}"
else:
index_body = f"z[indices] += {y_name}"
else:
fn_name = "subtensor"
index_prologue = ""
index_body = f"z = {input_names[0]}[indices]"
if objmode:
output_var = node.outputs[0]
if not set_or_inc:
# Since `z` is being "created" while in object mode, it's
# considered an "outgoing" variable and needs to be manually typed
output_sig = f"z='{output_var.dtype}[{', '.join([':'] * output_var.ndim)}]'"
else:
output_sig = ""
index_body = f"""
with objmode({output_sig}):
{index_body}
"""
subtensor_def_src = f"""
def {fn_name}({", ".join(input_names)}):
{index_prologue}
{indices_creation_src}
{index_body}
return z
"""
return subtensor_def_src
@numba_funcify.register(Subtensor)
@numba_funcify.register(AdvancedSubtensor)
@numba_funcify.register(AdvancedSubtensor1)
def numba_funcify_Subtensor(op, node, **kwargs):
subtensor_def_src = create_index_func(
node, objmode=isinstance(op, AdvancedSubtensor)
)
global_env = {"np": np, "objmode": numba.objmode}
subtensor_fn = compile_function_src(subtensor_def_src, "subtensor", global_env)
return numba.njit(subtensor_fn)
@numba_funcify.register(IncSubtensor)
@numba_funcify.register(AdvancedIncSubtensor)
@numba_funcify.register(AdvancedIncSubtensor1)
def numba_funcify_IncSubtensor(op, node, **kwargs):
incsubtensor_def_src = create_index_func(
node, objmode=isinstance(op, AdvancedIncSubtensor)
)
global_env = {"np": np, "objmode": numba.objmode}
incsubtensor_fn = compile_function_src(
incsubtensor_def_src, "incsubtensor", global_env
)
return numba.njit(incsubtensor_fn)
@numba_funcify.register(DeepCopyOp)
def numba_funcify_DeepCopyOp(op, node, **kwargs):
# Scalars are apparently returned as actual Python scalar types and not
# NumPy scalars, so we need two separate Numba functions for each case.
if node.outputs[0].type.ndim == 0:
# TODO: Do we really need to compile a pass-through function like this?
@numba.njit(inline="always")
def deepcopyop(x):
return x
else:
@numba.njit(inline="always")
def deepcopyop(x):
return x.copy()
return deepcopyop
@numba_funcify.register(MakeSlice)
def numba_funcify_MakeSlice(op, **kwargs):
@numba.njit
def makeslice(*x):
return slice(*x)
return makeslice
@numba_funcify.register(Shape)
def numba_funcify_Shape(op, **kwargs):
@numba.njit(inline="always")
def shape(x):
return np.asarray(np.shape(x))
return shape
@numba_funcify.register(Shape_i)
def numba_funcify_Shape_i(op, **kwargs):
i = op.i
@numba.njit(inline="always")
def shape_i(x):
return np.shape(x)[i]
return shape_i
@numba.extending.intrinsic
def direct_cast(typingctx, val, typ):
if isinstance(typ, numba.types.TypeRef):
casted = typ.instance_type
elif isinstance(typ, numba.types.DTypeSpec):
casted = typ.dtype
else:
casted = typ
sig = casted(casted, typ)
def codegen(context, builder, signature, args):
val, _ = args
context.nrt.incref(builder, signature.return_type, val)
return val
return sig, codegen
@numba_funcify.register(Reshape)
def numba_funcify_Reshape(op, **kwargs):
ndim = op.ndim
if ndim == 0:
@numba.njit(inline="always")
def reshape(x, shape):
return x.item()
else:
@numba.njit(inline="always")
def reshape(x, shape):
# TODO: Use this until https://github.com/numba/numba/issues/7353 is closed.
return np.reshape(
np.ascontiguousarray(np.asarray(x)),
numba_ndarray.to_fixed_tuple(shape, ndim),
)
return reshape
@numba_funcify.register(SpecifyShape)
def numba_funcify_SpecifyShape(op, **kwargs):
@numba.njit
def specifyshape(x, shape):
assert np.array_equal(x.shape, shape)
return x
return specifyshape
def int_to_float_fn(inputs, out_dtype):
"""Create a Numba function that converts integer and boolean ``ndarray``s to floats."""
if any(i.type.numpy_dtype.kind in "ib" for i in inputs):
args_dtype = np.dtype(f"f{out_dtype.itemsize}")
@numba.njit(inline="always")
def inputs_cast(x):
return x.astype(args_dtype)
else:
@numba.njit(inline="always")
def inputs_cast(x):
return x
return inputs_cast
@numba_funcify.register(Dot)
def numba_funcify_Dot(op, node, **kwargs):
# Numba's `np.dot` does not support integer dtypes, so we need to cast to
# float.
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
def dot(x, y):
return np.asarray(np.dot(inputs_cast(x), inputs_cast(y))).astype(out_dtype)
return dot
@numba_funcify.register(Softplus)
def numba_funcify_Softplus(op, node, **kwargs):
x_dtype = np.dtype(node.inputs[0].dtype)
@numba.njit
def softplus(x):
if x < -37.0:
return direct_cast(np.exp(x), x_dtype)
elif x < 18.0:
return direct_cast(np.log1p(np.exp(x)), x_dtype)
elif x < 33.3:
return direct_cast(x + np.exp(-x), x_dtype)
else:
return direct_cast(x, x_dtype)
return softplus
@numba_funcify.register(Cholesky)
def numba_funcify_Cholesky(op, node, **kwargs):
lower = op.lower
out_dtype = node.outputs[0].type.numpy_dtype
if lower:
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
def cholesky(a):
return np.linalg.cholesky(inputs_cast(a)).astype(out_dtype)
else:
# TODO: Use SciPy's BLAS/LAPACK Cython wrappers.
warnings.warn(
(
"Numba will use object mode to allow the "
"`lower` argument to `scipy.linalg.cholesky`."
),
UserWarning,
)
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
def cholesky(a):
with numba.objmode(ret=ret_sig):
ret = scipy.linalg.cholesky(a, lower=lower).astype(out_dtype)
return ret
return cholesky
@numba_funcify.register(Solve)
def numba_funcify_Solve(op, node, **kwargs):
assume_a = op.assume_a
# check_finite = op.check_finite
if assume_a != "gen":
lower = op.lower
warnings.warn(
(
"Numba will use object mode to allow the "
"`compute_uv` argument to `numpy.linalg.svd`."
),
UserWarning,
)
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
def solve(a, b):
with numba.objmode(ret=ret_sig):
ret = scipy.linalg.solve_triangular(
a,
b,
lower=lower,
# check_finite=check_finite
)
return ret
else:
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
def solve(a, b):
return np.linalg.solve(
inputs_cast(a),
inputs_cast(b),
# assume_a=assume_a,
# check_finite=check_finite,
).astype(out_dtype)
return solve
@numba_funcify.register(BatchedDot)
def numba_funcify_BatchedDot(op, node, **kwargs):
dtype = node.outputs[0].type.numpy_dtype
@numba.njit
def batched_dot(x, y):
shape = x.shape[:-1] + y.shape[2:]
z0 = np.empty(shape, dtype=dtype)
for i in range(z0.shape[0]):
z0[i] = np.dot(x[i], y[i])
return z0
return batched_dot
# NOTE: The remaining `aesara.tensor.blas` `Op`s appear unnecessary, because
# they're only used to optimize basic `Dot` nodes, and those GEMV and GEMM
# optimizations are apparently already performed by Numba
import inspect
from numbers import Number
from textwrap import indent
from typing import Union
import numba
import numpy as np
from numba.cpython.unsafe.tuple import tuple_setitem
from aesara.graph.basic import Apply
from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.dispatch.basic import (
create_numba_signature,
create_tuple_creator,
numba_funcify,
)
from aesara.link.utils import (
compile_function_src,
get_name_for_object,
unique_name_generator,
)
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.math import MaxAndArgmax
from aesara.tensor.nnet.basic import LogSoftmax, Softmax
from aesara.tensor.type import tensor
def create_vectorize_func(op, node, use_signature=False, identity=None, **kwargs):
scalar_op_fn = numba_funcify(op.scalar_op, node=node, inline="always", **kwargs)
if len(node.outputs) > 1:
raise NotImplementedError(
"Multi-output Elemwise Ops are not supported by the Numba backend"
)
if use_signature:
signature = [create_numba_signature(node, force_scalar=True)]
else:
signature = []
numba_vectorize = numba.vectorize(signature, identity=identity)
elemwise_fn = numba_vectorize(scalar_op_fn)
elemwise_fn.py_scalar_func = scalar_op_fn
return elemwise_fn
@numba_funcify.register(Elemwise)
def numba_funcify_Elemwise(op, node, **kwargs):
elemwise_fn = create_vectorize_func(op, node, use_signature=False)
elemwise_fn_name = elemwise_fn.__name__
if op.inplace_pattern:
input_idx = op.inplace_pattern[0]
sign_obj = inspect.signature(elemwise_fn.py_scalar_func)
input_names = list(sign_obj.parameters.keys())
unique_names = unique_name_generator([elemwise_fn_name, "np"], suffix_sep="_")
input_names = [unique_names(i, force_unique=True) for i in input_names]
updated_input_name = input_names[input_idx]
inplace_global_env = {elemwise_fn_name: elemwise_fn, "np": np}
inplace_elemwise_fn_name = f"{elemwise_fn_name}_inplace"
input_signature_str = ", ".join(input_names)
if node.inputs[input_idx].ndim > 0:
inplace_elemwise_src = f"""
def {inplace_elemwise_fn_name}({input_signature_str}):
return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name})
"""
else:
# We can't perform in-place updates on Numba scalars, so we need to
# convert them to NumPy scalars.
# TODO: We should really prevent the rewrites from creating
# in-place updates on scalars when the Numba mode is selected (or
# in general?).
inplace_elemwise_src = f"""
def {inplace_elemwise_fn_name}({input_signature_str}):
{updated_input_name}_scalar = np.asarray({updated_input_name})
return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name}_scalar).item()
"""
inplace_elemwise_fn = compile_function_src(
inplace_elemwise_src, inplace_elemwise_fn_name, inplace_global_env
)
return numba.njit(inline="always")(inplace_elemwise_fn)
return elemwise_fn
def create_axis_reducer(
reduce_fn: numba.np.ufunc.dufunc.DUFunc,
identity: Union[np.ndarray, Number],
axis: int,
ndim: int,
dtype: numba.types.Type,
keepdims: bool = False,
) -> numba.core.dispatcher.Dispatcher:
r"""Create a Numba JITed function that performs a NumPy reduction on a given axis.
The functions generated by this function take the following form:
.. code-block:: python
def careduce_axis(x):
res_shape = tuple(shape[i] if i < axis else shape[i + 1] for i in range(ndim - 1))
res = np.full(res_shape, identity, dtype=dtype)
x_axis_first = x.transpose(reaxis_first)
for m in range(x.shape[axis]):
reduce_fn(res, x_axis_first[m], res)
if keepdims:
return np.expand_dims(res, axis)
else:
return res
This can be removed/replaced when
https://github.com/numba/numba/issues/4504 is implemented.
Parameters
==========
reduce_fn:
The Numba ``ufunc`` representing a binary op that can perform the
reduction on arbitrary ``ndarray``\s.
identity:
The identity value for the reduction.
axis:
The axis to reduce.
ndim:
The number of dimensions of the result.
dtype:
The data type of the result.
keepdims:
Determines whether or not the reduced dimension is retained.
"""
if ndim > 1:
if keepdims:
@numba.njit(inline="always")
def set_out_dims(x):
return np.expand_dims(x, axis)
else:
@numba.njit(inline="always")
def set_out_dims(x):
return x
res_shape_tuple_ctor = create_tuple_creator(
lambda i, shape: shape[i] if i < axis else shape[i + 1], ndim - 1
)
reaxis_first = (axis,) + tuple(i for i in range(ndim) if i != axis)
@numba.njit(boundscheck=False)
def careduce_axis(x):
res_shape = res_shape_tuple_ctor(x.shape)
x_axis_first = x.transpose(reaxis_first)
res = np.full(res_shape, numba_basic.to_scalar(identity), dtype=dtype)
for m in range(x.shape[axis]):
reduce_fn(res, x_axis_first[m], res)
return set_out_dims(res)
else:
if keepdims:
@numba.njit(inline="always")
def set_out_dims(x):
return np.array([x], dtype)
else:
@numba.njit(inline="always")
def set_out_dims(x):
return numba_basic.direct_cast(x, dtype)
@numba.njit(boundscheck=False)
def careduce_axis(x):
res = numba_basic.to_scalar(identity)
for val in x:
res = reduce_fn(res, val)
return set_out_dims(res)
return careduce_axis
def create_multiaxis_reducer(
reduce_fn, identity, axes, ndim, dtype, input_name="input"
):
r"""Construct a function that reduces multiple axes.
The functions generated by this function take the following form:
.. code-block:: python
def careduce_maximum(input):
axis_0_res = careduce_axes_fn_0(input)
axis_1_res = careduce_axes_fn_1(axis_0_res)
...
axis_N_res = careduce_axes_fn_N(axis_N_minus_1_res)
return axis_N_res
The range 0-N is determined by the `axes` argument (i.e. the
axes to be reduced).
Parameters
==========
reduce_fn:
The Numba ``ufunc`` representing a binary op that can perform the
reduction on arbitrary ``ndarray``\s.
identity:
The identity value for the reduction.
axes:
The axes to reduce.
ndim:
The number of dimensions of the result.
dtype:
The data type of the result.
"""
if len(axes) == 1:
return create_axis_reducer(reduce_fn, identity, axes[0], ndim, dtype)
careduce_fn_name = f"careduce_{get_name_for_object(reduce_fn)}"
global_env = {}
to_reduce = reversed(sorted(axes))
careduce_lines_src = []
var_name = input_name
for i, axis in enumerate(to_reduce):
careducer_axes_fn_name = f"careduce_axes_fn_{i}"
global_env[careducer_axes_fn_name] = create_axis_reducer(
reduce_fn, identity, axis - i, ndim, dtype
)
ndim -= 1
last_var_name = var_name
var_name = f"axis_{i}_res"
careduce_lines_src.append(
f"{var_name} = {careducer_axes_fn_name}({last_var_name})"
)
careduce_assign_lines = indent("\n".join(careduce_lines_src), " " * 4)
careduce_def_src = f"""
def {careduce_fn_name}({input_name}):
{careduce_assign_lines}
return {var_name}
"""
careduce_fn = compile_function_src(careduce_def_src, careduce_fn_name, global_env)
return numba.njit(careduce_fn)
def create_axis_apply_fn(fn, axis, ndim, dtype):
reaxis_first = tuple(i for i in range(ndim) if i != axis) + (axis,)
@numba.njit(boundscheck=False)
def axis_apply_fn(x):
x_reaxis = x.transpose(reaxis_first)
res = np.zeros(x_reaxis.shape[:-1], dtype=dtype)
for m in np.ndindex(res.shape):
v = fn(x_reaxis[m])
res[m] = v
return res
return axis_apply_fn
@numba_funcify.register(CAReduce)
def numba_funcify_CAReduce(op, node, **kwargs):
axes = op.axis
if axes is None:
axes = list(range(node.inputs[0].ndim))
if hasattr(op, "acc_dtype") and op.acc_dtype is not None:
acc_dtype = op.acc_dtype
else:
acc_dtype = node.outputs[0].type.dtype
np_acc_dtype = np.dtype(acc_dtype)
scalar_op_identity = np.asarray(op.scalar_op.identity, dtype=np_acc_dtype)
scalar_nfunc_spec = op.scalar_op.nfunc_spec
# We construct a dummy `Apply` that has the minimum required number of
# inputs for the scalar `Op`. Without this, we would get a scalar function
# with too few arguments.
dummy_node = Apply(
op,
[tensor(np_acc_dtype, [False]) for i in range(scalar_nfunc_spec[1])],
[tensor(np_acc_dtype, [False]) for o in range(scalar_nfunc_spec[2])],
)
# TODO: Use `scalar_op_identity`?
elemwise_fn = create_vectorize_func(op, dummy_node, use_signature=True, **kwargs)
input_name = get_name_for_object(node.inputs[0])
ndim = node.inputs[0].ndim
careduce_fn = create_multiaxis_reducer(
elemwise_fn, scalar_op_identity, axes, ndim, np_acc_dtype, input_name=input_name
)
return careduce_fn
@numba_funcify.register(DimShuffle)
def numba_funcify_DimShuffle(op, **kwargs):
shuffle = tuple(op.shuffle)
drop = tuple(op.drop)
augment = tuple(op.augment)
inplace = op.inplace
ndim_new_shape = len(shuffle) + len(augment)
create_zeros_tuple = numba_basic.create_tuple_creator(lambda _: 0, ndim_new_shape)
if len(shuffle) > 0:
@numba.njit
def populate_new_shape(i, j, new_shape, shuffle_shape):
if i in augment:
new_shape = tuple_setitem(new_shape, i, 1)
return j, new_shape
else:
new_shape = tuple_setitem(new_shape, i, shuffle_shape[j])
return j + 1, new_shape
else:
# When `len(shuffle) == 0`, the `shuffle_shape[j]` expression above is
# is typed as `getitem(Tuple(), int)`, which has no implementation
# (since getting an item from an empty sequence doesn't make sense).
# To avoid this compile-time error, we omit the expression altogether.
@numba.njit(inline="always")
def populate_new_shape(i, j, new_shape, shuffle_shape):
return j, tuple_setitem(new_shape, i, 1)
@numba.njit
def dimshuffle_inner(x, shuffle):
res = np.transpose(x, shuffle + drop)
shuffle_shape = res.shape[: len(shuffle)]
new_shape = create_zeros_tuple()
j = 0
for i in range(len(new_shape)):
j, new_shape = populate_new_shape(i, j, new_shape, shuffle_shape)
# FIXME: Numba's `array.reshape` only accepts C arrays.
res_reshape = np.reshape(np.ascontiguousarray(res), new_shape)
if not inplace:
return res_reshape.copy()
else:
return res_reshape
# Without the following wrapper function we would see this error:
# E No implementation of function Function(<built-in function getitem>) found for signature:
# E
# E >>> getitem(UniTuple(int64 x 2), slice<a:b>)
# E
# E There are 22 candidate implementations:
# E - Of which 22 did not match due to:
# E Overload of function 'getitem': File: <numerous>: Line N/A.
# E With argument(s): '(UniTuple(int64 x 2), slice<a:b>)':
# E No match.
# ...(on this line)...
# E shuffle_shape = res.shape[: len(shuffle)]
@numba.njit(inline="always")
def dimshuffle(x):
return dimshuffle_inner(np.asarray(x), shuffle)
return dimshuffle
@numba_funcify.register(Softmax)
def numba_funcify_Softmax(op, node, **kwargs):
x_at = node.inputs[0]
x_dtype = x_at.type.numpy_dtype
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
# np.max(x, axis=1)
reduce_max = create_axis_reducer(np.maximum, -np.inf, 1, x_at.ndim, x_dtype)
# np.sum(x, axis=1)
reduce_sum = create_axis_reducer(np.add, 0.0, 1, x_at.ndim, x_dtype)
@numba.njit
def softmax(x):
z = np.expand_dims(reduce_max(x), -1)
e_x = np.exp(x - z)
w = np.expand_dims(reduce_sum(e_x), -1)
sm = e_x / w
return sm
return softmax
@numba_funcify.register(LogSoftmax)
def numba_funcify_LogSoftmax(op, node, **kwargs):
x_at = node.inputs[0]
x_dtype = x_at.type.numpy_dtype
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
# np.max(x, axis=1)
reduce_max = create_axis_reducer(np.maximum, -np.inf, 1, x_at.ndim, x_dtype)
# np.sum(x, axis=1, keepdims=True)
reduce_sum = create_axis_reducer(np.add, 0.0, 1, x_at.ndim, x_dtype, keepdims=True)
@numba.njit
def log_softmax(x):
xdev = x - np.expand_dims(reduce_max(x), -1)
lsm = xdev - np.log(reduce_sum(np.exp(xdev)))
return lsm
return log_softmax
@numba_funcify.register(MaxAndArgmax)
def numba_funcify_MaxAndArgmax(op, node, **kwargs):
axis = op.axis
x_at = node.inputs[0]
x_dtype = x_at.type.numpy_dtype
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
x_ndim = x_at.ndim
if x_ndim == 0:
@numba.njit(inline="always")
def maxandargmax(x):
return x, 0
else:
axes = tuple(int(ax) for ax in axis)
# NumPy does not support multiple axes for argmax; this is a
# work-around
keep_axes = tuple(i for i in range(x_ndim) if i not in axes)
reduce_max = create_multiaxis_reducer(
np.maximum, -np.inf, axes, x_ndim, x_dtype
)
reduced_x_ndim = x_ndim - len(axes) + 1
argmax_axis = create_axis_apply_fn(
np.argmax, reduced_x_ndim - 1, reduced_x_ndim, np.int64
)
reaxis_order = keep_axes + axes
sl1 = slice(None, len(keep_axes))
sl2 = slice(len(keep_axes), None)
@numba.njit
def maxandargmax(x):
max_res = reduce_max(x)
# Not-reduced axes in front
transposed_x = np.ascontiguousarray(np.transpose(x, reaxis_order))
kept_shape = transposed_x.shape[sl1]
reduced_shape = transposed_x.shape[sl2]
reduced_size = 1
for s in reduced_shape:
reduced_size *= s
# Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
# Otherwise reshape would complain citing float arg
new_shape = kept_shape + (reduced_size,)
reshaped_x = transposed_x.reshape(new_shape)
max_idx_res = argmax_axis(reshaped_x)
return max_res, max_idx_res
return maxandargmax
import warnings
import numba
import numpy as np
from numpy.core.multiarray import normalize_axis_index
from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.dispatch.basic import get_numba_type, numba_funcify
from aesara.tensor.extra_ops import (
Bartlett,
CumOp,
DiffOp,
FillDiagonal,
FillDiagonalOffset,
RavelMultiIndex,
Repeat,
SearchsortedOp,
Unique,
UnravelIndex,
)
@numba_funcify.register(Bartlett)
def numba_funcify_Bartlett(op, **kwargs):
@numba.njit(inline="always")
def bartlett(x):
return np.bartlett(numba_basic.to_scalar(x))
return bartlett
@numba_funcify.register(CumOp)
def numba_funcify_CumOp(op, node, **kwargs):
axis = op.axis
mode = op.mode
ndim = node.outputs[0].ndim
reaxis_first = (axis,) + tuple(i for i in range(ndim) if i != axis)
if mode == "add":
np_func = np.add
identity = 0
else:
np_func = np.multiply
identity = 1
@numba.njit(boundscheck=False)
def cumop(x):
out_dtype = x.dtype
if x.shape[axis] < 2:
return x.astype(out_dtype)
x_axis_first = x.transpose(reaxis_first)
res = np.empty(x_axis_first.shape, dtype=out_dtype)
for m in range(x.shape[axis]):
if m == 0:
np_func(identity, x_axis_first[m], res[m])
else:
np_func(res[m - 1], x_axis_first[m], res[m])
return res.transpose(reaxis_first)
return cumop
@numba_funcify.register(DiffOp)
def numba_funcify_DiffOp(op, node, **kwargs):
n = op.n
axis = op.axis
ndim = node.inputs[0].ndim
dtype = node.outputs[0].dtype
axis = normalize_axis_index(axis, ndim)
slice1 = [slice(None)] * ndim
slice2 = [slice(None)] * ndim
slice1[axis] = slice(1, None)
slice2[axis] = slice(None, -1)
slice1 = tuple(slice1)
slice2 = tuple(slice2)
op = np.not_equal if dtype == "bool" else np.subtract
@numba.njit(boundscheck=False)
def diffop(x):
res = x.copy()
for _ in range(n):
res = op(res[slice1], res[slice2])
return res
return diffop
@numba_funcify.register(FillDiagonal)
def numba_funcify_FillDiagonal(op, **kwargs):
@numba.njit
def filldiagonal(a, val):
np.fill_diagonal(a, val)
return a
return filldiagonal
@numba_funcify.register(FillDiagonalOffset)
def numba_funcify_FillDiagonalOffset(op, node, **kwargs):
@numba.njit
def filldiagonaloffset(a, val, offset):
height, width = a.shape
if offset >= 0:
start = numba_basic.to_scalar(offset)
num_of_step = min(min(width, height), width - offset)
else:
start = -numba_basic.to_scalar(offset) * a.shape[1]
num_of_step = min(min(width, height), height + offset)
step = a.shape[1] + 1
end = start + step * num_of_step
b = a.ravel()
b[start:end:step] = val
# TODO: This isn't implemented in Numba
# a.flat[start:end:step] = val
# return a
return b.reshape(a.shape)
return filldiagonaloffset
@numba_funcify.register(RavelMultiIndex)
def numba_funcify_RavelMultiIndex(op, node, **kwargs):
mode = op.mode
order = op.order
if order != "C":
raise NotImplementedError(
"Numba does not implement `order` in `numpy.ravel_multi_index`"
)
if mode == "raise":
@numba.njit
def mode_fn(*args):
raise ValueError("invalid entry in coordinates array")
elif mode == "wrap":
@numba.njit(inline="always")
def mode_fn(new_arr, i, j, v, d):
new_arr[i, j] = v % d
elif mode == "clip":
@numba.njit(inline="always")
def mode_fn(new_arr, i, j, v, d):
new_arr[i, j] = min(max(v, 0), d - 1)
if node.inputs[0].ndim == 0:
@numba.njit
def ravelmultiindex(*inp):
shape = inp[-1]
arr = np.stack(inp[:-1])
new_arr = arr.T.astype(np.float64).copy()
for i, b in enumerate(new_arr):
if b < 0 or b >= shape[i]:
mode_fn(new_arr, i, 0, b, shape[i])
a = np.ones(len(shape), dtype=np.float64)
a[: len(shape) - 1] = np.cumprod(shape[-1:0:-1])[::-1]
return np.array(a.dot(new_arr.T), dtype=np.int64)
else:
@numba.njit
def ravelmultiindex(*inp):
shape = inp[-1]
arr = np.stack(inp[:-1])
new_arr = arr.T.astype(np.float64).copy()
for i, b in enumerate(new_arr):
for j, (d, v) in enumerate(zip(shape, b)):
if v < 0 or v >= d:
mode_fn(new_arr, i, j, v, d)
a = np.ones(len(shape), dtype=np.float64)
a[: len(shape) - 1] = np.cumprod(shape[-1:0:-1])[::-1]
return a.dot(new_arr.T).astype(np.int64)
return ravelmultiindex
@numba_funcify.register(Repeat)
def numba_funcify_Repeat(op, node, **kwargs):
axis = op.axis
use_python = False
if axis is not None:
use_python = True
if use_python:
warnings.warn(
(
"Numba will use object mode to allow the "
"`axis` argument to `numpy.repeat`."
),
UserWarning,
)
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
def repeatop(x, repeats):
with numba.objmode(ret=ret_sig):
ret = np.repeat(x, repeats, axis)
return ret
else:
repeats_ndim = node.inputs[1].ndim
if repeats_ndim == 0:
@numba.njit(inline="always")
def repeatop(x, repeats):
return np.repeat(x, repeats.item())
else:
@numba.njit(inline="always")
def repeatop(x, repeats):
return np.repeat(x, repeats)
return repeatop
@numba_funcify.register(Unique)
def numba_funcify_Unique(op, node, **kwargs):
axis = op.axis
use_python = False
if axis is not None:
use_python = True
return_index = op.return_index
return_inverse = op.return_inverse
return_counts = op.return_counts
returns_multi = return_index or return_inverse or return_counts
use_python |= returns_multi
if not use_python:
@numba.njit(inline="always")
def unique(x):
return np.unique(x)
else:
warnings.warn(
(
"Numba will use object mode to allow the "
"`axis` and/or `return_*` arguments to `numpy.unique`."
),
UserWarning,
)
if returns_multi:
ret_sig = numba.types.Tuple([get_numba_type(o.type) for o in node.outputs])
else:
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
def unique(x):
with numba.objmode(ret=ret_sig):
ret = np.unique(x, return_index, return_inverse, return_counts, axis)
return ret
return unique
@numba_funcify.register(UnravelIndex)
def numba_funcify_UnravelIndex(op, node, **kwargs):
order = op.order
if order != "C":
raise NotImplementedError(
"Numba does not support the `order` argument in `numpy.unravel_index`"
)
if len(node.outputs) == 1:
@numba.njit(inline="always")
def maybe_expand_dim(arr):
return arr
else:
@numba.njit(inline="always")
def maybe_expand_dim(arr):
return np.expand_dims(arr, 1)
@numba.njit
def unravelindex(arr, shape):
a = np.ones(len(shape), dtype=np.int64)
a[1:] = shape[:0:-1]
a = np.cumprod(a)[::-1]
# Aesara actually returns a `tuple` of these values, instead of an
# `ndarray`; however, this `ndarray` result should be able to be
# unpacked into a `tuple`, so this discrepancy shouldn't really matter
return ((maybe_expand_dim(arr) // a) % shape).T
return unravelindex
@numba_funcify.register(SearchsortedOp)
def numba_funcify_Searchsorted(op, node, **kwargs):
side = op.side
use_python = False
if len(node.inputs) == 3:
use_python = True
if use_python:
warnings.warn(
(
"Numba will use object mode to allow the "
"`sorter` argument to `numpy.searchsorted`."
),
UserWarning,
)
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
def searchsorted(a, v, sorter):
with numba.objmode(ret=ret_sig):
ret = np.searchsorted(a, v, side, sorter)
return ret
else:
@numba.njit(inline="always")
def searchsorted(a, v):
return np.searchsorted(a, v, side)
return searchsorted
import warnings
import numba
import numpy as np
from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.dispatch.basic import (
get_numba_type,
int_to_float_fn,
numba_funcify,
)
from aesara.tensor.nlinalg import (
SVD,
Det,
Eig,
Eigh,
Inv,
MatrixInverse,
MatrixPinv,
QRFull,
)
@numba_funcify.register(SVD)
def numba_funcify_SVD(op, node, **kwargs):
full_matrices = op.full_matrices
compute_uv = op.compute_uv
if not compute_uv:
warnings.warn(
(
"Numba will use object mode to allow the "
"`compute_uv` argument to `numpy.linalg.svd`."
),
UserWarning,
)
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
def svd(x):
with numba.objmode(ret=ret_sig):
ret = np.linalg.svd(x, full_matrices, compute_uv)
return ret
else:
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
def svd(x):
return np.linalg.svd(inputs_cast(x), full_matrices)
return svd
@numba_funcify.register(Det)
def numba_funcify_Det(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
def det(x):
return numba_basic.direct_cast(np.linalg.det(inputs_cast(x)), out_dtype)
return det
@numba_funcify.register(Eig)
def numba_funcify_Eig(op, node, **kwargs):
out_dtype_1 = node.outputs[0].type.numpy_dtype
out_dtype_2 = node.outputs[1].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype_1)
@numba.njit
def eig(x):
out = np.linalg.eig(inputs_cast(x))
return (out[0].astype(out_dtype_1), out[1].astype(out_dtype_2))
return eig
@numba_funcify.register(Eigh)
def numba_funcify_Eigh(op, node, **kwargs):
uplo = op.UPLO
if uplo != "L":
warnings.warn(
(
"Numba will use object mode to allow the "
"`UPLO` argument to `numpy.linalg.eigh`."
),
UserWarning,
)
out_dtypes = tuple(o.type.numpy_dtype for o in node.outputs)
ret_sig = numba.types.Tuple(
[get_numba_type(node.outputs[0].type), get_numba_type(node.outputs[1].type)]
)
@numba.njit
def eigh(x):
with numba.objmode(ret=ret_sig):
out = np.linalg.eigh(x, UPLO=uplo)
ret = (out[0].astype(out_dtypes[0]), out[1].astype(out_dtypes[1]))
return ret
else:
@numba.njit(inline="always")
def eigh(x):
return np.linalg.eigh(x)
return eigh
@numba_funcify.register(Inv)
def numba_funcify_Inv(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
def inv(x):
return np.linalg.inv(inputs_cast(x)).astype(out_dtype)
return inv
@numba_funcify.register(MatrixInverse)
def numba_funcify_MatrixInverse(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
def matrix_inverse(x):
return np.linalg.inv(inputs_cast(x)).astype(out_dtype)
return matrix_inverse
@numba_funcify.register(MatrixPinv)
def numba_funcify_MatrixPinv(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
def matrixpinv(x):
return np.linalg.pinv(inputs_cast(x)).astype(out_dtype)
return matrixpinv
@numba_funcify.register(QRFull)
def numba_funcify_QRFull(op, node, **kwargs):
mode = op.mode
if mode != "reduced":
warnings.warn(
(
"Numba will use object mode to allow the "
"`mode` argument to `numpy.linalg.qr`."
),
UserWarning,
)
if len(node.outputs) > 1:
ret_sig = numba.types.Tuple([get_numba_type(o.type) for o in node.outputs])
else:
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
def qr_full(x):
with numba.objmode(ret=ret_sig):
ret = np.linalg.qr(x, mode=mode)
return ret
else:
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
def qr_full(x):
return np.linalg.qr(inputs_cast(x))
return qr_full
from textwrap import dedent, indent
from typing import Any, Callable, Dict, Optional
import numba
import numba.np.unsafe.ndarray as numba_ndarray
import numpy as np
from numba import _helperlib
from numpy.random import RandomState
import aesara.tensor.random.basic as aer
from aesara.graph.basic import Apply
from aesara.graph.op import Op
from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.dispatch.basic import numba_funcify, numba_typify
from aesara.link.utils import (
compile_function_src,
get_name_for_object,
unique_name_generator,
)
from aesara.tensor.basic import get_vector_length
from aesara.tensor.random.type import RandomStateType
from aesara.tensor.random.var import RandomStateSharedVariable
@numba_typify.register(RandomState)
def numba_typify_RandomState(state, **kwargs):
ints, index = state.get_state()[1:3]
ptr = _helperlib.rnd_get_np_state_ptr()
_helperlib.rnd_set_state(ptr, (index, [int(x) for x in ints]))
return ints
def make_numba_random_fn(node, np_random_func):
"""Create Numba implementations for existing Numba-supported ``np.random`` functions.
The functions generated here add parameter broadcasting and the ``size``
argument to the Numba-supported scalar ``np.random`` functions.
"""
tuple_size = get_vector_length(node.inputs[1])
size_dims = tuple_size - max(i.ndim for i in node.inputs[3:])
# Make a broadcast-capable version of the Numba supported scalar sampling
# function
bcast_fn_name = f"aesara_random_{get_name_for_object(np_random_func)}"
sized_fn_name = "sized_random_variable"
unique_names = unique_name_generator(
[
bcast_fn_name,
sized_fn_name,
"np",
"np_random_func",
"numba_vectorize",
"to_fixed_tuple",
"tuple_size",
"size_dims",
"rng",
"size",
"dtype",
],
suffix_sep="_",
)
bcast_fn_input_names = ", ".join(
[unique_names(i, force_unique=True) for i in node.inputs[3:]]
)
bcast_fn_global_env = {
"np_random_func": np_random_func,
"numba_vectorize": numba.vectorize,
}
bcast_fn_src = f"""
@numba_vectorize
def {bcast_fn_name}({bcast_fn_input_names}):
return np_random_func({bcast_fn_input_names})
"""
bcast_fn = compile_function_src(bcast_fn_src, bcast_fn_name, bcast_fn_global_env)
random_fn_input_names = ", ".join(
["rng", "size", "dtype"] + [unique_names(i) for i in node.inputs[3:]]
)
# Now, create a Numba JITable function that implements the `size` parameter
out_dtype = node.outputs[1].type.numpy_dtype
random_fn_global_env = {
bcast_fn_name: bcast_fn,
"out_dtype": out_dtype,
}
if tuple_size > 0:
random_fn_body = dedent(
f"""
size = to_fixed_tuple(size, tuple_size)
data = np.empty(size, dtype=out_dtype)
for i in np.ndindex(size[:size_dims]):
data[i] = {bcast_fn_name}({bcast_fn_input_names})
"""
)
random_fn_global_env.update(
{
"np": np,
"to_fixed_tuple": numba_ndarray.to_fixed_tuple,
"tuple_size": tuple_size,
"size_dims": size_dims,
}
)
else:
random_fn_body = f"""data = {bcast_fn_name}({bcast_fn_input_names})"""
sized_fn_src = dedent(
f"""
def {sized_fn_name}({random_fn_input_names}):
{indent(random_fn_body, " " * 4)}
return (rng, data)
"""
)
random_fn = compile_function_src(sized_fn_src, sized_fn_name, random_fn_global_env)
random_fn = numba.njit(random_fn)
return random_fn
@numba_funcify.register(aer.UniformRV)
@numba_funcify.register(aer.TriangularRV)
@numba_funcify.register(aer.BetaRV)
@numba_funcify.register(aer.NormalRV)
@numba_funcify.register(aer.LogNormalRV)
@numba_funcify.register(aer.GammaRV)
@numba_funcify.register(aer.ChiSquareRV)
@numba_funcify.register(aer.ParetoRV)
@numba_funcify.register(aer.GumbelRV)
@numba_funcify.register(aer.ExponentialRV)
@numba_funcify.register(aer.WeibullRV)
@numba_funcify.register(aer.LogisticRV)
@numba_funcify.register(aer.VonMisesRV)
@numba_funcify.register(aer.PoissonRV)
@numba_funcify.register(aer.GeometricRV)
@numba_funcify.register(aer.HyperGeometricRV)
@numba_funcify.register(aer.CauchyRV)
@numba_funcify.register(aer.WaldRV)
@numba_funcify.register(aer.LaplaceRV)
@numba_funcify.register(aer.BinomialRV)
@numba_funcify.register(aer.NegBinomialRV)
@numba_funcify.register(aer.MultinomialRV)
@numba_funcify.register(aer.RandIntRV) # only the first two arguments are supported
@numba_funcify.register(aer.ChoiceRV) # the `p` argument is not supported
@numba_funcify.register(aer.PermutationRV)
def numba_funcify_RandomVariable(op, node, **kwargs):
name = op.name
np_random_func = getattr(np.random, name)
if not isinstance(node.inputs[0], (RandomStateType, RandomStateSharedVariable)):
raise TypeError("Numba does not support NumPy `Generator`s")
return make_numba_random_fn(node, np_random_func)
def create_numba_random_fn(
op: Op,
node: Apply,
scalar_fn: Callable[[str], str],
global_env: Optional[Dict[str, Any]] = None,
) -> Callable:
"""Create a vectorized function from a callable that generates the ``str`` function body.
TODO: This could/should be generalized for other simple function
construction cases that need unique-ified symbol names.
"""
np_random_fn_name = f"aesara_random_{get_name_for_object(op.name)}"
if global_env:
np_global_env = global_env.copy()
else:
np_global_env = {}
np_global_env["np"] = np
np_global_env["numba_vectorize"] = numba.vectorize
unique_names = unique_name_generator(
[
np_random_fn_name,
]
+ list(np_global_env.keys())
+ [
"rng",
"size",
"dtype",
],
suffix_sep="_",
)
np_names = [unique_names(i, force_unique=True) for i in node.inputs[3:]]
np_input_names = ", ".join(np_names)
np_random_fn_src = f"""
@numba_vectorize
def {np_random_fn_name}({np_input_names}):
{scalar_fn(*np_names)}
"""
np_random_fn = compile_function_src(
np_random_fn_src, np_random_fn_name, np_global_env
)
return make_numba_random_fn(node, np_random_fn)
@numba_funcify.register(aer.HalfNormalRV)
def numba_funcify_HalfNormalRV(op, node, **kwargs):
def body_fn(a, b):
return f" return {a} + {b} * abs(np.random.normal(0, 1))"
return create_numba_random_fn(op, node, body_fn)
@numba_funcify.register(aer.BernoulliRV)
def numba_funcify_BernoulliRV(op, node, **kwargs):
out_dtype = node.outputs[1].type.numpy_dtype
def body_fn(a):
return f"""
if {a} < np.random.uniform(0, 1):
return direct_cast(0, out_dtype)
else:
return direct_cast(1, out_dtype)
"""
return create_numba_random_fn(
op,
node,
body_fn,
{"out_dtype": out_dtype, "direct_cast": numba_basic.direct_cast},
)
from functools import reduce
from typing import List
import numba
import numpy as np
import scipy
import scipy.special
from aesara.compile.ops import ViewOp
from aesara.graph.basic import Variable
from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.dispatch.basic import create_numba_signature, numba_funcify
from aesara.link.utils import (
compile_function_src,
get_name_for_object,
unique_name_generator,
)
from aesara.scalar.basic import (
Add,
Cast,
Clip,
Composite,
Identity,
Mul,
ScalarOp,
Second,
Switch,
)
@numba_funcify.register(ScalarOp)
def numba_funcify_ScalarOp(op, node, **kwargs):
# TODO: Do we need to cache these functions so that we don't end up
# compiling the same Numba function over and over again?
scalar_func_name = op.nfunc_spec[0]
if scalar_func_name.startswith("scipy."):
func_package = scipy
scalar_func_name = scalar_func_name.split(".", 1)[-1]
else:
func_package = np
if "." in scalar_func_name:
scalar_func = reduce(getattr, [scipy] + scalar_func_name.split("."))
else:
scalar_func = getattr(func_package, scalar_func_name)
scalar_op_fn_name = get_name_for_object(scalar_func)
unique_names = unique_name_generator(
[scalar_op_fn_name, "scalar_func"], suffix_sep="_"
)
input_names = ", ".join([unique_names(v, force_unique=True) for v in node.inputs])
global_env = {"scalar_func": scalar_func}
scalar_op_src = f"""
def {scalar_op_fn_name}({input_names}):
return scalar_func({input_names})
"""
scalar_op_fn = compile_function_src(scalar_op_src, scalar_op_fn_name, global_env)
signature = create_numba_signature(node, force_scalar=True)
return numba.njit(signature, inline="always")(scalar_op_fn)
@numba_funcify.register(Switch)
def numba_funcify_Switch(op, node, **kwargs):
@numba.njit(inline="always")
def switch(condition, x, y):
if condition:
return x
else:
return y
return switch
def binary_to_nary_func(inputs: List[Variable], binary_op_name: str, binary_op: str):
"""Create a Numba-compatible N-ary function from a binary function."""
unique_names = unique_name_generator(["binary_op_name"], suffix_sep="_")
input_names = [unique_names(v, force_unique=True) for v in inputs]
input_signature = ", ".join(input_names)
output_expr = binary_op.join(input_names)
nary_src = f"""
def {binary_op_name}({input_signature}):
return {output_expr}
"""
nary_fn = compile_function_src(nary_src, binary_op_name)
return nary_fn
@numba_funcify.register(Add)
def numba_funcify_Add(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True)
nary_add_fn = binary_to_nary_func(node.inputs, "add", "+")
return numba.njit(signature, inline="always")(nary_add_fn)
@numba_funcify.register(Mul)
def numba_funcify_Mul(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True)
nary_mul_fn = binary_to_nary_func(node.inputs, "mul", "*")
return numba.njit(signature, inline="always")(nary_mul_fn)
@numba_funcify.register(Cast)
def numba_funcify_Cast(op, node, **kwargs):
dtype = np.dtype(op.o_type.dtype)
@numba.njit(inline="always")
def cast(x):
return numba_basic.direct_cast(x, dtype)
return cast
@numba_funcify.register(Identity)
@numba_funcify.register(ViewOp)
def numba_funcify_ViewOp(op, **kwargs):
@numba.njit(inline="always")
def viewop(x):
return x
return viewop
@numba_funcify.register(Clip)
def numba_funcify_Clip(op, **kwargs):
@numba.njit
def clip(_x, _min, _max):
x = numba_basic.to_scalar(_x)
_min_scalar = numba_basic.to_scalar(_min)
_max_scalar = numba_basic.to_scalar(_max)
if x < _min_scalar:
return _min_scalar
elif x > _max_scalar:
return _max_scalar
else:
return x
return clip
@numba_funcify.register(Composite)
def numba_funcify_Composite(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True)
composite_fn = numba.njit(signature)(
numba_funcify(op.fgraph, squeeze_output=True, **kwargs)
)
return composite_fn
@numba_funcify.register(Second)
def numba_funcify_Second(op, node, **kwargs):
@numba.njit(inline="always")
def second(x, y):
return y
return second
from textwrap import indent
import numba
import numpy as np
from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.dispatch.basic import create_tuple_string, numba_funcify
from aesara.link.utils import compile_function_src, unique_name_generator
from aesara.tensor.basic import (
Alloc,
AllocDiag,
AllocEmpty,
ARange,
ExtractDiag,
Eye,
Join,
MakeVector,
Rebroadcast,
ScalarFromTensor,
TensorFromScalar,
)
@numba_funcify.register(AllocEmpty)
def numba_funcify_AllocEmpty(op, node, **kwargs):
global_env = {
"np": np,
"to_scalar": numba_basic.to_scalar,
"dtype": np.dtype(op.dtype),
}
unique_names = unique_name_generator(
["np", "to_scalar", "dtype", "allocempty", "scalar_shape"], suffix_sep="_"
)
shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs]
shape_var_item_names = [f"{name}_item" for name in shape_var_names]
shapes_to_items_src = indent(
"\n".join(
[
f"{item_name} = to_scalar({shape_name})"
for item_name, shape_name in zip(shape_var_item_names, shape_var_names)
]
),
" " * 4,
)
alloc_def_src = f"""
def allocempty({", ".join(shape_var_names)}):
{shapes_to_items_src}
scalar_shape = {create_tuple_string(shape_var_item_names)}
return np.empty(scalar_shape, dtype)
"""
alloc_fn = compile_function_src(alloc_def_src, "allocempty", global_env)
return numba.njit(alloc_fn)
@numba_funcify.register(Alloc)
def numba_funcify_Alloc(op, node, **kwargs):
global_env = {"np": np, "to_scalar": numba_basic.to_scalar}
unique_names = unique_name_generator(
["np", "to_scalar", "alloc", "val_np", "val", "scalar_shape", "res"],
suffix_sep="_",
)
shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs[1:]]
shape_var_item_names = [f"{name}_item" for name in shape_var_names]
shapes_to_items_src = indent(
"\n".join(
[
f"{item_name} = to_scalar({shape_name})"
for item_name, shape_name in zip(shape_var_item_names, shape_var_names)
]
),
" " * 4,
)
alloc_def_src = f"""
def alloc(val, {", ".join(shape_var_names)}):
val_np = np.asarray(val)
{shapes_to_items_src}
scalar_shape = {create_tuple_string(shape_var_item_names)}
res = np.empty(scalar_shape, dtype=val_np.dtype)
res[...] = val_np
return res
"""
alloc_fn = compile_function_src(alloc_def_src, "alloc", global_env)
return numba.njit(alloc_fn)
@numba_funcify.register(AllocDiag)
def numba_funcify_AllocDiag(op, **kwargs):
offset = op.offset
@numba.njit(inline="always")
def allocdiag(v):
return np.diag(v, k=offset)
return allocdiag
@numba_funcify.register(ARange)
def numba_funcify_ARange(op, **kwargs):
dtype = np.dtype(op.dtype)
@numba.njit(inline="always")
def arange(start, stop, step):
return np.arange(
numba_basic.to_scalar(start),
numba_basic.to_scalar(stop),
numba_basic.to_scalar(step),
dtype=dtype,
)
return arange
@numba_funcify.register(Join)
def numba_funcify_Join(op, **kwargs):
view = op.view
if view != -1:
# TODO: Where (and why) is this `Join.view` even being used? From a
# quick search, the answer appears to be "nowhere", so we should
# probably just remove it.
raise NotImplementedError("The `view` parameter to `Join` is not supported")
@numba.njit
def join(axis, *tensors):
return np.concatenate(tensors, numba_basic.to_scalar(axis))
return join
@numba_funcify.register(ExtractDiag)
def numba_funcify_ExtractDiag(op, **kwargs):
offset = op.offset
# axis1 = op.axis1
# axis2 = op.axis2
@numba.njit(inline="always")
def extract_diag(x):
return np.diag(x, k=offset)
return extract_diag
@numba_funcify.register(Eye)
def numba_funcify_Eye(op, **kwargs):
dtype = np.dtype(op.dtype)
@numba.njit(inline="always")
def eye(N, M, k):
return np.eye(
numba_basic.to_scalar(N),
numba_basic.to_scalar(M),
numba_basic.to_scalar(k),
dtype=dtype,
)
return eye
@numba_funcify.register(MakeVector)
def numba_funcify_MakeVector(op, **kwargs):
dtype = np.dtype(op.dtype)
@numba.njit
def makevector(*args):
return np.array([a.item() for a in args], dtype=dtype)
return makevector
@numba_funcify.register(Rebroadcast)
def numba_funcify_Rebroadcast(op, **kwargs):
op_axis = tuple(op.axis.items())
@numba.njit
def rebroadcast(x):
for axis, value in numba.literal_unroll(op_axis):
if value and x.shape[axis] != 1:
raise ValueError(
("Dimension in Rebroadcast's input was supposed to be 1")
)
return x
return rebroadcast
@numba_funcify.register(TensorFromScalar)
def numba_funcify_TensorFromScalar(op, **kwargs):
@numba.njit(inline="always")
def tensor_from_scalar(x):
return np.array(x)
return tensor_from_scalar
@numba_funcify.register(ScalarFromTensor)
def numba_funcify_ScalarFromTensor(op, **kwargs):
@numba.njit(inline="always")
def scalar_from_tensor(x):
return x.item()
return scalar_from_tensor
from numpy.random import RandomState
from aesara.link.basic import JITLinker
......@@ -18,6 +16,8 @@ class NumbaLinker(JITLinker):
return jitted_fn
def create_thunk_inputs(self, storage_map):
from numpy.random import RandomState
from aesara.link.numba.dispatch import numba_typify
thunk_inputs = []
......
......@@ -25,7 +25,7 @@ from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.optdb import OptimizationQuery
from aesara.graph.type import Type
from aesara.link.numba.dispatch import create_numba_signature, get_numba_type
from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.linker import NumbaLinker
from aesara.scalar.basic import Composite
from aesara.tensor import blas
......@@ -147,20 +147,21 @@ def eval_python_only(fn_inputs, fgraph, inputs):
else:
return wrap
with mock.patch("aesara.link.numba.dispatch.numba.njit", njit_noop), mock.patch(
"aesara.link.numba.dispatch.numba.vectorize",
with mock.patch("numba.njit", njit_noop), mock.patch(
"numba.vectorize",
vectorize_noop,
), mock.patch(
"aesara.link.numba.dispatch.tuple_setitem", py_tuple_setitem
"aesara.link.numba.dispatch.elemwise.tuple_setitem",
py_tuple_setitem,
), mock.patch(
"aesara.link.numba.dispatch.direct_cast", lambda x, dtype: x
"aesara.link.numba.dispatch.basic.direct_cast", lambda x, dtype: x
), mock.patch(
"aesara.link.numba.dispatch.numba.np.numpy_support.from_dtype",
"aesara.link.numba.dispatch.basic.numba.np.numpy_support.from_dtype",
lambda dtype: dtype,
), mock.patch(
"aesara.link.numba.dispatch.to_scalar", py_to_scalar
"aesara.link.numba.dispatch.basic.to_scalar", py_to_scalar
), mock.patch(
"aesara.link.numba.dispatch.to_fixed_tuple",
"numba.np.unsafe.ndarray.to_fixed_tuple",
lambda x, n: tuple(x),
):
aesara_numba_fn = function(
......@@ -247,7 +248,7 @@ def test_get_numba_type(v, expected, force_scalar, not_implemented):
else pytest.raises(NotImplementedError)
)
with cm:
res = get_numba_type(v, force_scalar=force_scalar)
res = numba_basic.get_numba_type(v, force_scalar=force_scalar)
assert res == expected
......@@ -289,7 +290,7 @@ def test_get_numba_type(v, expected, force_scalar, not_implemented):
],
)
def test_create_numba_signature(v, expected, force_scalar):
res = create_numba_signature(v, force_scalar=force_scalar)
res = numba_basic.create_numba_signature(v, force_scalar=force_scalar)
assert res == expected
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论