提交 3198d946 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Disable numba cache for cython functions

上级 3ed908b2
......@@ -48,10 +48,14 @@ from pytensor.tensor.type_other import MakeSlice, NoneConst
def numba_njit(*args, **kwargs):
kwargs = kwargs.copy()
if "cache" not in kwargs:
kwargs["cache"] = config.numba__cache
if len(args) > 0 and callable(args[0]):
return numba.njit(*args[1:], cache=config.numba__cache, **kwargs)(args[0])
return numba.njit(*args[1:], **kwargs)(args[0])
return numba.njit(*args, cache=config.numba__cache, **kwargs)
return numba.njit(*args, **kwargs)
def numba_vectorize(*args, **kwargs):
......
......@@ -144,7 +144,7 @@ def {scalar_op_fn_name}({', '.join(input_names)}):
signature = create_numba_signature(node, force_scalar=True)
return numba_basic.numba_njit(
signature, inline="always", fastmath=config.numba__fastmath
signature, inline="always", fastmath=config.numba__fastmath, cache=False,
)(scalar_op_fn)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论