提交 1fc678c5 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Use more specific Numba fastmath flags everywhere

上级 ab3704b3
......@@ -358,13 +358,13 @@ Here's an example for the `CumOp`\ `Op`:
if mode == "add":
if axis is None or ndim == 1:
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit()
def cumop(x):
return np.cumsum(x)
else:
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
@numba_basic.numba_njit(boundscheck=False)
def cumop(x):
out_dtype = x.dtype
if x.shape[axis] < 2:
......@@ -382,13 +382,13 @@ Here's an example for the `CumOp`\ `Op`:
else:
if axis is None or ndim == 1:
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit()
def cumop(x):
return np.cumprod(x)
else:
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
@numba_basic.numba_njit(boundscheck=False)
def cumop(x):
out_dtype = x.dtype
if x.shape[axis] < 2:
......
......@@ -49,10 +49,23 @@ def global_numba_func(func):
return func
def numba_njit(*args, **kwargs):
def numba_njit(*args, fastmath=None, **kwargs):
kwargs.setdefault("cache", config.numba__cache)
kwargs.setdefault("no_cpython_wrapper", True)
kwargs.setdefault("no_cfunc_wrapper", True)
if fastmath is None:
if config.numba__fastmath:
# Opinionated default on fastmath flags
# https://llvm.org/docs/LangRef.html#fast-math-flags
fastmath = {
"arcp", # Allow Reciprocal
"contract", # Allow floating-point contraction
"afn", # Approximate functions
"reassoc",
"nsz", # no-signed zeros
}
else:
fastmath = False
# Suppress cache warning for internal functions
# We have to add an ansi escape code for optional bold text by numba
......@@ -68,9 +81,9 @@ def numba_njit(*args, **kwargs):
)
if len(args) > 0 and callable(args[0]):
return numba.njit(*args[1:], **kwargs)(args[0])
return numba.njit(*args[1:], fastmath=fastmath, **kwargs)(args[0])
return numba.njit(*args, **kwargs)
return numba.njit(*args, fastmath=fastmath, **kwargs)
def numba_vectorize(*args, **kwargs):
......
......@@ -32,7 +32,6 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
core_op,
node=core_node,
parent_node=node,
fastmath=_jit_options["fastmath"],
**kwargs,
)
core_op_fn = store_core_outputs(core_op_fn, nin=nin, nout=nout)
......
......@@ -6,7 +6,6 @@ import numpy as np
from numba.core.extending import overload
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
from pytensor import config
from pytensor.graph.op import Op
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
......@@ -281,7 +280,6 @@ def jit_compile_reducer(
res = numba_basic.numba_njit(
*args,
boundscheck=False,
fastmath=config.numba__fastmath,
**kwds,
)(fn)
......@@ -315,7 +313,6 @@ def numba_funcify_Elemwise(op, node, **kwargs):
op.scalar_op,
node=scalar_node,
parent_node=node,
fastmath=_jit_options["fastmath"],
**kwargs,
)
......@@ -403,13 +400,13 @@ def numba_funcify_Sum(op, node, **kwargs):
if ndim_input == len(axes):
# Slightly faster than `numba_funcify_CAReduce` for this case
@numba_njit(fastmath=config.numba__fastmath)
@numba_njit
def impl_sum(array):
return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype)
elif len(axes) == 0:
# These cases should be removed by rewrites!
@numba_njit(fastmath=config.numba__fastmath)
@numba_njit
def impl_sum(array):
return np.asarray(array, dtype=out_dtype)
......@@ -568,9 +565,7 @@ def numba_funcify_Softmax(op, node, **kwargs):
add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True
)
jit_fn = numba_basic.numba_njit(
boundscheck=False, fastmath=config.numba__fastmath
)
jit_fn = numba_basic.numba_njit(boundscheck=False)
reduce_max = jit_fn(reduce_max_py)
reduce_sum = jit_fn(reduce_sum_py)
else:
......@@ -602,9 +597,7 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs):
add_as, 0.0, (axis,), sm_at.ndim, sm_dtype, keepdims=True
)
jit_fn = numba_basic.numba_njit(
boundscheck=False, fastmath=config.numba__fastmath
)
jit_fn = numba_basic.numba_njit(boundscheck=False)
reduce_sum = jit_fn(reduce_sum_py)
else:
reduce_sum = np.sum
......@@ -642,9 +635,7 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True
)
jit_fn = numba_basic.numba_njit(
boundscheck=False, fastmath=config.numba__fastmath
)
jit_fn = numba_basic.numba_njit(boundscheck=False)
reduce_max = jit_fn(reduce_max_py)
reduce_sum = jit_fn(reduce_sum_py)
else:
......
......@@ -4,7 +4,6 @@ from typing import cast
import numba
import numpy as np
from pytensor import config
from pytensor.graph import Apply
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify
......@@ -50,13 +49,13 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
if mode == "add":
if axis is None or ndim == 1:
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit
def cumop(x):
return np.cumsum(x)
else:
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
@numba_basic.numba_njit(boundscheck=False)
def cumop(x):
out_dtype = x.dtype
if x.shape[axis] < 2:
......@@ -74,13 +73,13 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
else:
if axis is None or ndim == 1:
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit
def cumop(x):
return np.cumprod(x)
else:
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
@numba_basic.numba_njit(boundscheck=False)
def cumop(x):
out_dtype = x.dtype
if x.shape[axis] < 2:
......
......@@ -2,7 +2,6 @@ import math
import numpy as np
from pytensor import config
from pytensor.compile.ops import ViewOp
from pytensor.graph.basic import Variable
from pytensor.link.numba.dispatch import basic as numba_basic
......@@ -137,7 +136,6 @@ def {scalar_op_fn_name}({', '.join(input_names)}):
return numba_basic.numba_njit(
signature,
fastmath=config.numba__fastmath,
# Functions that call a function pointer can't be cached
cache=False,
)(scalar_op_fn)
......@@ -177,9 +175,7 @@ def numba_funcify_Add(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True)
nary_add_fn = binary_to_nary_func(node.inputs, "add", "+")
return numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)(
nary_add_fn
)
return numba_basic.numba_njit(signature)(nary_add_fn)
@numba_funcify.register(Mul)
......@@ -187,9 +183,7 @@ def numba_funcify_Mul(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True)
nary_add_fn = binary_to_nary_func(node.inputs, "mul", "*")
return numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)(
nary_add_fn
)
return numba_basic.numba_njit(signature)(nary_add_fn)
@numba_funcify.register(Cast)
......@@ -239,7 +233,7 @@ def numba_funcify_Composite(op, node, **kwargs):
_ = kwargs.pop("storage_map", None)
composite_fn = numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)(
composite_fn = numba_basic.numba_njit(signature)(
numba_funcify(op.fgraph, squeeze_output=True, **kwargs)
)
return composite_fn
......@@ -267,7 +261,7 @@ def numba_funcify_Reciprocal(op, node, **kwargs):
return numba_basic.global_numba_func(reciprocal)
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit
def sigmoid(x):
return 1 / (1 + np.exp(-x))
......@@ -277,7 +271,7 @@ def numba_funcify_Sigmoid(op, node, **kwargs):
return numba_basic.global_numba_func(sigmoid)
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit
def gammaln(x):
return math.lgamma(x)
......@@ -287,7 +281,7 @@ def numba_funcify_GammaLn(op, node, **kwargs):
return numba_basic.global_numba_func(gammaln)
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit
def logp1mexp(x):
if x < np.log(0.5):
return np.log1p(-np.exp(x))
......@@ -300,7 +294,7 @@ def numba_funcify_Log1mexp(op, node, **kwargs):
return numba_basic.global_numba_func(logp1mexp)
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit
def erf(x):
return math.erf(x)
......@@ -310,7 +304,7 @@ def numba_funcify_Erf(op, **kwargs):
return numba_basic.global_numba_func(erf)
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit
def erfc(x):
return math.erfc(x)
......
......@@ -838,7 +838,13 @@ def test_config_options_fastmath():
pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode)
print(list(pytensor_numba_fn.vm.jit_fn.py_func.__globals__))
numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
assert numba_mul_fn.targetoptions["fastmath"] is True
assert numba_mul_fn.targetoptions["fastmath"] == {
"afn",
"arcp",
"contract",
"nsz",
"reassoc",
}
def test_config_options_cached():
......
......@@ -9,6 +9,7 @@ from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.scalar.basic import Composite
from pytensor.tensor import tensor
from pytensor.tensor.elemwise import Elemwise
from tests.link.numba.test_basic import compare_numba_and_py, set_test_value
......@@ -140,3 +141,21 @@ def test_reciprocal(v, dtype):
if not isinstance(i, SharedVariable | Constant)
],
)
@pytest.mark.parametrize("composite", (False, True))
def test_isnan(composite):
# Testing with tensor just to make sure Elemwise does not revert the scalar behavior of fastmath
x = tensor(shape=(2,), dtype="float64")
if composite:
x_scalar = psb.float64()
scalar_out = ~psb.isnan(x_scalar)
out = Elemwise(Composite([x_scalar], [scalar_out]))(x)
else:
out = pt.isnan(x)
compare_numba_and_py(
([x], [out]),
[np.array([1, 0], dtype="float64")],
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论