提交 1fc678c5 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Use more specific Numba fastmath flags everywhere

上级 ab3704b3
...@@ -358,13 +358,13 @@ Here's an example for the `CumOp`\ `Op`: ...@@ -358,13 +358,13 @@ Here's an example for the `CumOp`\ `Op`:
if mode == "add": if mode == "add":
if axis is None or ndim == 1: if axis is None or ndim == 1:
@numba_basic.numba_njit(fastmath=config.numba__fastmath) @numba_basic.numba_njit()
def cumop(x): def cumop(x):
return np.cumsum(x) return np.cumsum(x)
else: else:
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) @numba_basic.numba_njit(boundscheck=False)
def cumop(x): def cumop(x):
out_dtype = x.dtype out_dtype = x.dtype
if x.shape[axis] < 2: if x.shape[axis] < 2:
...@@ -382,13 +382,13 @@ Here's an example for the `CumOp`\ `Op`: ...@@ -382,13 +382,13 @@ Here's an example for the `CumOp`\ `Op`:
else: else:
if axis is None or ndim == 1: if axis is None or ndim == 1:
@numba_basic.numba_njit(fastmath=config.numba__fastmath) @numba_basic.numba_njit()
def cumop(x): def cumop(x):
return np.cumprod(x) return np.cumprod(x)
else: else:
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) @numba_basic.numba_njit(boundscheck=False)
def cumop(x): def cumop(x):
out_dtype = x.dtype out_dtype = x.dtype
if x.shape[axis] < 2: if x.shape[axis] < 2:
......
...@@ -49,10 +49,23 @@ def global_numba_func(func): ...@@ -49,10 +49,23 @@ def global_numba_func(func):
return func return func
def numba_njit(*args, **kwargs): def numba_njit(*args, fastmath=None, **kwargs):
kwargs.setdefault("cache", config.numba__cache) kwargs.setdefault("cache", config.numba__cache)
kwargs.setdefault("no_cpython_wrapper", True) kwargs.setdefault("no_cpython_wrapper", True)
kwargs.setdefault("no_cfunc_wrapper", True) kwargs.setdefault("no_cfunc_wrapper", True)
if fastmath is None:
if config.numba__fastmath:
# Opinionated default on fastmath flags
# https://llvm.org/docs/LangRef.html#fast-math-flags
fastmath = {
"arcp", # Allow Reciprocal
"contract", # Allow floating-point contraction
"afn", # Approximate functions
"reassoc",
"nsz", # no-signed zeros
}
else:
fastmath = False
# Suppress cache warning for internal functions # Suppress cache warning for internal functions
# We have to add an ansi escape code for optional bold text by numba # We have to add an ansi escape code for optional bold text by numba
...@@ -68,9 +81,9 @@ def numba_njit(*args, **kwargs): ...@@ -68,9 +81,9 @@ def numba_njit(*args, **kwargs):
) )
if len(args) > 0 and callable(args[0]): if len(args) > 0 and callable(args[0]):
return numba.njit(*args[1:], **kwargs)(args[0]) return numba.njit(*args[1:], fastmath=fastmath, **kwargs)(args[0])
return numba.njit(*args, **kwargs) return numba.njit(*args, fastmath=fastmath, **kwargs)
def numba_vectorize(*args, **kwargs): def numba_vectorize(*args, **kwargs):
......
...@@ -32,7 +32,6 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs): ...@@ -32,7 +32,6 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
core_op, core_op,
node=core_node, node=core_node,
parent_node=node, parent_node=node,
fastmath=_jit_options["fastmath"],
**kwargs, **kwargs,
) )
core_op_fn = store_core_outputs(core_op_fn, nin=nin, nout=nout) core_op_fn = store_core_outputs(core_op_fn, nin=nin, nout=nout)
......
...@@ -6,7 +6,6 @@ import numpy as np ...@@ -6,7 +6,6 @@ import numpy as np
from numba.core.extending import overload from numba.core.extending import overload
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
from pytensor import config
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import ( from pytensor.link.numba.dispatch.basic import (
...@@ -281,7 +280,6 @@ def jit_compile_reducer( ...@@ -281,7 +280,6 @@ def jit_compile_reducer(
res = numba_basic.numba_njit( res = numba_basic.numba_njit(
*args, *args,
boundscheck=False, boundscheck=False,
fastmath=config.numba__fastmath,
**kwds, **kwds,
)(fn) )(fn)
...@@ -315,7 +313,6 @@ def numba_funcify_Elemwise(op, node, **kwargs): ...@@ -315,7 +313,6 @@ def numba_funcify_Elemwise(op, node, **kwargs):
op.scalar_op, op.scalar_op,
node=scalar_node, node=scalar_node,
parent_node=node, parent_node=node,
fastmath=_jit_options["fastmath"],
**kwargs, **kwargs,
) )
...@@ -403,13 +400,13 @@ def numba_funcify_Sum(op, node, **kwargs): ...@@ -403,13 +400,13 @@ def numba_funcify_Sum(op, node, **kwargs):
if ndim_input == len(axes): if ndim_input == len(axes):
# Slightly faster than `numba_funcify_CAReduce` for this case # Slightly faster than `numba_funcify_CAReduce` for this case
@numba_njit(fastmath=config.numba__fastmath) @numba_njit
def impl_sum(array): def impl_sum(array):
return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype) return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype)
elif len(axes) == 0: elif len(axes) == 0:
# These cases should be removed by rewrites! # These cases should be removed by rewrites!
@numba_njit(fastmath=config.numba__fastmath) @numba_njit
def impl_sum(array): def impl_sum(array):
return np.asarray(array, dtype=out_dtype) return np.asarray(array, dtype=out_dtype)
...@@ -568,9 +565,7 @@ def numba_funcify_Softmax(op, node, **kwargs): ...@@ -568,9 +565,7 @@ def numba_funcify_Softmax(op, node, **kwargs):
add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True
) )
jit_fn = numba_basic.numba_njit( jit_fn = numba_basic.numba_njit(boundscheck=False)
boundscheck=False, fastmath=config.numba__fastmath
)
reduce_max = jit_fn(reduce_max_py) reduce_max = jit_fn(reduce_max_py)
reduce_sum = jit_fn(reduce_sum_py) reduce_sum = jit_fn(reduce_sum_py)
else: else:
...@@ -602,9 +597,7 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs): ...@@ -602,9 +597,7 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs):
add_as, 0.0, (axis,), sm_at.ndim, sm_dtype, keepdims=True add_as, 0.0, (axis,), sm_at.ndim, sm_dtype, keepdims=True
) )
jit_fn = numba_basic.numba_njit( jit_fn = numba_basic.numba_njit(boundscheck=False)
boundscheck=False, fastmath=config.numba__fastmath
)
reduce_sum = jit_fn(reduce_sum_py) reduce_sum = jit_fn(reduce_sum_py)
else: else:
reduce_sum = np.sum reduce_sum = np.sum
...@@ -642,9 +635,7 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): ...@@ -642,9 +635,7 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True
) )
jit_fn = numba_basic.numba_njit( jit_fn = numba_basic.numba_njit(boundscheck=False)
boundscheck=False, fastmath=config.numba__fastmath
)
reduce_max = jit_fn(reduce_max_py) reduce_max = jit_fn(reduce_max_py)
reduce_sum = jit_fn(reduce_sum_py) reduce_sum = jit_fn(reduce_sum_py)
else: else:
......
...@@ -4,7 +4,6 @@ from typing import cast ...@@ -4,7 +4,6 @@ from typing import cast
import numba import numba
import numpy as np import numpy as np
from pytensor import config
from pytensor.graph import Apply from pytensor.graph import Apply
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify
...@@ -50,13 +49,13 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs): ...@@ -50,13 +49,13 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
if mode == "add": if mode == "add":
if axis is None or ndim == 1: if axis is None or ndim == 1:
@numba_basic.numba_njit(fastmath=config.numba__fastmath) @numba_basic.numba_njit
def cumop(x): def cumop(x):
return np.cumsum(x) return np.cumsum(x)
else: else:
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) @numba_basic.numba_njit(boundscheck=False)
def cumop(x): def cumop(x):
out_dtype = x.dtype out_dtype = x.dtype
if x.shape[axis] < 2: if x.shape[axis] < 2:
...@@ -74,13 +73,13 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs): ...@@ -74,13 +73,13 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
else: else:
if axis is None or ndim == 1: if axis is None or ndim == 1:
@numba_basic.numba_njit(fastmath=config.numba__fastmath) @numba_basic.numba_njit
def cumop(x): def cumop(x):
return np.cumprod(x) return np.cumprod(x)
else: else:
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) @numba_basic.numba_njit(boundscheck=False)
def cumop(x): def cumop(x):
out_dtype = x.dtype out_dtype = x.dtype
if x.shape[axis] < 2: if x.shape[axis] < 2:
......
...@@ -2,7 +2,6 @@ import math ...@@ -2,7 +2,6 @@ import math
import numpy as np import numpy as np
from pytensor import config
from pytensor.compile.ops import ViewOp from pytensor.compile.ops import ViewOp
from pytensor.graph.basic import Variable from pytensor.graph.basic import Variable
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
...@@ -137,7 +136,6 @@ def {scalar_op_fn_name}({', '.join(input_names)}): ...@@ -137,7 +136,6 @@ def {scalar_op_fn_name}({', '.join(input_names)}):
return numba_basic.numba_njit( return numba_basic.numba_njit(
signature, signature,
fastmath=config.numba__fastmath,
# Functions that call a function pointer can't be cached # Functions that call a function pointer can't be cached
cache=False, cache=False,
)(scalar_op_fn) )(scalar_op_fn)
...@@ -177,9 +175,7 @@ def numba_funcify_Add(op, node, **kwargs): ...@@ -177,9 +175,7 @@ def numba_funcify_Add(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True) signature = create_numba_signature(node, force_scalar=True)
nary_add_fn = binary_to_nary_func(node.inputs, "add", "+") nary_add_fn = binary_to_nary_func(node.inputs, "add", "+")
return numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)( return numba_basic.numba_njit(signature)(nary_add_fn)
nary_add_fn
)
@numba_funcify.register(Mul) @numba_funcify.register(Mul)
...@@ -187,9 +183,7 @@ def numba_funcify_Mul(op, node, **kwargs): ...@@ -187,9 +183,7 @@ def numba_funcify_Mul(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True) signature = create_numba_signature(node, force_scalar=True)
nary_add_fn = binary_to_nary_func(node.inputs, "mul", "*") nary_add_fn = binary_to_nary_func(node.inputs, "mul", "*")
return numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)( return numba_basic.numba_njit(signature)(nary_add_fn)
nary_add_fn
)
@numba_funcify.register(Cast) @numba_funcify.register(Cast)
...@@ -239,7 +233,7 @@ def numba_funcify_Composite(op, node, **kwargs): ...@@ -239,7 +233,7 @@ def numba_funcify_Composite(op, node, **kwargs):
_ = kwargs.pop("storage_map", None) _ = kwargs.pop("storage_map", None)
composite_fn = numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)( composite_fn = numba_basic.numba_njit(signature)(
numba_funcify(op.fgraph, squeeze_output=True, **kwargs) numba_funcify(op.fgraph, squeeze_output=True, **kwargs)
) )
return composite_fn return composite_fn
...@@ -267,7 +261,7 @@ def numba_funcify_Reciprocal(op, node, **kwargs): ...@@ -267,7 +261,7 @@ def numba_funcify_Reciprocal(op, node, **kwargs):
return numba_basic.global_numba_func(reciprocal) return numba_basic.global_numba_func(reciprocal)
@numba_basic.numba_njit(fastmath=config.numba__fastmath) @numba_basic.numba_njit
def sigmoid(x): def sigmoid(x):
return 1 / (1 + np.exp(-x)) return 1 / (1 + np.exp(-x))
...@@ -277,7 +271,7 @@ def numba_funcify_Sigmoid(op, node, **kwargs): ...@@ -277,7 +271,7 @@ def numba_funcify_Sigmoid(op, node, **kwargs):
return numba_basic.global_numba_func(sigmoid) return numba_basic.global_numba_func(sigmoid)
@numba_basic.numba_njit(fastmath=config.numba__fastmath) @numba_basic.numba_njit
def gammaln(x): def gammaln(x):
return math.lgamma(x) return math.lgamma(x)
...@@ -287,7 +281,7 @@ def numba_funcify_GammaLn(op, node, **kwargs): ...@@ -287,7 +281,7 @@ def numba_funcify_GammaLn(op, node, **kwargs):
return numba_basic.global_numba_func(gammaln) return numba_basic.global_numba_func(gammaln)
@numba_basic.numba_njit(fastmath=config.numba__fastmath) @numba_basic.numba_njit
def logp1mexp(x): def logp1mexp(x):
if x < np.log(0.5): if x < np.log(0.5):
return np.log1p(-np.exp(x)) return np.log1p(-np.exp(x))
...@@ -300,7 +294,7 @@ def numba_funcify_Log1mexp(op, node, **kwargs): ...@@ -300,7 +294,7 @@ def numba_funcify_Log1mexp(op, node, **kwargs):
return numba_basic.global_numba_func(logp1mexp) return numba_basic.global_numba_func(logp1mexp)
@numba_basic.numba_njit(fastmath=config.numba__fastmath) @numba_basic.numba_njit
def erf(x): def erf(x):
return math.erf(x) return math.erf(x)
...@@ -310,7 +304,7 @@ def numba_funcify_Erf(op, **kwargs): ...@@ -310,7 +304,7 @@ def numba_funcify_Erf(op, **kwargs):
return numba_basic.global_numba_func(erf) return numba_basic.global_numba_func(erf)
@numba_basic.numba_njit(fastmath=config.numba__fastmath) @numba_basic.numba_njit
def erfc(x): def erfc(x):
return math.erfc(x) return math.erfc(x)
......
...@@ -838,7 +838,13 @@ def test_config_options_fastmath(): ...@@ -838,7 +838,13 @@ def test_config_options_fastmath():
pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode) pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode)
print(list(pytensor_numba_fn.vm.jit_fn.py_func.__globals__)) print(list(pytensor_numba_fn.vm.jit_fn.py_func.__globals__))
numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"] numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
assert numba_mul_fn.targetoptions["fastmath"] is True assert numba_mul_fn.targetoptions["fastmath"] == {
"afn",
"arcp",
"contract",
"nsz",
"reassoc",
}
def test_config_options_cached(): def test_config_options_cached():
......
...@@ -9,6 +9,7 @@ from pytensor.compile.sharedvalue import SharedVariable ...@@ -9,6 +9,7 @@ from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.scalar.basic import Composite from pytensor.scalar.basic import Composite
from pytensor.tensor import tensor
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import Elemwise
from tests.link.numba.test_basic import compare_numba_and_py, set_test_value from tests.link.numba.test_basic import compare_numba_and_py, set_test_value
...@@ -140,3 +141,21 @@ def test_reciprocal(v, dtype): ...@@ -140,3 +141,21 @@ def test_reciprocal(v, dtype):
if not isinstance(i, SharedVariable | Constant) if not isinstance(i, SharedVariable | Constant)
], ],
) )
@pytest.mark.parametrize("composite", (False, True))
def test_isnan(composite):
# Testing with tensor just to make sure Elemwise does not revert the scalar behavior of fastmath
x = tensor(shape=(2,), dtype="float64")
if composite:
x_scalar = psb.float64()
scalar_out = ~psb.isnan(x_scalar)
out = Elemwise(Composite([x_scalar], [scalar_out]))(x)
else:
out = pt.isnan(x)
compare_numba_and_py(
([x], [out]),
[np.array([1, 0], dtype="float64")],
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论