提交 cddf5883 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Respect predefined modes in `get_default_mode`

Also allow arbitrary capitalization of the modes. Also make linker and optimizer non-mutable config as the mode is cached after using them for the first time.
上级 450efff6
......@@ -37,7 +37,6 @@ from pytensor.compile.mode import (
PrintCurrentFunctionGraph,
get_default_mode,
get_mode,
instantiated_default_mode,
local_useless,
optdb,
predefined_linkers,
......
......@@ -492,7 +492,7 @@ predefined_modes = {
"PYTORCH": PYTORCH,
}
instantiated_default_mode = None
_CACHED_RUNTIME_MODES: dict[str, Mode] = {}
def get_mode(orig_string):
......@@ -500,50 +500,46 @@ def get_mode(orig_string):
string = config.mode
else:
string = orig_string
if not isinstance(string, str):
return string # it is hopefully already a mode...
global instantiated_default_mode
# The default mode is cached. However, config.mode can change
# If instantiated_default_mode has the right class, use it.
if orig_string is None and instantiated_default_mode:
if string in predefined_modes:
default_mode_class = predefined_modes[string].__class__.__name__
else:
default_mode_class = string
if instantiated_default_mode.__class__.__name__ == default_mode_class:
return instantiated_default_mode
if string in ("Mode", "DebugMode", "NanGuardMode"):
if string == "DebugMode":
# need to import later to break circular dependency.
from .debugmode import DebugMode
# DebugMode use its own linker.
ret = DebugMode(optimizer=config.optimizer)
elif string == "NanGuardMode":
# need to import later to break circular dependency.
from .nanguardmode import NanGuardMode
# NanGuardMode use its own linker.
ret = NanGuardMode(True, True, True, optimizer=config.optimizer)
else:
# TODO: Can't we look up the name and invoke it rather than using eval here?
ret = eval(string + "(linker=config.linker, optimizer=config.optimizer)")
elif string in predefined_modes:
ret = predefined_modes[string]
else:
raise Exception(f"No predefined mode exist for string: {string}")
# Keep the original string for error messages
upper_string = string.upper()
if orig_string is None:
# Build and cache the default mode
if config.optimizer_excluding:
ret = ret.excluding(*config.optimizer_excluding.split(":"))
if config.optimizer_including:
ret = ret.including(*config.optimizer_including.split(":"))
if config.optimizer_requiring:
ret = ret.requiring(*config.optimizer_requiring.split(":"))
instantiated_default_mode = ret
if upper_string in predefined_modes:
return predefined_modes[upper_string]
global _CACHED_RUNTIME_MODES
if upper_string in _CACHED_RUNTIME_MODES:
return _CACHED_RUNTIME_MODES[upper_string]
# Need to define the mode for the first time
if upper_string == "MODE":
ret = Mode(linker=config.linker, optimizer=config.optimizer)
elif upper_string in ("DEBUGMODE", "DEBUG_MODE"):
from pytensor.compile.debugmode import DebugMode
# DebugMode use its own linker.
ret = DebugMode(optimizer=config.optimizer)
elif upper_string == "NANGUARDMODE":
from pytensor.compile.nanguardmode import NanGuardMode
# NanGuardMode use its own linker.
ret = NanGuardMode(True, True, True, optimizer=config.optimizer)
else:
raise ValueError(f"No predefined mode exist for string: {string}")
if config.optimizer_excluding:
ret = ret.excluding(*config.optimizer_excluding.split(":"))
if config.optimizer_including:
ret = ret.including(*config.optimizer_including.split(":"))
if config.optimizer_requiring:
ret = ret.requiring(*config.optimizer_requiring.split(":"))
# Cache the mode for next time
_CACHED_RUNTIME_MODES[upper_string] = ret
return ret
......
......@@ -387,7 +387,8 @@ def add_compile_configvars():
config.add(
"linker",
"Default linker used if the pytensor flags mode is Mode",
EnumStr("cvm", linker_options),
# Not mutable because the default mode is cached after the first use.
EnumStr("cvm", linker_options, mutable=False),
in_c_key=False,
)
......@@ -410,6 +411,7 @@ def add_compile_configvars():
EnumStr(
"o4",
["o3", "o2", "o1", "unsafe", "fast_run", "fast_compile", "merge", "None"],
mutable=False, # Not mutable because the default mode is cached after the first use.
),
in_c_key=False,
)
......
......@@ -1105,14 +1105,10 @@ class TestPicklefunction:
((a.T.T) * (dot(xm, (sm.T.T.T)) + x).T * (x / x) + s),
)
old_default_mode = config.mode
old_default_opt = config.optimizer
old_default_link = config.linker
try:
try:
str_f = pickle.dumps(f, protocol=-1)
config.mode = "Mode"
config.linker = "py"
config.optimizer = "None"
config.mode = "NUMBA"
g = pickle.loads(str_f)
# print g.maker.mode
# print compile.mode.default_mode
......@@ -1121,8 +1117,6 @@ class TestPicklefunction:
g = "ok"
finally:
config.mode = old_default_mode
config.optimizer = old_default_opt
config.linker = old_default_link
if g == "ok":
return
......
......@@ -13,6 +13,7 @@ from pytensor.configdefaults import config
from pytensor.graph.features import NoOutputFromInplace
from pytensor.graph.rewriting.db import RewriteDatabaseQuery, SequenceDB
from pytensor.link.basic import LocalLinker
from pytensor.link.jax import JAXLinker
from pytensor.tensor.math import dot, tanh
from pytensor.tensor.type import matrix, vector
......@@ -142,3 +143,15 @@ def test_get_target_language():
test_mode = Mode(linker=MyLinker())
with pytest.raises(Exception):
get_target_language(test_mode)
def test_predefined_modes_respected():
default_mode = get_default_mode()
assert not isinstance(default_mode.linker, JAXLinker)
with config.change_flags(mode="JAX"):
jax_mode = get_default_mode()
assert isinstance(jax_mode.linker, JAXLinker)
default_mode_again = get_default_mode()
assert not isinstance(default_mode_again.linker, JAXLinker)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论