提交 1453ba09 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Suppress noisy numba warnings

上级 a244ab1a
......@@ -5,6 +5,7 @@ from hashlib import sha256
import numba
import numpy as np
from numba import NumbaPerformanceWarning, NumbaWarning
from numba import njit as _njit
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
......@@ -23,6 +24,35 @@ from pytensor.tensor.type import TensorType
from pytensor.tensor.utils import hash_from_ndarray
def _filter_numba_warnings():
# Suppress large global arrays cache warning for internal functions
# We have to add an ansi escape code for optional bold text by numba
# TODO: We could avoid inlining large constants and pass them at runtime
warnings.filterwarnings(
"ignore",
message=(
"(\x1b\\[1m)*" # ansi escape code for bold text
'Cannot cache compiled function "numba_funcified_fgraph" as it uses dynamic globals'
),
category=NumbaWarning,
)
# Disable loud / incorrect warnings from Numba
# https://github.com/numba/numba/issues/10086
# TODO: Would be much better if we could disable only for our functions
warnings.filterwarnings(
"ignore",
message=(
"(\x1b\\[1m)*" # ansi escape code for bold text
r"np\.dot\(\) is faster on contiguous arrays"
),
category=NumbaPerformanceWarning,
)
_filter_numba_warnings()
def numba_njit(
*args, fastmath=None, final_function: bool = False, **kwargs
) -> Callable:
......
......@@ -25,6 +25,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.type import Type
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
_filter_numba_warnings,
cache_key_for_constant,
numba_funcify_and_cache_key,
)
......@@ -455,14 +456,46 @@ def test_scalar_return_value_conversion():
assert isinstance(x_fn(1.0), np.ndarray)
@pytest.mark.filterwarnings("error")
def test_cache_warning_suppressed():
x = pt.vector("x", shape=(5,), dtype="float64")
out = pt.psi(x) * 2
fn = function([x], out, mode="NUMBA")
x_test = np.random.uniform(size=5)
np.testing.assert_allclose(fn(x_test), scipy.special.psi(x_test) * 2)
class TestNumbaWarnings:
def setup_method(self, method):
# Pytest messes up with the package filters, reenable here for testing
_filter_numba_warnings()
@pytest.mark.filterwarnings("error")
def test_cache_pointer_func_warning_suppressed(self):
x = pt.vector("x", shape=(5,), dtype="float64")
out = pt.psi(x) * 2
fn = function([x], out, mode="NUMBA")
x_test = np.random.uniform(size=5)
np.testing.assert_allclose(fn(x_test), scipy.special.psi(x_test) * 2)
@pytest.mark.filterwarnings("error")
def test_cache_large_global_array_warning_suppressed(self):
rng = np.random.default_rng(458)
large_constant = rng.normal(size=(100000, 5))
x = pt.vector("x", shape=(5,), dtype="float64")
out = x * large_constant
fn = function([x], out, mode="NUMBA")
x_test = rng.uniform(size=5)
np.testing.assert_allclose(fn(x_test), x_test * large_constant)
@pytest.mark.filterwarnings("error")
def test_contiguous_array_dot_warning_suppressed(self):
A = pt.matrix("A")
b = pt.vector("b")
out = pt.dot(A, b[:, None])
# Cached functions won't reemit the warning, so we have to disable it
with config.change_flags(numba__cache=False):
fn = function([A, b], out, mode="NUMBA")
A_test = np.ones((5, 5))
# Numba actually warns even on contiguous arrays: https://github.com/numba/numba/issues/10086
# But either way we don't want this warning for users as they have little control over strides
b_test = np.ones((10,))[::2]
np.testing.assert_allclose(fn(A_test, b_test), np.dot(A_test, b_test[:, None]))
@pytest.mark.parametrize("mode", ("default", "trust_input", "direct"))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论