Unverified 提交 3ed2c497 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: GitHub

Enable no-cpython-wrapper in numba where possible (#765)

* Enable no-cpython-wrapper in numba where possible * Fix test with no_cpython_wrapper * Add docstring to numba_funcify
上级 15b90be8
......@@ -59,6 +59,8 @@ def global_numba_func(func):
def numba_njit(*args, **kwargs):
kwargs.setdefault("cache", config.numba__cache)
kwargs.setdefault("no_cpython_wrapper", True)
kwargs.setdefault("no_cfunc_wrapper", True)
# Supress caching warnings
warnings.filterwarnings(
......@@ -419,7 +421,12 @@ def generate_fallback_impl(op, node=None, storage_map=None, **kwargs):
@singledispatch
def numba_funcify(op, node=None, storage_map=None, **kwargs):
"""Generate a numba function for a given op and apply node."""
"""Generate a numba function for a given op and apply node.
The resulting function will usually use the `no_cpython_wrapper`
argument in numba, so it can not be called directly from python,
but only from other jit functions.
"""
return generate_fallback_impl(op, node, storage_map, **kwargs)
......
......@@ -470,7 +470,9 @@ _jit_options = {
"afn", # Approximate functions
"reassoc",
"nsz", # TODO Do we want this one?
}
},
"no_cpython_wrapper": True,
"no_cfunc_wrapper": True,
}
......@@ -698,7 +700,14 @@ def numba_funcify_Elemwise(op, node, **kwargs):
return tuple(outputs_summed)
return outputs_summed[0]
@overload(elemwise)
@overload(
elemwise,
jit_options={
"fastmath": flags,
"no_cpython_wrapper": True,
"no_cfunc_wrapper": True,
},
)
def ov_elemwise(*inputs):
return elemwise_wrapper
......
......@@ -29,7 +29,7 @@ class NumbaLinker(JITLinker):
def jit_compile(self, fn):
from pytensor.link.numba.dispatch.basic import numba_njit
jitted_fn = numba_njit(fn)
jitted_fn = numba_njit(fn, no_cpython_wrapper=False, no_cfunc_wrapper=False)
return jitted_fn
def create_thunk_inputs(self, storage_map):
......
......@@ -386,6 +386,8 @@ def test_ExtractDiag(val, offset):
)
@pytest.mark.parametrize("reverse_axis", (False, True))
def test_ExtractDiag_exhaustive(k, axis1, axis2, reverse_axis):
from pytensor.link.numba.dispatch.basic import numba_njit
if reverse_axis:
axis1, axis2 = axis2, axis1
......@@ -394,7 +396,12 @@ def test_ExtractDiag_exhaustive(k, axis1, axis2, reverse_axis):
x_test = np.arange(np.prod(x_shape)).reshape(x_shape)
out = pt.diagonal(x, k, axis1, axis2)
numba_fn = numba_funcify(out.owner.op, out.owner)
np.testing.assert_allclose(numba_fn(x_test), np.diagonal(x_test, k, axis1, axis2))
@numba_njit(no_cpython_wrapper=False)
def wrap(x):
return numba_fn(x)
np.testing.assert_allclose(wrap(x_test), np.diagonal(x_test, k, axis1, axis2))
@pytest.mark.parametrize(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论