提交 3f9cc26a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: ricardoV94

Test numba config flags directly

上级 165aa19f
...@@ -29,6 +29,7 @@ from pytensor.link.numba.dispatch.basic import ( ...@@ -29,6 +29,7 @@ from pytensor.link.numba.dispatch.basic import (
_filter_numba_warnings, _filter_numba_warnings,
cache_key_for_constant, cache_key_for_constant,
numba_funcify_and_cache_key, numba_funcify_and_cache_key,
numba_njit,
) )
from pytensor.link.numba.linker import NumbaLinker from pytensor.link.numba.linker import NumbaLinker
from pytensor.scalar.basic import Composite, ScalarOp, as_scalar from pytensor.scalar.basic import Composite, ScalarOp, as_scalar
...@@ -426,14 +427,13 @@ def test_shared_updates(): ...@@ -426,14 +427,13 @@ def test_shared_updates():
def test_config_options_fastmath(): def test_config_options_fastmath():
x = pt.dvector()
with config.change_flags(numba__fastmath=True): with config.change_flags(numba__fastmath=True):
pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode)
numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__[ @numba_njit
"jitable_func" def fn_fast(x):
].py_func.__globals__["impl_sum"] return x + 1
assert numba_sum_fn.targetoptions["fastmath"] == {
assert fn_fast.targetoptions["fastmath"] == {
"afn", "afn",
"arcp", "arcp",
"contract", "contract",
...@@ -442,28 +442,30 @@ def test_config_options_fastmath(): ...@@ -442,28 +442,30 @@ def test_config_options_fastmath():
} }
with config.change_flags(numba__fastmath=False): with config.change_flags(numba__fastmath=False):
pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode)
numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__[
"jitable_func"
].py_func.__globals__["impl_sum"]
assert numba_sum_fn.targetoptions["fastmath"] is False
@numba_njit
def fn_nofast(x):
return x + 1
assert fn_nofast.targetoptions["fastmath"] is False
def test_config_options_cached():
x = pt.dvector()
def test_config_options_cached():
with config.change_flags(numba__cache=True): with config.change_flags(numba__cache=True):
pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode)
numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__[ @numba_njit(cache=True)
"jitable_func" def fn_cached(x):
].py_func.__globals__["impl_sum"] return x + 1
assert not isinstance(numba_sum_fn._cache, numba.core.caching.NullCache)
assert not isinstance(fn_cached._cache, numba.core.caching.NullCache)
with config.change_flags(numba__cache=False): with config.change_flags(numba__cache=False):
pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode)
# Without caching we don't wrap the function in jitable_func @numba_njit
numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"] def fn_uncached(x):
assert isinstance(numba_sum_fn._cache, numba.core.caching.NullCache) return x + 1
assert isinstance(fn_uncached._cache, numba.core.caching.NullCache)
def test_scalar_return_value_conversion(): def test_scalar_return_value_conversion():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论