提交 afb4885e authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Use some global njit functions in numba

This allows numba to reuse previous typing and compilation results if the same function is reused, which then also leads to smaller llvm modules. For the tests to continue to work we have to return those global functions through a wrapper (`basic.global_numba_func`) so that the tests are still able to disable compilation. Also remove some inline="always" arguments that don't seem to be helpful.
上级 f4de2fd2
...@@ -2,6 +2,7 @@ import operator ...@@ -2,6 +2,7 @@ import operator
import sys import sys
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from copy import copy
from functools import singledispatch from functools import singledispatch
from textwrap import dedent from textwrap import dedent
from typing import Union from typing import Union
...@@ -15,7 +16,7 @@ from llvmlite import ir ...@@ -15,7 +16,7 @@ from llvmlite import ir
from numba import types from numba import types
from numba.core.errors import TypingError from numba.core.errors import TypingError
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401 from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
from numba.extending import box from numba.extending import box, overload
from pytensor import config from pytensor import config
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
...@@ -47,6 +48,14 @@ from pytensor.tensor.type import TensorType ...@@ -47,6 +48,14 @@ from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import MakeSlice, NoneConst from pytensor.tensor.type_other import MakeSlice, NoneConst
def global_numba_func(func):
"""Use to return global numba functions in numba_funcify_*.
This allows tests to remove the compilation using mock.
"""
return func
def numba_njit(*args, **kwargs): def numba_njit(*args, **kwargs):
kwargs = kwargs.copy() kwargs = kwargs.copy()
...@@ -573,29 +582,36 @@ def numba_funcify_IncSubtensor(op, node, **kwargs): ...@@ -573,29 +582,36 @@ def numba_funcify_IncSubtensor(op, node, **kwargs):
return numba_njit(incsubtensor_fn, boundscheck=True) return numba_njit(incsubtensor_fn, boundscheck=True)
@numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace_set(x, vals, idxs):
for idx, val in zip(idxs, vals):
x[idx] = val
return x
@numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace_inc(x, vals, idxs):
for idx, val in zip(idxs, vals):
x[idx] += val
return x
@numba_funcify.register(AdvancedIncSubtensor1) @numba_funcify.register(AdvancedIncSubtensor1)
def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs): def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
inplace = op.inplace inplace = op.inplace
set_instead_of_inc = op.set_instead_of_inc set_instead_of_inc = op.set_instead_of_inc
if set_instead_of_inc: if set_instead_of_inc:
advancedincsubtensor1_inplace = global_numba_func(
@numba_njit(boundscheck=True) advancedincsubtensor1_inplace_set
def advancedincsubtensor1_inplace(x, vals, idxs): )
for idx, val in zip(idxs, vals):
x[idx] = val
return x
else: else:
advancedincsubtensor1_inplace = global_numba_func(
@numba_njit(boundscheck=True) advancedincsubtensor1_inplace_inc
def advancedincsubtensor1_inplace(x, vals, idxs): )
for idx, val in zip(idxs, vals):
x[idx] += val
return x
if inplace: if inplace:
return advancedincsubtensor1_inplace return global_numba_func(advancedincsubtensor1_inplace)
else: else:
@numba_njit @numba_njit
...@@ -606,51 +622,48 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs): ...@@ -606,51 +622,48 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
return advancedincsubtensor1 return advancedincsubtensor1
@numba_funcify.register(DeepCopyOp) def deepcopyop(x):
def numba_funcify_DeepCopyOp(op, node, **kwargs): return copy(x)
# Scalars are apparently returned as actual Python scalar types and not
# NumPy scalars, so we need two separate Numba functions for each case.
# The type can also be RandomType with no ndims @overload(deepcopyop)
if not hasattr(node.outputs[0].type, "ndim") or node.outputs[0].type.ndim == 0: def dispatch_deepcopyop(x):
# TODO: Do we really need to compile a pass-through function like this? if isinstance(x, types.Array):
@numba_njit(inline="always") return lambda x: np.copy(x)
def deepcopyop(x):
return x
else: return lambda x: x
@numba_njit(inline="always")
def deepcopyop(x):
return x.copy()
@numba_funcify.register(DeepCopyOp)
def numba_funcify_DeepCopyOp(op, node, **kwargs):
return deepcopyop return deepcopyop
@numba_njit
def makeslice(*x):
return slice(*x)
@numba_funcify.register(MakeSlice) @numba_funcify.register(MakeSlice)
def numba_funcify_MakeSlice(op, **kwargs): def numba_funcify_MakeSlice(op, **kwargs):
@numba_njit return global_numba_func(makeslice)
def makeslice(*x):
return slice(*x)
return makeslice
@numba_njit
def shape(x):
return np.asarray(np.shape(x))
@numba_funcify.register(Shape) @numba_funcify.register(Shape)
def numba_funcify_Shape(op, **kwargs): def numba_funcify_Shape(op, **kwargs):
@numba_njit(inline="always") return global_numba_func(shape)
def shape(x):
return np.asarray(np.shape(x))
return shape
@numba_funcify.register(Shape_i) @numba_funcify.register(Shape_i)
def numba_funcify_Shape_i(op, **kwargs): def numba_funcify_Shape_i(op, **kwargs):
i = op.i i = op.i
@numba_njit(inline="always") @numba_njit
def shape_i(x): def shape_i(x):
return np.shape(x)[i] return np.shape(x)[i]
...@@ -683,13 +696,13 @@ def numba_funcify_Reshape(op, **kwargs): ...@@ -683,13 +696,13 @@ def numba_funcify_Reshape(op, **kwargs):
if ndim == 0: if ndim == 0:
@numba_njit(inline="always") @numba_njit
def reshape(x, shape): def reshape(x, shape):
return x.item() return x.item()
else: else:
@numba_njit(inline="always") @numba_njit
def reshape(x, shape): def reshape(x, shape):
# TODO: Use this until https://github.com/numba/numba/issues/7353 is closed. # TODO: Use this until https://github.com/numba/numba/issues/7353 is closed.
return np.reshape( return np.reshape(
...@@ -732,7 +745,7 @@ def int_to_float_fn(inputs, out_dtype): ...@@ -732,7 +745,7 @@ def int_to_float_fn(inputs, out_dtype):
args_dtype = np.dtype(f"f{out_dtype.itemsize}") args_dtype = np.dtype(f"f{out_dtype.itemsize}")
@numba_njit(inline="always") @numba_njit
def inputs_cast(x): def inputs_cast(x):
return x.astype(args_dtype) return x.astype(args_dtype)
...@@ -740,7 +753,7 @@ def int_to_float_fn(inputs, out_dtype): ...@@ -740,7 +753,7 @@ def int_to_float_fn(inputs, out_dtype):
args_dtype_sz = max(_arg.type.numpy_dtype.itemsize for _arg in inputs) args_dtype_sz = max(_arg.type.numpy_dtype.itemsize for _arg in inputs)
args_dtype = np.dtype(f"f{args_dtype_sz}") args_dtype = np.dtype(f"f{args_dtype_sz}")
@numba_njit(inline="always") @numba_njit
def inputs_cast(x): def inputs_cast(x):
return x.astype(args_dtype) return x.astype(args_dtype)
...@@ -755,7 +768,7 @@ def numba_funcify_Dot(op, node, **kwargs): ...@@ -755,7 +768,7 @@ def numba_funcify_Dot(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype) inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba_njit(inline="always") @numba_njit
def dot(x, y): def dot(x, y):
return np.asarray(np.dot(inputs_cast(x), inputs_cast(y))).astype(out_dtype) return np.asarray(np.dot(inputs_cast(x), inputs_cast(y))).astype(out_dtype)
...@@ -770,13 +783,14 @@ def numba_funcify_Softplus(op, node, **kwargs): ...@@ -770,13 +783,14 @@ def numba_funcify_Softplus(op, node, **kwargs):
@numba_njit @numba_njit
def softplus(x): def softplus(x):
if x < -37.0: if x < -37.0:
return direct_cast(np.exp(x), x_dtype) value = np.exp(x)
elif x < 18.0: elif x < 18.0:
return direct_cast(np.log1p(np.exp(x)), x_dtype) value = np.log1p(np.exp(x))
elif x < 33.3: elif x < 33.3:
return direct_cast(x + np.exp(-x), x_dtype) value = x + np.exp(-x)
else: else:
return direct_cast(x, x_dtype) value = x
return direct_cast(value, x_dtype)
return softplus return softplus
...@@ -791,7 +805,7 @@ def numba_funcify_Cholesky(op, node, **kwargs): ...@@ -791,7 +805,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
inputs_cast = int_to_float_fn(node.inputs, out_dtype) inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba_njit(inline="always") @numba_njit
def cholesky(a): def cholesky(a):
return np.linalg.cholesky(inputs_cast(a)).astype(out_dtype) return np.linalg.cholesky(inputs_cast(a)).astype(out_dtype)
...@@ -852,7 +866,7 @@ def numba_funcify_Solve(op, node, **kwargs): ...@@ -852,7 +866,7 @@ def numba_funcify_Solve(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype) inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba_njit(inline="always") @numba_njit
def solve(a, b): def solve(a, b):
return np.linalg.solve( return np.linalg.solve(
inputs_cast(a), inputs_cast(a),
......
...@@ -145,22 +145,23 @@ def {scalar_op_fn_name}({', '.join(input_names)}): ...@@ -145,22 +145,23 @@ def {scalar_op_fn_name}({', '.join(input_names)}):
return numba_basic.numba_njit( return numba_basic.numba_njit(
signature, signature,
inline="always",
fastmath=config.numba__fastmath, fastmath=config.numba__fastmath,
# Functions that call a function pointer can't be cached
cache=False, cache=False,
)(scalar_op_fn) )(scalar_op_fn)
@numba_basic.numba_njit
def switch(condition, x, y):
if condition:
return x
else:
return y
@numba_funcify.register(Switch) @numba_funcify.register(Switch)
def numba_funcify_Switch(op, node, **kwargs): def numba_funcify_Switch(op, node, **kwargs):
@numba_basic.numba_njit(inline="always") return numba_basic.global_numba_func(switch)
def switch(condition, x, y):
if condition:
return x
else:
return y
return switch
def binary_to_nary_func(inputs: List[Variable], binary_op_name: str, binary_op: str): def binary_to_nary_func(inputs: List[Variable], binary_op_name: str, binary_op: str):
...@@ -181,26 +182,22 @@ def {binary_op_name}({input_signature}): ...@@ -181,26 +182,22 @@ def {binary_op_name}({input_signature}):
@numba_funcify.register(Add) @numba_funcify.register(Add)
def numba_funcify_Add(op, node, **kwargs): def numba_funcify_Add(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True) signature = create_numba_signature(node, force_scalar=True)
nary_add_fn = binary_to_nary_func(node.inputs, "add", "+") nary_add_fn = binary_to_nary_func(node.inputs, "add", "+")
return numba_basic.numba_njit( return numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)(
signature, inline="always", fastmath=config.numba__fastmath nary_add_fn
)(nary_add_fn) )
@numba_funcify.register(Mul) @numba_funcify.register(Mul)
def numba_funcify_Mul(op, node, **kwargs): def numba_funcify_Mul(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True) signature = create_numba_signature(node, force_scalar=True)
nary_add_fn = binary_to_nary_func(node.inputs, "mul", "*")
nary_mul_fn = binary_to_nary_func(node.inputs, "mul", "*") return numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)(
nary_add_fn
return numba_basic.numba_njit( )
signature, inline="always", fastmath=config.numba__fastmath
)(nary_mul_fn)
@numba_funcify.register(Cast) @numba_funcify.register(Cast)
...@@ -208,39 +205,41 @@ def numba_funcify_Cast(op, node, **kwargs): ...@@ -208,39 +205,41 @@ def numba_funcify_Cast(op, node, **kwargs):
dtype = np.dtype(op.o_type.dtype) dtype = np.dtype(op.o_type.dtype)
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit
def cast(x): def cast(x):
return numba_basic.direct_cast(x, dtype) return numba_basic.direct_cast(x, dtype)
return cast return cast
@numba_basic.numba_njit
def viewop(x):
return x
@numba_funcify.register(Identity) @numba_funcify.register(Identity)
@numba_funcify.register(ViewOp) @numba_funcify.register(ViewOp)
def numba_funcify_ViewOp(op, **kwargs): def numba_funcify_ViewOp(op, **kwargs):
@numba_basic.numba_njit(inline="always") return numba_basic.global_numba_func(viewop)
def viewop(x):
return x
return viewop
@numba_basic.numba_njit
def clip(_x, _min, _max):
x = numba_basic.to_scalar(_x)
_min_scalar = numba_basic.to_scalar(_min)
_max_scalar = numba_basic.to_scalar(_max)
if x < _min_scalar:
return _min_scalar
elif x > _max_scalar:
return _max_scalar
else:
return x
@numba_funcify.register(Clip) @numba_funcify.register(Clip)
def numba_funcify_Clip(op, **kwargs): def numba_funcify_Clip(op, **kwargs):
@numba_basic.numba_njit return numba_basic.global_numba_func(clip)
def clip(_x, _min, _max):
x = numba_basic.to_scalar(_x)
_min_scalar = numba_basic.to_scalar(_min)
_max_scalar = numba_basic.to_scalar(_max)
if x < _min_scalar:
return _min_scalar
elif x > _max_scalar:
return _max_scalar
else:
return x
return clip
@numba_funcify.register(Composite) @numba_funcify.register(Composite)
...@@ -255,69 +254,76 @@ def numba_funcify_Composite(op, node, **kwargs): ...@@ -255,69 +254,76 @@ def numba_funcify_Composite(op, node, **kwargs):
return composite_fn return composite_fn
@numba_basic.numba_njit
def second(x, y):
return y
@numba_funcify.register(Second) @numba_funcify.register(Second)
def numba_funcify_Second(op, node, **kwargs): def numba_funcify_Second(op, node, **kwargs):
@numba_basic.numba_njit(inline="always") return numba_basic.global_numba_func(second)
def second(x, y):
return y
return second
@numba_basic.numba_njit
def reciprocal(x):
# TODO FIXME: This isn't really the behavior or `numpy.reciprocal` when
# `x` is an `int`
return 1 / x
@numba_funcify.register(Reciprocal) @numba_funcify.register(Reciprocal)
def numba_funcify_Reciprocal(op, node, **kwargs): def numba_funcify_Reciprocal(op, node, **kwargs):
@numba_basic.numba_njit(inline="always") return numba_basic.global_numba_func(reciprocal)
def reciprocal(x):
# TODO FIXME: This isn't really the behavior or `numpy.reciprocal` when
# `x` is an `int`
return 1 / x
return reciprocal @numba_basic.numba_njit(fastmath=config.numba__fastmath)
def sigmoid(x):
return 1 / (1 + np.exp(-x))
@numba_funcify.register(Sigmoid) @numba_funcify.register(Sigmoid)
def numba_funcify_Sigmoid(op, node, **kwargs): def numba_funcify_Sigmoid(op, node, **kwargs):
@numba_basic.numba_njit(inline="always", fastmath=config.numba__fastmath) return numba_basic.global_numba_func(sigmoid)
def sigmoid(x):
return 1 / (1 + np.exp(-x))
return sigmoid @numba_basic.numba_njit(fastmath=config.numba__fastmath)
def gammaln(x):
return math.lgamma(x)
@numba_funcify.register(GammaLn) @numba_funcify.register(GammaLn)
def numba_funcify_GammaLn(op, node, **kwargs): def numba_funcify_GammaLn(op, node, **kwargs):
@numba_basic.numba_njit(inline="always", fastmath=config.numba__fastmath) return numba_basic.global_numba_func(gammaln)
def gammaln(x):
return math.lgamma(x)
return gammaln
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
def logp1mexp(x):
if x < np.log(0.5):
return np.log1p(-np.exp(x))
else:
return np.log(-np.expm1(x))
@numba_funcify.register(Log1mexp) @numba_funcify.register(Log1mexp)
def numba_funcify_Log1mexp(op, node, **kwargs): def numba_funcify_Log1mexp(op, node, **kwargs):
@numba_basic.numba_njit(inline="always", fastmath=config.numba__fastmath) return numba_basic.global_numba_func(logp1mexp)
def logp1mexp(x):
if x < np.log(0.5):
return np.log1p(-np.exp(x))
else:
return np.log(-np.expm1(x))
return logp1mexp @numba_basic.numba_njit(fastmath=config.numba__fastmath)
def erf(x):
return math.erf(x)
@numba_funcify.register(Erf) @numba_funcify.register(Erf)
def numba_funcify_Erf(op, **kwargs): def numba_funcify_Erf(op, **kwargs):
@numba_basic.numba_njit(inline="always", fastmath=config.numba__fastmath) return numba_basic.global_numba_func(erf)
def erf(x):
return math.erf(x)
return erf
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
def erfc(x):
return math.erfc(x)
@numba_funcify.register(Erfc) @numba_funcify.register(Erfc)
def numba_funcify_Erfc(op, **kwargs): def numba_funcify_Erfc(op, **kwargs):
@numba_basic.numba_njit(inline="always", fastmath=config.numba__fastmath) return numba_basic.global_numba_func(erfc)
def erfc(x):
return math.erfc(x)
return erfc
...@@ -149,9 +149,18 @@ def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode): ...@@ -149,9 +149,18 @@ def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode):
else: else:
return wrap return wrap
def py_global_numba_func(func):
if hasattr(func, "py_func"):
return func.py_func
return func
mocks = [ mocks = [
mock.patch("numba.njit", njit_noop), mock.patch("numba.njit", njit_noop),
mock.patch("numba.vectorize", vectorize_noop), mock.patch("numba.vectorize", vectorize_noop),
mock.patch(
"pytensor.link.numba.dispatch.basic.global_numba_func",
py_global_numba_func,
),
mock.patch( mock.patch(
"pytensor.link.numba.dispatch.basic.tuple_setitem", py_tuple_setitem "pytensor.link.numba.dispatch.basic.tuple_setitem", py_tuple_setitem
), ),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论