提交 240827cf authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Added Numba cache, vectorize_target, and fastmath config options

上级 be918f5d
......@@ -17,6 +17,8 @@ __pycache__
*.snm
*.toc
*.vrb
*.nbc
*.nbi
.noseids
*.DS_Store
*.bak
......
......@@ -1452,6 +1452,27 @@ def add_scan_configvars():
)
def add_numba_configvars():
config.add(
"numba__vectorize_target",
("Default target for numba.vectorize."),
EnumStr("cpu", ["parallel", "cuda"], mutable=True),
in_c_key=False,
)
config.add(
"numba__fastmath",
("If True, use Numba's fastmath mode."),
BoolParam(True),
in_c_key=False,
)
config.add(
"numba__cache",
("If True, use Numba's file based caching."),
BoolParam(True),
in_c_key=False,
)
def _get_default_gpuarray__cache_path():
return os.path.join(config.compiledir, "gpuarray_kernels")
......@@ -1683,6 +1704,7 @@ add_optimizer_configvars()
add_metaopt_configvars()
add_vm_configvars()
add_deprecated_configvars()
add_numba_configvars()
# TODO: `gcc_version_str` is used by other modules.. Should it become an immutable config var?
try:
......
......@@ -12,6 +12,7 @@ from numba import types
from numba.core.errors import TypingError
from numba.extending import box
from aesara import config
from aesara.compile.ops import DeepCopyOp
from aesara.graph.basic import Apply
from aesara.graph.fg import FunctionGraph
......@@ -40,6 +41,21 @@ from aesara.tensor.type import TensorType
from aesara.tensor.type_other import MakeSlice
def numba_njit(*args, **kwargs):
if len(args) > 0 and callable(args[0]):
return numba.njit(*args[1:], cache=config.numba__cache, **kwargs)(args[0])
return numba.njit(*args, cache=config.numba__cache, **kwargs)
def numba_vectorize(*args, **kwargs):
if len(args) > 0 and callable(args[0]):
return numba.vectorize(*args[1:], cache=config.numba__cache, **kwargs)(args[0])
return numba.vectorize(*args, cache=config.numba__cache, **kwargs)
def get_numba_type(
aesara_type: Type, layout: str = "A", force_scalar: bool = False
) -> numba.types.Type:
......@@ -222,19 +238,19 @@ def create_tuple_creator(f, n):
"""
assert n > 0
f = numba.njit(f)
f = numba_njit(f)
@numba.njit
@numba_njit
def creator(args):
return (f(0, *args),)
for i in range(1, n):
@numba.njit
@numba_njit
def creator(args, creator=creator, i=i):
return creator(args) + (f(i, *args),)
return numba.njit(lambda *args: creator(args))
return numba_njit(lambda *args: creator(args))
def create_tuple_string(x):
......@@ -268,7 +284,7 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs):
else:
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
@numba_njit
def perform(*inputs):
with numba.objmode(ret=ret_sig):
outputs = [[None] for i in range(n_outputs)]
......@@ -402,9 +418,11 @@ def numba_funcify_Subtensor(op, node, **kwargs):
global_env = {"np": np, "objmode": numba.objmode}
subtensor_fn = compile_function_src(subtensor_def_src, "subtensor", global_env)
subtensor_fn = compile_function_src(
subtensor_def_src, "subtensor", {**globals(), **global_env}
)
return numba.njit(subtensor_fn)
return numba_njit(subtensor_fn)
@numba_funcify.register(IncSubtensor)
......@@ -419,10 +437,10 @@ def numba_funcify_IncSubtensor(op, node, **kwargs):
global_env = {"np": np, "objmode": numba.objmode}
incsubtensor_fn = compile_function_src(
incsubtensor_def_src, "incsubtensor", global_env
incsubtensor_def_src, "incsubtensor", {**globals(), **global_env}
)
return numba.njit(incsubtensor_fn)
return numba_njit(incsubtensor_fn)
@numba_funcify.register(DeepCopyOp)
......@@ -434,13 +452,13 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs):
# The type can also be RandomType with no ndims
if not hasattr(node.outputs[0].type, "ndim") or node.outputs[0].type.ndim == 0:
# TODO: Do we really need to compile a pass-through function like this?
@numba.njit(inline="always")
@numba_njit(inline="always")
def deepcopyop(x):
return x
else:
@numba.njit(inline="always")
@numba_njit(inline="always")
def deepcopyop(x):
return x.copy()
......@@ -449,7 +467,7 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs):
@numba_funcify.register(MakeSlice)
def numba_funcify_MakeSlice(op, **kwargs):
@numba.njit
@numba_njit
def makeslice(*x):
return slice(*x)
......@@ -458,7 +476,7 @@ def numba_funcify_MakeSlice(op, **kwargs):
@numba_funcify.register(Shape)
def numba_funcify_Shape(op, **kwargs):
@numba.njit(inline="always")
@numba_njit(inline="always")
def shape(x):
return np.asarray(np.shape(x))
......@@ -469,7 +487,7 @@ def numba_funcify_Shape(op, **kwargs):
def numba_funcify_Shape_i(op, **kwargs):
i = op.i
@numba.njit(inline="always")
@numba_njit(inline="always")
def shape_i(x):
return np.shape(x)[i]
......@@ -502,13 +520,13 @@ def numba_funcify_Reshape(op, **kwargs):
if ndim == 0:
@numba.njit(inline="always")
@numba_njit(inline="always")
def reshape(x, shape):
return x.item()
else:
@numba.njit(inline="always")
@numba_njit(inline="always")
def reshape(x, shape):
# TODO: Use this until https://github.com/numba/numba/issues/7353 is closed.
return np.reshape(
......@@ -521,7 +539,7 @@ def numba_funcify_Reshape(op, **kwargs):
@numba_funcify.register(SpecifyShape)
def numba_funcify_SpecifyShape(op, **kwargs):
@numba.njit
@numba_njit
def specifyshape(x, shape):
assert np.array_equal(x.shape, shape)
return x
......@@ -536,7 +554,7 @@ def int_to_float_fn(inputs, out_dtype):
args_dtype = np.dtype(f"f{out_dtype.itemsize}")
@numba.njit(inline="always")
@numba_njit(inline="always")
def inputs_cast(x):
return x.astype(args_dtype)
......@@ -544,7 +562,7 @@ def int_to_float_fn(inputs, out_dtype):
args_dtype_sz = max([_arg.type.numpy_dtype.itemsize for _arg in inputs])
args_dtype = np.dtype(f"f{args_dtype_sz}")
@numba.njit(inline="always")
@numba_njit(inline="always")
def inputs_cast(x):
return x.astype(args_dtype)
......@@ -559,7 +577,7 @@ def numba_funcify_Dot(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
@numba_njit(inline="always")
def dot(x, y):
return np.asarray(np.dot(inputs_cast(x), inputs_cast(y))).astype(out_dtype)
......@@ -571,7 +589,7 @@ def numba_funcify_Softplus(op, node, **kwargs):
x_dtype = np.dtype(node.inputs[0].dtype)
@numba.njit
@numba_njit
def softplus(x):
if x < -37.0:
return direct_cast(np.exp(x), x_dtype)
......@@ -595,7 +613,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
@numba_njit(inline="always")
def cholesky(a):
return np.linalg.cholesky(inputs_cast(a)).astype(out_dtype)
......@@ -612,7 +630,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
@numba_njit
def cholesky(a):
with numba.objmode(ret=ret_sig):
ret = scipy.linalg.cholesky(a, lower=lower).astype(out_dtype)
......@@ -641,7 +659,7 @@ def numba_funcify_Solve(op, node, **kwargs):
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
@numba_njit
def solve(a, b):
with numba.objmode(ret=ret_sig):
ret = scipy.linalg.solve_triangular(
......@@ -656,7 +674,7 @@ def numba_funcify_Solve(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
@numba_njit(inline="always")
def solve(a, b):
return np.linalg.solve(
inputs_cast(a),
......@@ -672,7 +690,7 @@ def numba_funcify_Solve(op, node, **kwargs):
def numba_funcify_BatchedDot(op, node, **kwargs):
dtype = node.outputs[0].type.numpy_dtype
@numba.njit
@numba_njit
def batched_dot(x, y):
shape = x.shape[:-1] + y.shape[2:]
z0 = np.empty(shape, dtype=dtype)
......@@ -695,7 +713,7 @@ def numba_funcify_IfElse(op, **kwargs):
if n_outs > 1:
@numba.njit
@numba_njit
def ifelse(cond, *args):
if cond:
res = args[:n_outs]
......@@ -706,7 +724,7 @@ def numba_funcify_IfElse(op, **kwargs):
else:
@numba.njit
@numba_njit
def ifelse(cond, *args):
if cond:
res = args[:n_outs]
......
......@@ -7,6 +7,7 @@ import numba
import numpy as np
from numba.cpython.unsafe.tuple import tuple_setitem
from aesara import config
from aesara.graph.basic import Apply
from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.dispatch.basic import (
......@@ -38,9 +39,19 @@ def create_vectorize_func(op, node, use_signature=False, identity=None, **kwargs
else:
signature = []
numba_vectorize = numba.vectorize(signature, identity=identity)
elemwise_fn = numba_vectorize(scalar_op_fn)
elemwise_fn.py_scalar_func = scalar_op_fn
target = (
getattr(node.tag, "numba__vectorize_target", None)
or config.numba__vectorize_target
)
numba_vectorized_fn = numba_basic.numba_vectorize(
signature, identity=identity, target=target, fastmath=config.numba__fastmath
)
py_scalar_func = getattr(scalar_op_fn, "py_func", scalar_op_fn)
elemwise_fn = numba_vectorized_fn(scalar_op_fn)
elemwise_fn.py_scalar_func = py_scalar_func
return elemwise_fn
......@@ -85,9 +96,13 @@ def {inplace_elemwise_fn_name}({input_signature_str}):
"""
inplace_elemwise_fn = compile_function_src(
inplace_elemwise_src, inplace_elemwise_fn_name, inplace_global_env
inplace_elemwise_src,
inplace_elemwise_fn_name,
{**globals(), **inplace_global_env},
)
return numba_basic.numba_njit(inline="always", fastmath=config.numba__fastmath)(
inplace_elemwise_fn
)
return numba.njit(inline="always")(inplace_elemwise_fn)
return elemwise_fn
......@@ -144,13 +159,13 @@ def create_axis_reducer(
if keepdims:
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def set_out_dims(x):
return np.expand_dims(x, axis)
else:
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def set_out_dims(x):
return x
......@@ -160,13 +175,13 @@ def create_axis_reducer(
reaxis_first = (axis,) + tuple(i for i in range(ndim) if i != axis)
@numba.njit(boundscheck=False)
@numba_basic.numba_njit(boundscheck=False)
def careduce_axis(x):
res_shape = res_shape_tuple_ctor(x.shape)
x_axis_first = x.transpose(reaxis_first)
res = np.full(res_shape, numba_basic.to_scalar(identity), dtype=dtype)
for m in range(x.shape[axis]):
for m in numba.prange(x.shape[axis]):
reduce_fn(res, x_axis_first[m], res)
return set_out_dims(res)
......@@ -175,21 +190,22 @@ def create_axis_reducer(
if keepdims:
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def set_out_dims(x):
return np.array([x], dtype)
else:
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def set_out_dims(x):
return numba_basic.direct_cast(x, dtype)
@numba.njit(boundscheck=False)
@numba_basic.numba_njit(boundscheck=False)
def careduce_axis(x):
res = numba_basic.to_scalar(identity)
for val in x:
res = reduce_fn(res, val)
x_ravel = x.ravel()
for i in numba.prange(x_ravel.size):
res = reduce_fn(res, x_ravel[i])
return set_out_dims(res)
return careduce_axis
......@@ -258,14 +274,16 @@ def {careduce_fn_name}({input_name}):
return {var_name}
"""
careduce_fn = compile_function_src(careduce_def_src, careduce_fn_name, global_env)
return numba.njit(careduce_fn)
careduce_fn = compile_function_src(
careduce_def_src, careduce_fn_name, {**globals(), **global_env}
)
return numba_basic.numba_njit(fastmath=config.numba__fastmath)(careduce_fn)
def create_axis_apply_fn(fn, axis, ndim, dtype):
reaxis_first = tuple(i for i in range(ndim) if i != axis) + (axis,)
@numba.njit(boundscheck=False)
@numba_basic.numba_njit(boundscheck=False)
def axis_apply_fn(x):
x_reaxis = x.transpose(reaxis_first)
......@@ -327,7 +345,7 @@ def numba_funcify_DimShuffle(op, **kwargs):
if len(shuffle) > 0:
@numba.njit
@numba_basic.numba_njit
def populate_new_shape(i, j, new_shape, shuffle_shape):
if i in augment:
new_shape = tuple_setitem(new_shape, i, 1)
......@@ -341,7 +359,7 @@ def numba_funcify_DimShuffle(op, **kwargs):
# is typed as `getitem(Tuple(), int)`, which has no implementation
# (since getting an item from an empty sequence doesn't make sense).
# To avoid this compile-time error, we omit the expression altogether.
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def populate_new_shape(i, j, new_shape, shuffle_shape):
return j, tuple_setitem(new_shape, i, 1)
......@@ -350,7 +368,7 @@ def numba_funcify_DimShuffle(op, **kwargs):
lambda _: 0, ndim_new_shape
)
@numba.njit
@numba_basic.numba_njit
def dimshuffle_inner(x, shuffle):
res = np.transpose(x, transposition)
shuffle_shape = res.shape[: len(shuffle)]
......@@ -371,7 +389,7 @@ def numba_funcify_DimShuffle(op, **kwargs):
else:
@numba.njit
@numba_basic.numba_njit
def dimshuffle_inner(x, shuffle):
return x.item()
......@@ -387,7 +405,7 @@ def numba_funcify_DimShuffle(op, **kwargs):
# E No match.
# ...(on this line)...
# E shuffle_shape = res.shape[: len(shuffle)]
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def dimshuffle(x):
return dimshuffle_inner(np.asarray(x), shuffle)
......@@ -413,7 +431,7 @@ def numba_funcify_Softmax(op, node, **kwargs):
reduce_max = np.max
reduce_sum = np.sum
@numba.njit
@numba_basic.numba_njit
def softmax(x):
z = reduce_max(x)
e_x = np.exp(x - z)
......@@ -439,7 +457,7 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs):
else:
reduce_sum = np.sum
@numba.njit
@numba_basic.numba_njit
def softmax_grad(dy, sm):
dy_times_sm = dy * sm
sum_dy_times_sm = reduce_sum(dy_times_sm)
......@@ -468,7 +486,7 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
reduce_max = np.max
reduce_sum = np.sum
@numba.njit
@numba_basic.numba_njit
def log_softmax(x):
xdev = x - reduce_max(x)
lsm = xdev - np.log(reduce_sum(np.exp(xdev)))
......@@ -487,7 +505,7 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
if x_ndim == 0:
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def maxandargmax(x):
return x, 0
......@@ -511,7 +529,7 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
sl1 = slice(None, len(keep_axes))
sl2 = slice(len(keep_axes), None)
@numba.njit
@numba_basic.numba_njit
def maxandargmax(x):
max_res = reduce_max(x)
......
......@@ -4,6 +4,7 @@ import numba
import numpy as np
from numpy.core.multiarray import normalize_axis_index
from aesara import config
from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.dispatch.basic import get_numba_type, numba_funcify
from aesara.tensor.extra_ops import (
......@@ -22,7 +23,7 @@ from aesara.tensor.extra_ops import (
@numba_funcify.register(Bartlett)
def numba_funcify_Bartlett(op, **kwargs):
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def bartlett(x):
return np.bartlett(numba_basic.to_scalar(x))
......@@ -44,7 +45,7 @@ def numba_funcify_CumOp(op, node, **kwargs):
np_func = np.multiply
identity = 1
@numba.njit(boundscheck=False)
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
def cumop(x):
out_dtype = x.dtype
if x.shape[axis] < 2:
......@@ -53,7 +54,7 @@ def numba_funcify_CumOp(op, node, **kwargs):
x_axis_first = x.transpose(reaxis_first)
res = np.empty(x_axis_first.shape, dtype=out_dtype)
for m in range(x.shape[axis]):
for m in numba.prange(x.shape[axis]):
if m == 0:
np_func(identity, x_axis_first[m], res[m])
else:
......@@ -82,7 +83,7 @@ def numba_funcify_DiffOp(op, node, **kwargs):
op = np.not_equal if dtype == "bool" else np.subtract
@numba.njit(boundscheck=False)
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
def diffop(x):
res = x.copy()
......@@ -96,7 +97,7 @@ def numba_funcify_DiffOp(op, node, **kwargs):
@numba_funcify.register(FillDiagonal)
def numba_funcify_FillDiagonal(op, **kwargs):
@numba.njit
@numba_basic.numba_njit
def filldiagonal(a, val):
np.fill_diagonal(a, val)
return a
......@@ -106,7 +107,7 @@ def numba_funcify_FillDiagonal(op, **kwargs):
@numba_funcify.register(FillDiagonalOffset)
def numba_funcify_FillDiagonalOffset(op, node, **kwargs):
@numba.njit
@numba_basic.numba_njit
def filldiagonaloffset(a, val, offset):
height, width = a.shape
......@@ -142,25 +143,25 @@ def numba_funcify_RavelMultiIndex(op, node, **kwargs):
if mode == "raise":
@numba.njit
@numba_basic.numba_njit
def mode_fn(*args):
raise ValueError("invalid entry in coordinates array")
elif mode == "wrap":
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def mode_fn(new_arr, i, j, v, d):
new_arr[i, j] = v % d
elif mode == "clip":
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def mode_fn(new_arr, i, j, v, d):
new_arr[i, j] = min(max(v, 0), d - 1)
if node.inputs[0].ndim == 0:
@numba.njit
@numba_basic.numba_njit
def ravelmultiindex(*inp):
shape = inp[-1]
arr = np.stack(inp[:-1])
......@@ -176,7 +177,7 @@ def numba_funcify_RavelMultiIndex(op, node, **kwargs):
else:
@numba.njit
@numba_basic.numba_njit
def ravelmultiindex(*inp):
shape = inp[-1]
arr = np.stack(inp[:-1])
......@@ -215,7 +216,7 @@ def numba_funcify_Repeat(op, node, **kwargs):
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
@numba_basic.numba_njit
def repeatop(x, repeats):
with numba.objmode(ret=ret_sig):
ret = np.repeat(x, repeats, axis)
......@@ -226,13 +227,13 @@ def numba_funcify_Repeat(op, node, **kwargs):
if repeats_ndim == 0:
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def repeatop(x, repeats):
return np.repeat(x, repeats.item())
else:
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def repeatop(x, repeats):
return np.repeat(x, repeats)
......@@ -257,7 +258,7 @@ def numba_funcify_Unique(op, node, **kwargs):
if not use_python:
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def unique(x):
return np.unique(x)
......@@ -276,7 +277,7 @@ def numba_funcify_Unique(op, node, **kwargs):
else:
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
@numba_basic.numba_njit
def unique(x):
with numba.objmode(ret=ret_sig):
ret = np.unique(x, return_index, return_inverse, return_counts, axis)
......@@ -296,17 +297,17 @@ def numba_funcify_UnravelIndex(op, node, **kwargs):
if len(node.outputs) == 1:
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def maybe_expand_dim(arr):
return arr
else:
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def maybe_expand_dim(arr):
return np.expand_dims(arr, 1)
@numba.njit
@numba_basic.numba_njit
def unravelindex(arr, shape):
a = np.ones(len(shape), dtype=np.int64)
a[1:] = shape[:0:-1]
......@@ -339,7 +340,7 @@ def numba_funcify_Searchsorted(op, node, **kwargs):
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
@numba_basic.numba_njit
def searchsorted(a, v, sorter):
with numba.objmode(ret=ret_sig):
ret = np.searchsorted(a, v, side, sorter)
......@@ -347,7 +348,7 @@ def numba_funcify_Searchsorted(op, node, **kwargs):
else:
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def searchsorted(a, v):
return np.searchsorted(a, v, side)
......
......@@ -38,7 +38,7 @@ def numba_funcify_SVD(op, node, **kwargs):
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
@numba_basic.numba_njit
def svd(x):
with numba.objmode(ret=ret_sig):
ret = np.linalg.svd(x, full_matrices, compute_uv)
......@@ -49,7 +49,7 @@ def numba_funcify_SVD(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def svd(x):
return np.linalg.svd(inputs_cast(x), full_matrices)
......@@ -62,7 +62,7 @@ def numba_funcify_Det(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def det(x):
return numba_basic.direct_cast(np.linalg.det(inputs_cast(x)), out_dtype)
......@@ -77,7 +77,7 @@ def numba_funcify_Eig(op, node, **kwargs):
inputs_cast = int_to_float_fn(node.inputs, out_dtype_1)
@numba.njit
@numba_basic.numba_njit
def eig(x):
out = np.linalg.eig(inputs_cast(x))
return (out[0].astype(out_dtype_1), out[1].astype(out_dtype_2))
......@@ -104,7 +104,7 @@ def numba_funcify_Eigh(op, node, **kwargs):
[get_numba_type(node.outputs[0].type), get_numba_type(node.outputs[1].type)]
)
@numba.njit
@numba_basic.numba_njit
def eigh(x):
with numba.objmode(ret=ret_sig):
out = np.linalg.eigh(x, UPLO=uplo)
......@@ -113,7 +113,7 @@ def numba_funcify_Eigh(op, node, **kwargs):
else:
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def eigh(x):
return np.linalg.eigh(x)
......@@ -126,7 +126,7 @@ def numba_funcify_Inv(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def inv(x):
return np.linalg.inv(inputs_cast(x)).astype(out_dtype)
......@@ -139,7 +139,7 @@ def numba_funcify_MatrixInverse(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def matrix_inverse(x):
return np.linalg.inv(inputs_cast(x)).astype(out_dtype)
......@@ -152,7 +152,7 @@ def numba_funcify_MatrixPinv(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def matrixpinv(x):
return np.linalg.pinv(inputs_cast(x)).astype(out_dtype)
......@@ -177,7 +177,7 @@ def numba_funcify_QRFull(op, node, **kwargs):
else:
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
@numba_basic.numba_njit
def qr_full(x):
with numba.objmode(ret=ret_sig):
ret = np.linalg.qr(x, mode=mode)
......@@ -188,7 +188,7 @@ def numba_funcify_QRFull(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def qr_full(x):
return np.linalg.qr(inputs_cast(x))
......
from textwrap import dedent, indent
from typing import Any, Callable, Dict, Optional
import numba
import numba.np.unsafe.ndarray as numba_ndarray
import numpy as np
from numba import _helperlib, types
......@@ -129,7 +128,7 @@ def make_numba_random_fn(node, np_random_func):
)
bcast_fn_global_env = {
"np_random_func": np_random_func,
"numba_vectorize": numba.vectorize,
"numba_vectorize": numba_basic.numba_vectorize,
}
bcast_fn_src = f"""
......@@ -137,7 +136,9 @@ def make_numba_random_fn(node, np_random_func):
def {bcast_fn_name}({bcast_fn_input_names}):
return np_random_func({bcast_fn_input_names})
"""
bcast_fn = compile_function_src(bcast_fn_src, bcast_fn_name, bcast_fn_global_env)
bcast_fn = compile_function_src(
bcast_fn_src, bcast_fn_name, {**globals(), **bcast_fn_global_env}
)
random_fn_input_names = ", ".join(
["rng", "size", "dtype"] + [unique_names(i) for i in node.inputs[3:]]
......@@ -179,8 +180,10 @@ def {sized_fn_name}({random_fn_input_names}):
return (rng, data)
"""
)
random_fn = compile_function_src(sized_fn_src, sized_fn_name, random_fn_global_env)
random_fn = numba.njit(random_fn)
random_fn = compile_function_src(
sized_fn_src, sized_fn_name, {**globals(), **random_fn_global_env}
)
random_fn = numba_basic.numba_njit(random_fn)
return random_fn
......@@ -239,7 +242,7 @@ def create_numba_random_fn(
np_global_env = {}
np_global_env["np"] = np
np_global_env["numba_vectorize"] = numba.vectorize
np_global_env["numba_vectorize"] = numba_basic.numba_vectorize
unique_names = unique_name_generator(
[
......@@ -262,7 +265,7 @@ def {np_random_fn_name}({np_input_names}):
{scalar_fn(*np_names)}
"""
np_random_fn = compile_function_src(
np_random_fn_src, np_random_fn_name, np_global_env
np_random_fn_src, np_random_fn_name, {**globals(), **np_global_env}
)
return make_numba_random_fn(node, np_random_fn)
......
from functools import reduce
from typing import List
import numba
import numpy as np
import scipy
import scipy.special
from aesara import config
from aesara.compile.ops import ViewOp
from aesara.graph.basic import Variable
from aesara.link.numba.dispatch import basic as numba_basic
......@@ -60,16 +60,20 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
def {scalar_op_fn_name}({input_names}):
return scalar_func({input_names})
"""
scalar_op_fn = compile_function_src(scalar_op_src, scalar_op_fn_name, global_env)
scalar_op_fn = compile_function_src(
scalar_op_src, scalar_op_fn_name, {**globals(), **global_env}
)
signature = create_numba_signature(node, force_scalar=True)
return numba.njit(signature, inline="always")(scalar_op_fn)
return numba_basic.numba_njit(
signature, inline="always", fastmath=config.numba__fastmath
)(scalar_op_fn)
@numba_funcify.register(Switch)
def numba_funcify_Switch(op, node, **kwargs):
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def switch(condition, x, y):
if condition:
return x
......@@ -90,7 +94,7 @@ def binary_to_nary_func(inputs: List[Variable], binary_op_name: str, binary_op:
def {binary_op_name}({input_signature}):
return {output_expr}
"""
nary_fn = compile_function_src(nary_src, binary_op_name)
nary_fn = compile_function_src(nary_src, binary_op_name, globals())
return nary_fn
......@@ -102,7 +106,9 @@ def numba_funcify_Add(op, node, **kwargs):
nary_add_fn = binary_to_nary_func(node.inputs, "add", "+")
return numba.njit(signature, inline="always")(nary_add_fn)
return numba_basic.numba_njit(
signature, inline="always", fastmath=config.numba__fastmath
)(nary_add_fn)
@numba_funcify.register(Mul)
......@@ -112,7 +118,9 @@ def numba_funcify_Mul(op, node, **kwargs):
nary_mul_fn = binary_to_nary_func(node.inputs, "mul", "*")
return numba.njit(signature, inline="always")(nary_mul_fn)
return numba_basic.numba_njit(
signature, inline="always", fastmath=config.numba__fastmath
)(nary_mul_fn)
@numba_funcify.register(Cast)
......@@ -120,7 +128,7 @@ def numba_funcify_Cast(op, node, **kwargs):
dtype = np.dtype(op.o_type.dtype)
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def cast(x):
return numba_basic.direct_cast(x, dtype)
......@@ -130,7 +138,7 @@ def numba_funcify_Cast(op, node, **kwargs):
@numba_funcify.register(Identity)
@numba_funcify.register(ViewOp)
def numba_funcify_ViewOp(op, **kwargs):
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def viewop(x):
return x
......@@ -139,7 +147,7 @@ def numba_funcify_ViewOp(op, **kwargs):
@numba_funcify.register(Clip)
def numba_funcify_Clip(op, **kwargs):
@numba.njit
@numba_basic.numba_njit
def clip(_x, _min, _max):
x = numba_basic.to_scalar(_x)
_min_scalar = numba_basic.to_scalar(_min)
......@@ -158,7 +166,7 @@ def numba_funcify_Clip(op, **kwargs):
@numba_funcify.register(Composite)
def numba_funcify_Composite(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True)
composite_fn = numba.njit(signature)(
composite_fn = numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)(
numba_funcify(op.fgraph, squeeze_output=True, **kwargs)
)
return composite_fn
......@@ -166,7 +174,7 @@ def numba_funcify_Composite(op, node, **kwargs):
@numba_funcify.register(Second)
def numba_funcify_Second(op, node, **kwargs):
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def second(x, y):
return y
......@@ -175,7 +183,7 @@ def numba_funcify_Second(op, node, **kwargs):
@numba_funcify.register(Inv)
def numba_funcify_Inv(op, node, **kwargs):
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def inv(x):
return 1 / x
......
import numba
import numpy as np
from numba import types
from numba.extending import overload
from aesara.graph.fg import FunctionGraph
from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.dispatch.basic import (
create_arg_string,
create_tuple_string,
......@@ -35,7 +35,7 @@ def array0d_range(x):
@numba_funcify.register(Scan)
def numba_funcify_Scan(op, node, **kwargs):
inner_fg = FunctionGraph(op.inputs, op.outputs)
numba_at_inner_func = numba.njit(numba_funcify(inner_fg, **kwargs))
numba_at_inner_func = numba_basic.numba_njit(numba_funcify(inner_fg, **kwargs))
n_seqs = op.info.n_seqs
n_mit_mot = op.info.n_mit_mot
......@@ -150,6 +150,8 @@ def scan(n_steps, {", ".join(input_names)}):
outer_in_nit_sot_names
)}
"""
scalar_op_fn = compile_function_src(scan_op_src, "scan", global_env)
scalar_op_fn = compile_function_src(
scan_op_src, "scan", {**globals(), **global_env}
)
return numba.njit(scalar_op_fn)
return numba_basic.numba_njit(scalar_op_fn)
......@@ -52,9 +52,11 @@ def allocempty({", ".join(shape_var_names)}):
return np.empty(scalar_shape, dtype)
"""
alloc_fn = compile_function_src(alloc_def_src, "allocempty", global_env)
alloc_fn = compile_function_src(
alloc_def_src, "allocempty", {**globals(), **global_env}
)
return numba.njit(alloc_fn)
return numba_basic.numba_njit(alloc_fn)
@numba_funcify.register(Alloc)
......@@ -88,16 +90,16 @@ def alloc(val, {", ".join(shape_var_names)}):
return res
"""
alloc_fn = compile_function_src(alloc_def_src, "alloc", global_env)
alloc_fn = compile_function_src(alloc_def_src, "alloc", {**globals(), **global_env})
return numba.njit(alloc_fn)
return numba_basic.numba_njit(alloc_fn)
@numba_funcify.register(AllocDiag)
def numba_funcify_AllocDiag(op, **kwargs):
offset = op.offset
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def allocdiag(v):
return np.diag(v, k=offset)
......@@ -108,7 +110,7 @@ def numba_funcify_AllocDiag(op, **kwargs):
def numba_funcify_ARange(op, **kwargs):
dtype = np.dtype(op.dtype)
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def arange(start, stop, step):
return np.arange(
numba_basic.to_scalar(start),
......@@ -130,7 +132,7 @@ def numba_funcify_Join(op, **kwargs):
# probably just remove it.
raise NotImplementedError("The `view` parameter to `Join` is not supported")
@numba.njit
@numba_basic.numba_njit
def join(axis, *tensors):
return np.concatenate(tensors, numba_basic.to_scalar(axis))
......@@ -143,7 +145,7 @@ def numba_funcify_ExtractDiag(op, **kwargs):
# axis1 = op.axis1
# axis2 = op.axis2
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def extract_diag(x):
return np.diag(x, k=offset)
......@@ -154,7 +156,7 @@ def numba_funcify_ExtractDiag(op, **kwargs):
def numba_funcify_Eye(op, **kwargs):
dtype = np.dtype(op.dtype)
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def eye(N, M, k):
return np.eye(
numba_basic.to_scalar(N),
......@@ -187,16 +189,18 @@ def makevector({", ".join(input_names)}):
return np.array({create_list_string(input_names)}, dtype=np.{dtype})
"""
makevector_fn = compile_function_src(makevector_def_src, "makevector", global_env)
makevector_fn = compile_function_src(
makevector_def_src, "makevector", {**globals(), **global_env}
)
return numba.njit(makevector_fn)
return numba_basic.numba_njit(makevector_fn)
@numba_funcify.register(Rebroadcast)
def numba_funcify_Rebroadcast(op, **kwargs):
op_axis = tuple(op.axis.items())
@numba.njit
@numba_basic.numba_njit
def rebroadcast(x):
for axis, value in numba.literal_unroll(op_axis):
if value and x.shape[axis] != 1:
......@@ -210,7 +214,7 @@ def numba_funcify_Rebroadcast(op, **kwargs):
@numba_funcify.register(TensorFromScalar)
def numba_funcify_TensorFromScalar(op, **kwargs):
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def tensor_from_scalar(x):
return np.array(x)
......@@ -219,7 +223,7 @@ def numba_funcify_TensorFromScalar(op, **kwargs):
@numba_funcify.register(ScalarFromTensor)
def numba_funcify_ScalarFromTensor(op, **kwargs):
@numba.njit(inline="always")
@numba_basic.numba_njit(inline="always")
def scalar_from_tensor(x):
return x.item()
......
......@@ -14,8 +14,7 @@ from typing import Any, Callable, Dict, Iterable, List, NoReturn, Optional, Tupl
import numpy as np
from aesara import utils
from aesara.configdefaults import config
from aesara import config, utils
from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.fg import FunctionGraph
......@@ -768,7 +767,7 @@ def {fgraph_name}({", ".join(fgraph_input_names)}):
local_env = locals()
fgraph_def = compile_function_src(
fgraph_def_src, fgraph_name, global_env, local_env
fgraph_def_src, fgraph_name, {**globals(), **global_env}, local_env
)
return fgraph_def
......@@ -151,23 +151,27 @@ def eval_python_only(fn_inputs, fgraph, inputs):
else:
return wrap
with mock.patch("numba.njit", njit_noop), mock.patch(
"numba.vectorize",
vectorize_noop,
), mock.patch(
"aesara.link.numba.dispatch.elemwise.tuple_setitem",
py_tuple_setitem,
), mock.patch(
"aesara.link.numba.dispatch.basic.direct_cast", lambda x, dtype: x
), mock.patch(
"aesara.link.numba.dispatch.basic.numba.np.numpy_support.from_dtype",
lambda dtype: dtype,
), mock.patch(
"aesara.link.numba.dispatch.basic.to_scalar", py_to_scalar
), mock.patch(
"numba.np.unsafe.ndarray.to_fixed_tuple",
lambda x, n: tuple(x),
):
mocks = [
mock.patch("numba.njit", njit_noop),
mock.patch("numba.vectorize", vectorize_noop),
mock.patch(
"aesara.link.numba.dispatch.elemwise.tuple_setitem", py_tuple_setitem
),
mock.patch("aesara.link.numba.dispatch.basic.numba_njit", njit_noop),
mock.patch("aesara.link.numba.dispatch.basic.numba_vectorize", vectorize_noop),
mock.patch("aesara.link.numba.dispatch.basic.direct_cast", lambda x, dtype: x),
mock.patch(
"aesara.link.numba.dispatch.basic.numba.np.numpy_support.from_dtype",
lambda dtype: dtype,
),
mock.patch("aesara.link.numba.dispatch.basic.to_scalar", py_to_scalar),
mock.patch("numba.np.unsafe.ndarray.to_fixed_tuple", lambda x, n: tuple(x)),
]
with contextlib.ExitStack() as stack:
for ctx in mocks:
stack.enter_context(ctx)
aesara_numba_fn = function(
fn_inputs,
fgraph.outputs,
......@@ -330,7 +334,6 @@ def test_numba_box_unbox(input, wrapper_fn, check_fn):
None,
),
(
# This also tests the use of repeated arguments
[at.matrix(), at.scalar()],
[rng.normal(size=(2, 2)).astype(config.floatX), 0.0],
lambda a, b: at.switch(a, b, a),
......@@ -3272,3 +3275,38 @@ def test_numba_ifelse(inputs, cond_fn, true_vals, false_vals):
out_fg = FunctionGraph(inputs, out)
compare_numba_and_py(out_fg, [get_test_value(i) for i in out_fg.inputs])
@pytest.mark.xfail(reason="https://github.com/numba/numba/issues/7409")
def test_config_options_parallel():
x = at.dvector()
with config.change_flags(numba__vectorize_target="parallel"):
aesara_numba_fn = function([x], x * 2, mode=numba_mode)
numba_mul_fn = aesara_numba_fn.fn.jit_fn.py_func.__globals__["mul"]
assert numba_mul_fn.targetoptions["parallel"] is True
def test_config_options_fastmath():
x = at.dvector()
with config.change_flags(numba__fastmath=True):
aesara_numba_fn = function([x], x * 2, mode=numba_mode)
numba_mul_fn = aesara_numba_fn.fn.jit_fn.py_func.__globals__["mul"]
assert numba_mul_fn.targetoptions["fastmath"] is True
def test_config_options_cached():
x = at.dvector()
with config.change_flags(numba__cache=True):
aesara_numba_fn = function([x], x * 2, mode=numba_mode)
numba_mul_fn = aesara_numba_fn.fn.jit_fn.py_func.__globals__["mul"]
assert not isinstance(
numba_mul_fn._dispatcher.cache, numba.core.caching.NullCache
)
with config.change_flags(numba__cache=False):
aesara_numba_fn = function([x], x * 2, mode=numba_mode)
numba_mul_fn = aesara_numba_fn.fn.jit_fn.py_func.__globals__["mul"]
assert isinstance(numba_mul_fn._dispatcher.cache, numba.core.caching.NullCache)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论