提交 d4696e6b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Split aesara.link.numba.dispatch into distinct modules

上级 0ae63d1b
差异被折叠。
# isort: off
from aesara.link.numba.dispatch.basic import numba_funcify, numba_typify
# Load dispatch specializations
import aesara.link.numba.dispatch.scalar
import aesara.link.numba.dispatch.tensor_basic
import aesara.link.numba.dispatch.extra_ops
import aesara.link.numba.dispatch.nlinalg
import aesara.link.numba.dispatch.random
import aesara.link.numba.dispatch.elemwise
# isort: on
差异被折叠。
差异被折叠。
import warnings
import numba
import numpy as np
from numpy.core.multiarray import normalize_axis_index
from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.dispatch.basic import get_numba_type, numba_funcify
from aesara.tensor.extra_ops import (
Bartlett,
CumOp,
DiffOp,
FillDiagonal,
FillDiagonalOffset,
RavelMultiIndex,
Repeat,
SearchsortedOp,
Unique,
UnravelIndex,
)
@numba_funcify.register(Bartlett)
def numba_funcify_Bartlett(op, **kwargs):
@numba.njit(inline="always")
def bartlett(x):
return np.bartlett(numba_basic.to_scalar(x))
return bartlett
@numba_funcify.register(CumOp)
def numba_funcify_CumOp(op, node, **kwargs):
axis = op.axis
mode = op.mode
ndim = node.outputs[0].ndim
reaxis_first = (axis,) + tuple(i for i in range(ndim) if i != axis)
if mode == "add":
np_func = np.add
identity = 0
else:
np_func = np.multiply
identity = 1
@numba.njit(boundscheck=False)
def cumop(x):
out_dtype = x.dtype
if x.shape[axis] < 2:
return x.astype(out_dtype)
x_axis_first = x.transpose(reaxis_first)
res = np.empty(x_axis_first.shape, dtype=out_dtype)
for m in range(x.shape[axis]):
if m == 0:
np_func(identity, x_axis_first[m], res[m])
else:
np_func(res[m - 1], x_axis_first[m], res[m])
return res.transpose(reaxis_first)
return cumop
@numba_funcify.register(DiffOp)
def numba_funcify_DiffOp(op, node, **kwargs):
n = op.n
axis = op.axis
ndim = node.inputs[0].ndim
dtype = node.outputs[0].dtype
axis = normalize_axis_index(axis, ndim)
slice1 = [slice(None)] * ndim
slice2 = [slice(None)] * ndim
slice1[axis] = slice(1, None)
slice2[axis] = slice(None, -1)
slice1 = tuple(slice1)
slice2 = tuple(slice2)
op = np.not_equal if dtype == "bool" else np.subtract
@numba.njit(boundscheck=False)
def diffop(x):
res = x.copy()
for _ in range(n):
res = op(res[slice1], res[slice2])
return res
return diffop
@numba_funcify.register(FillDiagonal)
def numba_funcify_FillDiagonal(op, **kwargs):
@numba.njit
def filldiagonal(a, val):
np.fill_diagonal(a, val)
return a
return filldiagonal
@numba_funcify.register(FillDiagonalOffset)
def numba_funcify_FillDiagonalOffset(op, node, **kwargs):
@numba.njit
def filldiagonaloffset(a, val, offset):
height, width = a.shape
if offset >= 0:
start = numba_basic.to_scalar(offset)
num_of_step = min(min(width, height), width - offset)
else:
start = -numba_basic.to_scalar(offset) * a.shape[1]
num_of_step = min(min(width, height), height + offset)
step = a.shape[1] + 1
end = start + step * num_of_step
b = a.ravel()
b[start:end:step] = val
# TODO: This isn't implemented in Numba
# a.flat[start:end:step] = val
# return a
return b.reshape(a.shape)
return filldiagonaloffset
@numba_funcify.register(RavelMultiIndex)
def numba_funcify_RavelMultiIndex(op, node, **kwargs):
mode = op.mode
order = op.order
if order != "C":
raise NotImplementedError(
"Numba does not implement `order` in `numpy.ravel_multi_index`"
)
if mode == "raise":
@numba.njit
def mode_fn(*args):
raise ValueError("invalid entry in coordinates array")
elif mode == "wrap":
@numba.njit(inline="always")
def mode_fn(new_arr, i, j, v, d):
new_arr[i, j] = v % d
elif mode == "clip":
@numba.njit(inline="always")
def mode_fn(new_arr, i, j, v, d):
new_arr[i, j] = min(max(v, 0), d - 1)
if node.inputs[0].ndim == 0:
@numba.njit
def ravelmultiindex(*inp):
shape = inp[-1]
arr = np.stack(inp[:-1])
new_arr = arr.T.astype(np.float64).copy()
for i, b in enumerate(new_arr):
if b < 0 or b >= shape[i]:
mode_fn(new_arr, i, 0, b, shape[i])
a = np.ones(len(shape), dtype=np.float64)
a[: len(shape) - 1] = np.cumprod(shape[-1:0:-1])[::-1]
return np.array(a.dot(new_arr.T), dtype=np.int64)
else:
@numba.njit
def ravelmultiindex(*inp):
shape = inp[-1]
arr = np.stack(inp[:-1])
new_arr = arr.T.astype(np.float64).copy()
for i, b in enumerate(new_arr):
for j, (d, v) in enumerate(zip(shape, b)):
if v < 0 or v >= d:
mode_fn(new_arr, i, j, v, d)
a = np.ones(len(shape), dtype=np.float64)
a[: len(shape) - 1] = np.cumprod(shape[-1:0:-1])[::-1]
return a.dot(new_arr.T).astype(np.int64)
return ravelmultiindex
@numba_funcify.register(Repeat)
def numba_funcify_Repeat(op, node, **kwargs):
axis = op.axis
use_python = False
if axis is not None:
use_python = True
if use_python:
warnings.warn(
(
"Numba will use object mode to allow the "
"`axis` argument to `numpy.repeat`."
),
UserWarning,
)
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
def repeatop(x, repeats):
with numba.objmode(ret=ret_sig):
ret = np.repeat(x, repeats, axis)
return ret
else:
repeats_ndim = node.inputs[1].ndim
if repeats_ndim == 0:
@numba.njit(inline="always")
def repeatop(x, repeats):
return np.repeat(x, repeats.item())
else:
@numba.njit(inline="always")
def repeatop(x, repeats):
return np.repeat(x, repeats)
return repeatop
@numba_funcify.register(Unique)
def numba_funcify_Unique(op, node, **kwargs):
axis = op.axis
use_python = False
if axis is not None:
use_python = True
return_index = op.return_index
return_inverse = op.return_inverse
return_counts = op.return_counts
returns_multi = return_index or return_inverse or return_counts
use_python |= returns_multi
if not use_python:
@numba.njit(inline="always")
def unique(x):
return np.unique(x)
else:
warnings.warn(
(
"Numba will use object mode to allow the "
"`axis` and/or `return_*` arguments to `numpy.unique`."
),
UserWarning,
)
if returns_multi:
ret_sig = numba.types.Tuple([get_numba_type(o.type) for o in node.outputs])
else:
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
def unique(x):
with numba.objmode(ret=ret_sig):
ret = np.unique(x, return_index, return_inverse, return_counts, axis)
return ret
return unique
@numba_funcify.register(UnravelIndex)
def numba_funcify_UnravelIndex(op, node, **kwargs):
order = op.order
if order != "C":
raise NotImplementedError(
"Numba does not support the `order` argument in `numpy.unravel_index`"
)
if len(node.outputs) == 1:
@numba.njit(inline="always")
def maybe_expand_dim(arr):
return arr
else:
@numba.njit(inline="always")
def maybe_expand_dim(arr):
return np.expand_dims(arr, 1)
@numba.njit
def unravelindex(arr, shape):
a = np.ones(len(shape), dtype=np.int64)
a[1:] = shape[:0:-1]
a = np.cumprod(a)[::-1]
# Aesara actually returns a `tuple` of these values, instead of an
# `ndarray`; however, this `ndarray` result should be able to be
# unpacked into a `tuple`, so this discrepancy shouldn't really matter
return ((maybe_expand_dim(arr) // a) % shape).T
return unravelindex
@numba_funcify.register(SearchsortedOp)
def numba_funcify_Searchsorted(op, node, **kwargs):
side = op.side
use_python = False
if len(node.inputs) == 3:
use_python = True
if use_python:
warnings.warn(
(
"Numba will use object mode to allow the "
"`sorter` argument to `numpy.searchsorted`."
),
UserWarning,
)
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
def searchsorted(a, v, sorter):
with numba.objmode(ret=ret_sig):
ret = np.searchsorted(a, v, side, sorter)
return ret
else:
@numba.njit(inline="always")
def searchsorted(a, v):
return np.searchsorted(a, v, side)
return searchsorted
import warnings
import numba
import numpy as np
from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.dispatch.basic import (
get_numba_type,
int_to_float_fn,
numba_funcify,
)
from aesara.tensor.nlinalg import (
SVD,
Det,
Eig,
Eigh,
Inv,
MatrixInverse,
MatrixPinv,
QRFull,
)
@numba_funcify.register(SVD)
def numba_funcify_SVD(op, node, **kwargs):
full_matrices = op.full_matrices
compute_uv = op.compute_uv
if not compute_uv:
warnings.warn(
(
"Numba will use object mode to allow the "
"`compute_uv` argument to `numpy.linalg.svd`."
),
UserWarning,
)
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
def svd(x):
with numba.objmode(ret=ret_sig):
ret = np.linalg.svd(x, full_matrices, compute_uv)
return ret
else:
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
def svd(x):
return np.linalg.svd(inputs_cast(x), full_matrices)
return svd
@numba_funcify.register(Det)
def numba_funcify_Det(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
def det(x):
return numba_basic.direct_cast(np.linalg.det(inputs_cast(x)), out_dtype)
return det
@numba_funcify.register(Eig)
def numba_funcify_Eig(op, node, **kwargs):
out_dtype_1 = node.outputs[0].type.numpy_dtype
out_dtype_2 = node.outputs[1].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype_1)
@numba.njit
def eig(x):
out = np.linalg.eig(inputs_cast(x))
return (out[0].astype(out_dtype_1), out[1].astype(out_dtype_2))
return eig
@numba_funcify.register(Eigh)
def numba_funcify_Eigh(op, node, **kwargs):
uplo = op.UPLO
if uplo != "L":
warnings.warn(
(
"Numba will use object mode to allow the "
"`UPLO` argument to `numpy.linalg.eigh`."
),
UserWarning,
)
out_dtypes = tuple(o.type.numpy_dtype for o in node.outputs)
ret_sig = numba.types.Tuple(
[get_numba_type(node.outputs[0].type), get_numba_type(node.outputs[1].type)]
)
@numba.njit
def eigh(x):
with numba.objmode(ret=ret_sig):
out = np.linalg.eigh(x, UPLO=uplo)
ret = (out[0].astype(out_dtypes[0]), out[1].astype(out_dtypes[1]))
return ret
else:
@numba.njit(inline="always")
def eigh(x):
return np.linalg.eigh(x)
return eigh
@numba_funcify.register(Inv)
def numba_funcify_Inv(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
def inv(x):
return np.linalg.inv(inputs_cast(x)).astype(out_dtype)
return inv
@numba_funcify.register(MatrixInverse)
def numba_funcify_MatrixInverse(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
def matrix_inverse(x):
return np.linalg.inv(inputs_cast(x)).astype(out_dtype)
return matrix_inverse
@numba_funcify.register(MatrixPinv)
def numba_funcify_MatrixPinv(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
def matrixpinv(x):
return np.linalg.pinv(inputs_cast(x)).astype(out_dtype)
return matrixpinv
@numba_funcify.register(QRFull)
def numba_funcify_QRFull(op, node, **kwargs):
mode = op.mode
if mode != "reduced":
warnings.warn(
(
"Numba will use object mode to allow the "
"`mode` argument to `numpy.linalg.qr`."
),
UserWarning,
)
if len(node.outputs) > 1:
ret_sig = numba.types.Tuple([get_numba_type(o.type) for o in node.outputs])
else:
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
def qr_full(x):
with numba.objmode(ret=ret_sig):
ret = np.linalg.qr(x, mode=mode)
return ret
else:
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
def qr_full(x):
return np.linalg.qr(inputs_cast(x))
return qr_full
from textwrap import dedent, indent
from typing import Any, Callable, Dict, Optional
import numba
import numba.np.unsafe.ndarray as numba_ndarray
import numpy as np
from numba import _helperlib
from numpy.random import RandomState
import aesara.tensor.random.basic as aer
from aesara.graph.basic import Apply
from aesara.graph.op import Op
from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.dispatch.basic import numba_funcify, numba_typify
from aesara.link.utils import (
compile_function_src,
get_name_for_object,
unique_name_generator,
)
from aesara.tensor.basic import get_vector_length
from aesara.tensor.random.type import RandomStateType
from aesara.tensor.random.var import RandomStateSharedVariable
@numba_typify.register(RandomState)
def numba_typify_RandomState(state, **kwargs):
ints, index = state.get_state()[1:3]
ptr = _helperlib.rnd_get_np_state_ptr()
_helperlib.rnd_set_state(ptr, (index, [int(x) for x in ints]))
return ints
def make_numba_random_fn(node, np_random_func):
"""Create Numba implementations for existing Numba-supported ``np.random`` functions.
The functions generated here add parameter broadcasting and the ``size``
argument to the Numba-supported scalar ``np.random`` functions.
"""
tuple_size = get_vector_length(node.inputs[1])
size_dims = tuple_size - max(i.ndim for i in node.inputs[3:])
# Make a broadcast-capable version of the Numba supported scalar sampling
# function
bcast_fn_name = f"aesara_random_{get_name_for_object(np_random_func)}"
sized_fn_name = "sized_random_variable"
unique_names = unique_name_generator(
[
bcast_fn_name,
sized_fn_name,
"np",
"np_random_func",
"numba_vectorize",
"to_fixed_tuple",
"tuple_size",
"size_dims",
"rng",
"size",
"dtype",
],
suffix_sep="_",
)
bcast_fn_input_names = ", ".join(
[unique_names(i, force_unique=True) for i in node.inputs[3:]]
)
bcast_fn_global_env = {
"np_random_func": np_random_func,
"numba_vectorize": numba.vectorize,
}
bcast_fn_src = f"""
@numba_vectorize
def {bcast_fn_name}({bcast_fn_input_names}):
return np_random_func({bcast_fn_input_names})
"""
bcast_fn = compile_function_src(bcast_fn_src, bcast_fn_name, bcast_fn_global_env)
random_fn_input_names = ", ".join(
["rng", "size", "dtype"] + [unique_names(i) for i in node.inputs[3:]]
)
# Now, create a Numba JITable function that implements the `size` parameter
out_dtype = node.outputs[1].type.numpy_dtype
random_fn_global_env = {
bcast_fn_name: bcast_fn,
"out_dtype": out_dtype,
}
if tuple_size > 0:
random_fn_body = dedent(
f"""
size = to_fixed_tuple(size, tuple_size)
data = np.empty(size, dtype=out_dtype)
for i in np.ndindex(size[:size_dims]):
data[i] = {bcast_fn_name}({bcast_fn_input_names})
"""
)
random_fn_global_env.update(
{
"np": np,
"to_fixed_tuple": numba_ndarray.to_fixed_tuple,
"tuple_size": tuple_size,
"size_dims": size_dims,
}
)
else:
random_fn_body = f"""data = {bcast_fn_name}({bcast_fn_input_names})"""
sized_fn_src = dedent(
f"""
def {sized_fn_name}({random_fn_input_names}):
{indent(random_fn_body, " " * 4)}
return (rng, data)
"""
)
random_fn = compile_function_src(sized_fn_src, sized_fn_name, random_fn_global_env)
random_fn = numba.njit(random_fn)
return random_fn
@numba_funcify.register(aer.UniformRV)
@numba_funcify.register(aer.TriangularRV)
@numba_funcify.register(aer.BetaRV)
@numba_funcify.register(aer.NormalRV)
@numba_funcify.register(aer.LogNormalRV)
@numba_funcify.register(aer.GammaRV)
@numba_funcify.register(aer.ChiSquareRV)
@numba_funcify.register(aer.ParetoRV)
@numba_funcify.register(aer.GumbelRV)
@numba_funcify.register(aer.ExponentialRV)
@numba_funcify.register(aer.WeibullRV)
@numba_funcify.register(aer.LogisticRV)
@numba_funcify.register(aer.VonMisesRV)
@numba_funcify.register(aer.PoissonRV)
@numba_funcify.register(aer.GeometricRV)
@numba_funcify.register(aer.HyperGeometricRV)
@numba_funcify.register(aer.CauchyRV)
@numba_funcify.register(aer.WaldRV)
@numba_funcify.register(aer.LaplaceRV)
@numba_funcify.register(aer.BinomialRV)
@numba_funcify.register(aer.NegBinomialRV)
@numba_funcify.register(aer.MultinomialRV)
@numba_funcify.register(aer.RandIntRV) # only the first two arguments are supported
@numba_funcify.register(aer.ChoiceRV) # the `p` argument is not supported
@numba_funcify.register(aer.PermutationRV)
def numba_funcify_RandomVariable(op, node, **kwargs):
name = op.name
np_random_func = getattr(np.random, name)
if not isinstance(node.inputs[0], (RandomStateType, RandomStateSharedVariable)):
raise TypeError("Numba does not support NumPy `Generator`s")
return make_numba_random_fn(node, np_random_func)
def create_numba_random_fn(
op: Op,
node: Apply,
scalar_fn: Callable[[str], str],
global_env: Optional[Dict[str, Any]] = None,
) -> Callable:
"""Create a vectorized function from a callable that generates the ``str`` function body.
TODO: This could/should be generalized for other simple function
construction cases that need unique-ified symbol names.
"""
np_random_fn_name = f"aesara_random_{get_name_for_object(op.name)}"
if global_env:
np_global_env = global_env.copy()
else:
np_global_env = {}
np_global_env["np"] = np
np_global_env["numba_vectorize"] = numba.vectorize
unique_names = unique_name_generator(
[
np_random_fn_name,
]
+ list(np_global_env.keys())
+ [
"rng",
"size",
"dtype",
],
suffix_sep="_",
)
np_names = [unique_names(i, force_unique=True) for i in node.inputs[3:]]
np_input_names = ", ".join(np_names)
np_random_fn_src = f"""
@numba_vectorize
def {np_random_fn_name}({np_input_names}):
{scalar_fn(*np_names)}
"""
np_random_fn = compile_function_src(
np_random_fn_src, np_random_fn_name, np_global_env
)
return make_numba_random_fn(node, np_random_fn)
@numba_funcify.register(aer.HalfNormalRV)
def numba_funcify_HalfNormalRV(op, node, **kwargs):
def body_fn(a, b):
return f" return {a} + {b} * abs(np.random.normal(0, 1))"
return create_numba_random_fn(op, node, body_fn)
@numba_funcify.register(aer.BernoulliRV)
def numba_funcify_BernoulliRV(op, node, **kwargs):
out_dtype = node.outputs[1].type.numpy_dtype
def body_fn(a):
return f"""
if {a} < np.random.uniform(0, 1):
return direct_cast(0, out_dtype)
else:
return direct_cast(1, out_dtype)
"""
return create_numba_random_fn(
op,
node,
body_fn,
{"out_dtype": out_dtype, "direct_cast": numba_basic.direct_cast},
)
from functools import reduce
from typing import List
import numba
import numpy as np
import scipy
import scipy.special
from aesara.compile.ops import ViewOp
from aesara.graph.basic import Variable
from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.dispatch.basic import create_numba_signature, numba_funcify
from aesara.link.utils import (
compile_function_src,
get_name_for_object,
unique_name_generator,
)
from aesara.scalar.basic import (
Add,
Cast,
Clip,
Composite,
Identity,
Mul,
ScalarOp,
Second,
Switch,
)
@numba_funcify.register(ScalarOp)
def numba_funcify_ScalarOp(op, node, **kwargs):
# TODO: Do we need to cache these functions so that we don't end up
# compiling the same Numba function over and over again?
scalar_func_name = op.nfunc_spec[0]
if scalar_func_name.startswith("scipy."):
func_package = scipy
scalar_func_name = scalar_func_name.split(".", 1)[-1]
else:
func_package = np
if "." in scalar_func_name:
scalar_func = reduce(getattr, [scipy] + scalar_func_name.split("."))
else:
scalar_func = getattr(func_package, scalar_func_name)
scalar_op_fn_name = get_name_for_object(scalar_func)
unique_names = unique_name_generator(
[scalar_op_fn_name, "scalar_func"], suffix_sep="_"
)
input_names = ", ".join([unique_names(v, force_unique=True) for v in node.inputs])
global_env = {"scalar_func": scalar_func}
scalar_op_src = f"""
def {scalar_op_fn_name}({input_names}):
return scalar_func({input_names})
"""
scalar_op_fn = compile_function_src(scalar_op_src, scalar_op_fn_name, global_env)
signature = create_numba_signature(node, force_scalar=True)
return numba.njit(signature, inline="always")(scalar_op_fn)
@numba_funcify.register(Switch)
def numba_funcify_Switch(op, node, **kwargs):
@numba.njit(inline="always")
def switch(condition, x, y):
if condition:
return x
else:
return y
return switch
def binary_to_nary_func(inputs: List[Variable], binary_op_name: str, binary_op: str):
"""Create a Numba-compatible N-ary function from a binary function."""
unique_names = unique_name_generator(["binary_op_name"], suffix_sep="_")
input_names = [unique_names(v, force_unique=True) for v in inputs]
input_signature = ", ".join(input_names)
output_expr = binary_op.join(input_names)
nary_src = f"""
def {binary_op_name}({input_signature}):
return {output_expr}
"""
nary_fn = compile_function_src(nary_src, binary_op_name)
return nary_fn
@numba_funcify.register(Add)
def numba_funcify_Add(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True)
nary_add_fn = binary_to_nary_func(node.inputs, "add", "+")
return numba.njit(signature, inline="always")(nary_add_fn)
@numba_funcify.register(Mul)
def numba_funcify_Mul(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True)
nary_mul_fn = binary_to_nary_func(node.inputs, "mul", "*")
return numba.njit(signature, inline="always")(nary_mul_fn)
@numba_funcify.register(Cast)
def numba_funcify_Cast(op, node, **kwargs):
dtype = np.dtype(op.o_type.dtype)
@numba.njit(inline="always")
def cast(x):
return numba_basic.direct_cast(x, dtype)
return cast
@numba_funcify.register(Identity)
@numba_funcify.register(ViewOp)
def numba_funcify_ViewOp(op, **kwargs):
@numba.njit(inline="always")
def viewop(x):
return x
return viewop
@numba_funcify.register(Clip)
def numba_funcify_Clip(op, **kwargs):
@numba.njit
def clip(_x, _min, _max):
x = numba_basic.to_scalar(_x)
_min_scalar = numba_basic.to_scalar(_min)
_max_scalar = numba_basic.to_scalar(_max)
if x < _min_scalar:
return _min_scalar
elif x > _max_scalar:
return _max_scalar
else:
return x
return clip
@numba_funcify.register(Composite)
def numba_funcify_Composite(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True)
composite_fn = numba.njit(signature)(
numba_funcify(op.fgraph, squeeze_output=True, **kwargs)
)
return composite_fn
@numba_funcify.register(Second)
def numba_funcify_Second(op, node, **kwargs):
@numba.njit(inline="always")
def second(x, y):
return y
return second
from textwrap import indent
import numba
import numpy as np
from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.dispatch.basic import create_tuple_string, numba_funcify
from aesara.link.utils import compile_function_src, unique_name_generator
from aesara.tensor.basic import (
Alloc,
AllocDiag,
AllocEmpty,
ARange,
ExtractDiag,
Eye,
Join,
MakeVector,
Rebroadcast,
ScalarFromTensor,
TensorFromScalar,
)
@numba_funcify.register(AllocEmpty)
def numba_funcify_AllocEmpty(op, node, **kwargs):
global_env = {
"np": np,
"to_scalar": numba_basic.to_scalar,
"dtype": np.dtype(op.dtype),
}
unique_names = unique_name_generator(
["np", "to_scalar", "dtype", "allocempty", "scalar_shape"], suffix_sep="_"
)
shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs]
shape_var_item_names = [f"{name}_item" for name in shape_var_names]
shapes_to_items_src = indent(
"\n".join(
[
f"{item_name} = to_scalar({shape_name})"
for item_name, shape_name in zip(shape_var_item_names, shape_var_names)
]
),
" " * 4,
)
alloc_def_src = f"""
def allocempty({", ".join(shape_var_names)}):
{shapes_to_items_src}
scalar_shape = {create_tuple_string(shape_var_item_names)}
return np.empty(scalar_shape, dtype)
"""
alloc_fn = compile_function_src(alloc_def_src, "allocempty", global_env)
return numba.njit(alloc_fn)
@numba_funcify.register(Alloc)
def numba_funcify_Alloc(op, node, **kwargs):
global_env = {"np": np, "to_scalar": numba_basic.to_scalar}
unique_names = unique_name_generator(
["np", "to_scalar", "alloc", "val_np", "val", "scalar_shape", "res"],
suffix_sep="_",
)
shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs[1:]]
shape_var_item_names = [f"{name}_item" for name in shape_var_names]
shapes_to_items_src = indent(
"\n".join(
[
f"{item_name} = to_scalar({shape_name})"
for item_name, shape_name in zip(shape_var_item_names, shape_var_names)
]
),
" " * 4,
)
alloc_def_src = f"""
def alloc(val, {", ".join(shape_var_names)}):
val_np = np.asarray(val)
{shapes_to_items_src}
scalar_shape = {create_tuple_string(shape_var_item_names)}
res = np.empty(scalar_shape, dtype=val_np.dtype)
res[...] = val_np
return res
"""
alloc_fn = compile_function_src(alloc_def_src, "alloc", global_env)
return numba.njit(alloc_fn)
@numba_funcify.register(AllocDiag)
def numba_funcify_AllocDiag(op, **kwargs):
offset = op.offset
@numba.njit(inline="always")
def allocdiag(v):
return np.diag(v, k=offset)
return allocdiag
@numba_funcify.register(ARange)
def numba_funcify_ARange(op, **kwargs):
dtype = np.dtype(op.dtype)
@numba.njit(inline="always")
def arange(start, stop, step):
return np.arange(
numba_basic.to_scalar(start),
numba_basic.to_scalar(stop),
numba_basic.to_scalar(step),
dtype=dtype,
)
return arange
@numba_funcify.register(Join)
def numba_funcify_Join(op, **kwargs):
view = op.view
if view != -1:
# TODO: Where (and why) is this `Join.view` even being used? From a
# quick search, the answer appears to be "nowhere", so we should
# probably just remove it.
raise NotImplementedError("The `view` parameter to `Join` is not supported")
@numba.njit
def join(axis, *tensors):
return np.concatenate(tensors, numba_basic.to_scalar(axis))
return join
@numba_funcify.register(ExtractDiag)
def numba_funcify_ExtractDiag(op, **kwargs):
offset = op.offset
# axis1 = op.axis1
# axis2 = op.axis2
@numba.njit(inline="always")
def extract_diag(x):
return np.diag(x, k=offset)
return extract_diag
@numba_funcify.register(Eye)
def numba_funcify_Eye(op, **kwargs):
dtype = np.dtype(op.dtype)
@numba.njit(inline="always")
def eye(N, M, k):
return np.eye(
numba_basic.to_scalar(N),
numba_basic.to_scalar(M),
numba_basic.to_scalar(k),
dtype=dtype,
)
return eye
@numba_funcify.register(MakeVector)
def numba_funcify_MakeVector(op, **kwargs):
dtype = np.dtype(op.dtype)
@numba.njit
def makevector(*args):
return np.array([a.item() for a in args], dtype=dtype)
return makevector
@numba_funcify.register(Rebroadcast)
def numba_funcify_Rebroadcast(op, **kwargs):
op_axis = tuple(op.axis.items())
@numba.njit
def rebroadcast(x):
for axis, value in numba.literal_unroll(op_axis):
if value and x.shape[axis] != 1:
raise ValueError(
("Dimension in Rebroadcast's input was supposed to be 1")
)
return x
return rebroadcast
@numba_funcify.register(TensorFromScalar)
def numba_funcify_TensorFromScalar(op, **kwargs):
@numba.njit(inline="always")
def tensor_from_scalar(x):
return np.array(x)
return tensor_from_scalar
@numba_funcify.register(ScalarFromTensor)
def numba_funcify_ScalarFromTensor(op, **kwargs):
@numba.njit(inline="always")
def scalar_from_tensor(x):
return x.item()
return scalar_from_tensor
from numpy.random import RandomState
from aesara.link.basic import JITLinker from aesara.link.basic import JITLinker
...@@ -18,6 +16,8 @@ class NumbaLinker(JITLinker): ...@@ -18,6 +16,8 @@ class NumbaLinker(JITLinker):
return jitted_fn return jitted_fn
def create_thunk_inputs(self, storage_map): def create_thunk_inputs(self, storage_map):
from numpy.random import RandomState
from aesara.link.numba.dispatch import numba_typify from aesara.link.numba.dispatch import numba_typify
thunk_inputs = [] thunk_inputs = []
......
...@@ -25,7 +25,7 @@ from aesara.graph.fg import FunctionGraph ...@@ -25,7 +25,7 @@ from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.optdb import OptimizationQuery from aesara.graph.optdb import OptimizationQuery
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.link.numba.dispatch import create_numba_signature, get_numba_type from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.linker import NumbaLinker from aesara.link.numba.linker import NumbaLinker
from aesara.scalar.basic import Composite from aesara.scalar.basic import Composite
from aesara.tensor import blas from aesara.tensor import blas
...@@ -147,20 +147,21 @@ def eval_python_only(fn_inputs, fgraph, inputs): ...@@ -147,20 +147,21 @@ def eval_python_only(fn_inputs, fgraph, inputs):
else: else:
return wrap return wrap
with mock.patch("aesara.link.numba.dispatch.numba.njit", njit_noop), mock.patch( with mock.patch("numba.njit", njit_noop), mock.patch(
"aesara.link.numba.dispatch.numba.vectorize", "numba.vectorize",
vectorize_noop, vectorize_noop,
), mock.patch( ), mock.patch(
"aesara.link.numba.dispatch.tuple_setitem", py_tuple_setitem "aesara.link.numba.dispatch.elemwise.tuple_setitem",
py_tuple_setitem,
), mock.patch( ), mock.patch(
"aesara.link.numba.dispatch.direct_cast", lambda x, dtype: x "aesara.link.numba.dispatch.basic.direct_cast", lambda x, dtype: x
), mock.patch( ), mock.patch(
"aesara.link.numba.dispatch.numba.np.numpy_support.from_dtype", "aesara.link.numba.dispatch.basic.numba.np.numpy_support.from_dtype",
lambda dtype: dtype, lambda dtype: dtype,
), mock.patch( ), mock.patch(
"aesara.link.numba.dispatch.to_scalar", py_to_scalar "aesara.link.numba.dispatch.basic.to_scalar", py_to_scalar
), mock.patch( ), mock.patch(
"aesara.link.numba.dispatch.to_fixed_tuple", "numba.np.unsafe.ndarray.to_fixed_tuple",
lambda x, n: tuple(x), lambda x, n: tuple(x),
): ):
aesara_numba_fn = function( aesara_numba_fn = function(
...@@ -247,7 +248,7 @@ def test_get_numba_type(v, expected, force_scalar, not_implemented): ...@@ -247,7 +248,7 @@ def test_get_numba_type(v, expected, force_scalar, not_implemented):
else pytest.raises(NotImplementedError) else pytest.raises(NotImplementedError)
) )
with cm: with cm:
res = get_numba_type(v, force_scalar=force_scalar) res = numba_basic.get_numba_type(v, force_scalar=force_scalar)
assert res == expected assert res == expected
...@@ -289,7 +290,7 @@ def test_get_numba_type(v, expected, force_scalar, not_implemented): ...@@ -289,7 +290,7 @@ def test_get_numba_type(v, expected, force_scalar, not_implemented):
], ],
) )
def test_create_numba_signature(v, expected, force_scalar): def test_create_numba_signature(v, expected, force_scalar):
res = create_numba_signature(v, force_scalar=force_scalar) res = numba_basic.create_numba_signature(v, force_scalar=force_scalar)
assert res == expected assert res == expected
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论