提交 9b7d707a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Suppress caching warning when compiling Numba functions

上级 475fe3a0
......@@ -14,7 +14,7 @@ import scipy
import scipy.special
from llvmlite import ir
from numba import types
from numba.core.errors import TypingError
from numba.core.errors import NumbaWarning, TypingError
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
from numba.extending import box, overload
......@@ -61,6 +61,13 @@ def global_numba_func(func):
def numba_njit(*args, **kwargs):
kwargs.setdefault("cache", config.numba__cache)
# Supress caching warnings
warnings.filterwarnings(
"ignore",
message='Cannot cache compiled function "numba_funcified_fgraph" as it uses dynamic globals',
category=NumbaWarning,
)
if len(args) > 0 and callable(args[0]):
return numba.njit(*args[1:], **kwargs)(args[0])
......
......@@ -7,6 +7,8 @@ from unittest import mock
import numpy as np
import pytest
from tests.tensor.test_math_scipy import scipy
numba = pytest.importorskip("numba")
......@@ -1064,3 +1066,13 @@ def test_OpFromGraph():
zv = np.ones((2, 2), dtype=config.floatX) * 5
compare_numba_and_py(((x, y, z), (out,)), [xv, yv, zv])
@pytest.mark.filterwarnings("error")
def test_cache_warning_suppressed():
x = pt.vector("x", shape=(5,), dtype="float64")
out = pt.psi(x) * 2
fn = function([x], out, mode="NUMBA")
x_test = np.random.uniform(size=5)
np.testing.assert_allclose(fn(x_test), scipy.special.psi(x_test) * 2)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论