提交 d833f039 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add missing linker and mode options in config

上级 969f76c4
...@@ -47,6 +47,8 @@ def _filter_mode(val): ...@@ -47,6 +47,8 @@ def _filter_mode(val):
"DEBUG_MODE", "DEBUG_MODE",
"JAX", "JAX",
"NUMBA", "NUMBA",
"PYTORCH",
"MLX",
] ]
if val in str_options: if val in str_options:
return val return val
...@@ -367,13 +369,25 @@ def add_compile_configvars(): ...@@ -367,13 +369,25 @@ def add_compile_configvars():
) )
del param del param
default_linker = "cvm"
if rc == 0 and config.cxx != "": if rc == 0 and config.cxx != "":
# Keep the default linker the same as the one for the mode FAST_RUN # Keep the default linker the same as the one for the mode FAST_RUN
linker_options = ["c|py", "py", "c", "c|py_nogc", "vm", "vm_nogc", "cvm_nogc"] linker_options = [
"c|py",
"py",
"c",
"c|py_nogc",
"vm",
"vm_nogc",
"cvm_nogc",
"numba",
"jax",
]
else: else:
# g++ is not present or the user disabled it, # g++ is not present or the user disabled it,
# linker should default to python only. # linker should default to python only.
linker_options = ["py", "vm_nogc"] linker_options = ["py", "vm", "vm_nogc", "numba", "jax"]
if type(config).cxx.is_default: if type(config).cxx.is_default:
# If the user provided an empty value for cxx, do not warn. # If the user provided an empty value for cxx, do not warn.
_logger.warning( _logger.warning(
...@@ -387,7 +401,7 @@ def add_compile_configvars(): ...@@ -387,7 +401,7 @@ def add_compile_configvars():
"linker", "linker",
"Default linker used if the pytensor flags mode is Mode", "Default linker used if the pytensor flags mode is Mode",
# Not mutable because the default mode is cached after the first use. # Not mutable because the default mode is cached after the first use.
EnumStr("cvm", linker_options, mutable=False), EnumStr(default_linker, linker_options, mutable=False),
in_c_key=False, in_c_key=False,
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论