提交 cddf5883 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Respect predefined modes in `get_default_mode`

Also allow arbitrary capitalization of the modes. Also make linker and optimizer non-mutable config as the mode is cached after using them for the first time.
上级 450efff6
...@@ -37,7 +37,6 @@ from pytensor.compile.mode import ( ...@@ -37,7 +37,6 @@ from pytensor.compile.mode import (
PrintCurrentFunctionGraph, PrintCurrentFunctionGraph,
get_default_mode, get_default_mode,
get_mode, get_mode,
instantiated_default_mode,
local_useless, local_useless,
optdb, optdb,
predefined_linkers, predefined_linkers,
......
...@@ -492,7 +492,7 @@ predefined_modes = { ...@@ -492,7 +492,7 @@ predefined_modes = {
"PYTORCH": PYTORCH, "PYTORCH": PYTORCH,
} }
instantiated_default_mode = None _CACHED_RUNTIME_MODES: dict[str, Mode] = {}
def get_mode(orig_string): def get_mode(orig_string):
...@@ -500,50 +500,46 @@ def get_mode(orig_string): ...@@ -500,50 +500,46 @@ def get_mode(orig_string):
string = config.mode string = config.mode
else: else:
string = orig_string string = orig_string
if not isinstance(string, str): if not isinstance(string, str):
return string # it is hopefully already a mode... return string # it is hopefully already a mode...
global instantiated_default_mode # Keep the original string for error messages
# The default mode is cached. However, config.mode can change upper_string = string.upper()
# If instantiated_default_mode has the right class, use it.
if orig_string is None and instantiated_default_mode:
if string in predefined_modes:
default_mode_class = predefined_modes[string].__class__.__name__
else:
default_mode_class = string
if instantiated_default_mode.__class__.__name__ == default_mode_class:
return instantiated_default_mode
if string in ("Mode", "DebugMode", "NanGuardMode"):
if string == "DebugMode":
# need to import later to break circular dependency.
from .debugmode import DebugMode
# DebugMode use its own linker.
ret = DebugMode(optimizer=config.optimizer)
elif string == "NanGuardMode":
# need to import later to break circular dependency.
from .nanguardmode import NanGuardMode
# NanGuardMode use its own linker.
ret = NanGuardMode(True, True, True, optimizer=config.optimizer)
else:
# TODO: Can't we look up the name and invoke it rather than using eval here?
ret = eval(string + "(linker=config.linker, optimizer=config.optimizer)")
elif string in predefined_modes:
ret = predefined_modes[string]
else:
raise Exception(f"No predefined mode exist for string: {string}")
if orig_string is None: if upper_string in predefined_modes:
# Build and cache the default mode return predefined_modes[upper_string]
if config.optimizer_excluding:
ret = ret.excluding(*config.optimizer_excluding.split(":")) global _CACHED_RUNTIME_MODES
if config.optimizer_including:
ret = ret.including(*config.optimizer_including.split(":")) if upper_string in _CACHED_RUNTIME_MODES:
if config.optimizer_requiring: return _CACHED_RUNTIME_MODES[upper_string]
ret = ret.requiring(*config.optimizer_requiring.split(":"))
instantiated_default_mode = ret # Need to define the mode for the first time
if upper_string == "MODE":
ret = Mode(linker=config.linker, optimizer=config.optimizer)
elif upper_string in ("DEBUGMODE", "DEBUG_MODE"):
from pytensor.compile.debugmode import DebugMode
# DebugMode use its own linker.
ret = DebugMode(optimizer=config.optimizer)
elif upper_string == "NANGUARDMODE":
from pytensor.compile.nanguardmode import NanGuardMode
# NanGuardMode use its own linker.
ret = NanGuardMode(True, True, True, optimizer=config.optimizer)
else:
raise ValueError(f"No predefined mode exist for string: {string}")
if config.optimizer_excluding:
ret = ret.excluding(*config.optimizer_excluding.split(":"))
if config.optimizer_including:
ret = ret.including(*config.optimizer_including.split(":"))
if config.optimizer_requiring:
ret = ret.requiring(*config.optimizer_requiring.split(":"))
# Cache the mode for next time
_CACHED_RUNTIME_MODES[upper_string] = ret
return ret return ret
......
...@@ -387,7 +387,8 @@ def add_compile_configvars(): ...@@ -387,7 +387,8 @@ def add_compile_configvars():
config.add( config.add(
"linker", "linker",
"Default linker used if the pytensor flags mode is Mode", "Default linker used if the pytensor flags mode is Mode",
EnumStr("cvm", linker_options), # Not mutable because the default mode is cached after the first use.
EnumStr("cvm", linker_options, mutable=False),
in_c_key=False, in_c_key=False,
) )
...@@ -410,6 +411,7 @@ def add_compile_configvars(): ...@@ -410,6 +411,7 @@ def add_compile_configvars():
EnumStr( EnumStr(
"o4", "o4",
["o3", "o2", "o1", "unsafe", "fast_run", "fast_compile", "merge", "None"], ["o3", "o2", "o1", "unsafe", "fast_run", "fast_compile", "merge", "None"],
mutable=False, # Not mutable because the default mode is cached after the first use.
), ),
in_c_key=False, in_c_key=False,
) )
......
...@@ -1105,14 +1105,10 @@ class TestPicklefunction: ...@@ -1105,14 +1105,10 @@ class TestPicklefunction:
((a.T.T) * (dot(xm, (sm.T.T.T)) + x).T * (x / x) + s), ((a.T.T) * (dot(xm, (sm.T.T.T)) + x).T * (x / x) + s),
) )
old_default_mode = config.mode old_default_mode = config.mode
old_default_opt = config.optimizer
old_default_link = config.linker
try: try:
try: try:
str_f = pickle.dumps(f, protocol=-1) str_f = pickle.dumps(f, protocol=-1)
config.mode = "Mode" config.mode = "NUMBA"
config.linker = "py"
config.optimizer = "None"
g = pickle.loads(str_f) g = pickle.loads(str_f)
# print g.maker.mode # print g.maker.mode
# print compile.mode.default_mode # print compile.mode.default_mode
...@@ -1121,8 +1117,6 @@ class TestPicklefunction: ...@@ -1121,8 +1117,6 @@ class TestPicklefunction:
g = "ok" g = "ok"
finally: finally:
config.mode = old_default_mode config.mode = old_default_mode
config.optimizer = old_default_opt
config.linker = old_default_link
if g == "ok": if g == "ok":
return return
......
...@@ -13,6 +13,7 @@ from pytensor.configdefaults import config ...@@ -13,6 +13,7 @@ from pytensor.configdefaults import config
from pytensor.graph.features import NoOutputFromInplace from pytensor.graph.features import NoOutputFromInplace
from pytensor.graph.rewriting.db import RewriteDatabaseQuery, SequenceDB from pytensor.graph.rewriting.db import RewriteDatabaseQuery, SequenceDB
from pytensor.link.basic import LocalLinker from pytensor.link.basic import LocalLinker
from pytensor.link.jax import JAXLinker
from pytensor.tensor.math import dot, tanh from pytensor.tensor.math import dot, tanh
from pytensor.tensor.type import matrix, vector from pytensor.tensor.type import matrix, vector
...@@ -142,3 +143,15 @@ def test_get_target_language(): ...@@ -142,3 +143,15 @@ def test_get_target_language():
test_mode = Mode(linker=MyLinker()) test_mode = Mode(linker=MyLinker())
with pytest.raises(Exception): with pytest.raises(Exception):
get_target_language(test_mode) get_target_language(test_mode)
def test_predefined_modes_respected():
default_mode = get_default_mode()
assert not isinstance(default_mode.linker, JAXLinker)
with config.change_flags(mode="JAX"):
jax_mode = get_default_mode()
assert isinstance(jax_mode.linker, JAXLinker)
default_mode_again = get_default_mode()
assert not isinstance(default_mode_again.linker, JAXLinker)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论