提交 1d9fa843 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Ricardo Vieira

fix(numba): Add warnings for objectmode

上级 426e0353
...@@ -560,11 +560,18 @@ def {fn_name}({", ".join(input_names)}): ...@@ -560,11 +560,18 @@ def {fn_name}({", ".join(input_names)}):
@numba_funcify.register(Subtensor) @numba_funcify.register(Subtensor)
@numba_funcify.register(AdvancedSubtensor1) @numba_funcify.register(AdvancedSubtensor1)
def numba_funcify_Subtensor(op, node, **kwargs): def numba_funcify_Subtensor(op, node, **kwargs):
subtensor_def_src = create_index_func( objmode = isinstance(op, AdvancedSubtensor)
node, objmode=isinstance(op, AdvancedSubtensor) if objmode:
) warnings.warn(
("Numba will use object mode to allow run " "AdvancedSubtensor."),
UserWarning,
)
subtensor_def_src = create_index_func(node, objmode=objmode)
global_env = {"np": np, "objmode": numba.objmode} global_env = {"np": np}
if objmode:
global_env["objmode"] = numba.objmode
subtensor_fn = compile_function_src( subtensor_fn = compile_function_src(
subtensor_def_src, "subtensor", {**globals(), **global_env} subtensor_def_src, "subtensor", {**globals(), **global_env}
...@@ -575,11 +582,18 @@ def numba_funcify_Subtensor(op, node, **kwargs): ...@@ -575,11 +582,18 @@ def numba_funcify_Subtensor(op, node, **kwargs):
@numba_funcify.register(IncSubtensor) @numba_funcify.register(IncSubtensor)
def numba_funcify_IncSubtensor(op, node, **kwargs): def numba_funcify_IncSubtensor(op, node, **kwargs):
incsubtensor_def_src = create_index_func( objmode = isinstance(op, AdvancedIncSubtensor)
node, objmode=isinstance(op, AdvancedIncSubtensor) if objmode:
) warnings.warn(
("Numba will use object mode to allow run " "AdvancedIncSubtensor."),
UserWarning,
)
incsubtensor_def_src = create_index_func(node, objmode=objmode)
global_env = {"np": np, "objmode": numba.objmode} global_env = {"np": np}
if objmode:
global_env["objmode"] = numba.objmode
incsubtensor_fn = compile_function_src( incsubtensor_fn = compile_function_src(
incsubtensor_def_src, "incsubtensor", {**globals(), **global_env} incsubtensor_def_src, "incsubtensor", {**globals(), **global_env}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论