提交 9b7d707a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Suppress caching warning when compiling Numba functions

上级 475fe3a0
...@@ -14,7 +14,7 @@ import scipy ...@@ -14,7 +14,7 @@ import scipy
import scipy.special import scipy.special
from llvmlite import ir from llvmlite import ir
from numba import types from numba import types
from numba.core.errors import TypingError from numba.core.errors import NumbaWarning, TypingError
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401 from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
from numba.extending import box, overload from numba.extending import box, overload
...@@ -61,6 +61,13 @@ def global_numba_func(func): ...@@ -61,6 +61,13 @@ def global_numba_func(func):
def numba_njit(*args, **kwargs): def numba_njit(*args, **kwargs):
kwargs.setdefault("cache", config.numba__cache) kwargs.setdefault("cache", config.numba__cache)
# Supress caching warnings
warnings.filterwarnings(
"ignore",
message='Cannot cache compiled function "numba_funcified_fgraph" as it uses dynamic globals',
category=NumbaWarning,
)
if len(args) > 0 and callable(args[0]): if len(args) > 0 and callable(args[0]):
return numba.njit(*args[1:], **kwargs)(args[0]) return numba.njit(*args[1:], **kwargs)(args[0])
......
...@@ -7,6 +7,8 @@ from unittest import mock ...@@ -7,6 +7,8 @@ from unittest import mock
import numpy as np import numpy as np
import pytest import pytest
from tests.tensor.test_math_scipy import scipy
numba = pytest.importorskip("numba") numba = pytest.importorskip("numba")
...@@ -1064,3 +1066,13 @@ def test_OpFromGraph(): ...@@ -1064,3 +1066,13 @@ def test_OpFromGraph():
zv = np.ones((2, 2), dtype=config.floatX) * 5 zv = np.ones((2, 2), dtype=config.floatX) * 5
compare_numba_and_py(((x, y, z), (out,)), [xv, yv, zv]) compare_numba_and_py(((x, y, z), (out,)), [xv, yv, zv])
@pytest.mark.filterwarnings("error")
def test_cache_warning_suppressed():
x = pt.vector("x", shape=(5,), dtype="float64")
out = pt.psi(x) * 2
fn = function([x], out, mode="NUMBA")
x_test = np.random.uniform(size=5)
np.testing.assert_allclose(fn(x_test), scipy.special.psi(x_test) * 2)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论