提交 7039415c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow config.linker to change and affect both `Mode` and `FAST_RUN` abstract modes

上级 7ae5ae3c
......@@ -17,8 +17,8 @@ from pytensor.compile.function.types import (
)
from pytensor.compile.io import In, Out, SymbolicInput, SymbolicOutput
from pytensor.compile.mode import (
CVM,
FAST_COMPILE,
FAST_RUN,
JAX,
NUMBA,
OPT_FAST_COMPILE,
......@@ -33,6 +33,7 @@ from pytensor.compile.mode import (
PYTORCH,
AddDestroyHandler,
AddFeatureOptimizer,
C,
Mode,
PrintCurrentFunctionGraph,
get_default_mode,
......
......@@ -5,7 +5,7 @@ WRITEME
import logging
import warnings
from typing import Literal
from typing import Any, Literal
from pytensor.compile.function.types import Supervisor
from pytensor.configdefaults import config
......@@ -62,23 +62,17 @@ def register_linker(name, linker):
predefined_linkers[name] = linker
# If a string is passed as the optimizer argument in the constructor
# for Mode, it will be used as the key to retrieve the real optimizer
# in this dictionary
exclude = []
if not config.cxx:
exclude = ["cxx_only"]
OPT_NONE = RewriteDatabaseQuery(include=[], exclude=exclude)
OPT_NONE = RewriteDatabaseQuery(include=[])
# Minimum set of rewrites needed to evaluate a function. This is needed for graphs with "dummy" Operations
OPT_MINIMUM = RewriteDatabaseQuery(include=["minimum_compile"], exclude=exclude)
OPT_MINIMUM = RewriteDatabaseQuery(include=["minimum_compile"])
# Even if multiple merge optimizer call will be there, this shouldn't
# impact performance.
OPT_MERGE = RewriteDatabaseQuery(include=["merge"], exclude=exclude)
OPT_FAST_RUN = RewriteDatabaseQuery(include=["fast_run"], exclude=exclude)
OPT_MERGE = RewriteDatabaseQuery(include=["merge"])
OPT_FAST_RUN = RewriteDatabaseQuery(include=["fast_run"])
OPT_FAST_RUN_STABLE = OPT_FAST_RUN.requiring("stable")
OPT_FAST_COMPILE = RewriteDatabaseQuery(include=["fast_compile"], exclude=exclude)
OPT_STABILIZE = RewriteDatabaseQuery(include=["fast_run"], exclude=exclude)
OPT_FAST_COMPILE = RewriteDatabaseQuery(include=["fast_compile"])
OPT_STABILIZE = RewriteDatabaseQuery(include=["fast_run"])
OPT_STABILIZE.position_cutoff = 1.5000001
OPT_NONE.name = "OPT_NONE"
OPT_MINIMUM.name = "OPT_MINIMUM"
......@@ -316,6 +310,8 @@ class Mode:
):
if linker is None:
linker = config.linker
if isinstance(linker, str) and linker == "auto":
linker = "cvm" if config.cxx else "vm"
if isinstance(optimizer, str) and optimizer == "default":
optimizer = config.optimizer
......@@ -451,24 +447,9 @@ class Mode:
return new_mode
# If a string is passed as the mode argument in function or
# FunctionMaker, the Mode will be taken from this dictionary using the
# string as the key
# Use VM_linker to allow lazy evaluation by default.
FAST_COMPILE = Mode(
VMLinker(use_cloop=False, c_thunks=False),
RewriteDatabaseQuery(include=["fast_compile", "py_only"]),
)
if config.cxx:
FAST_RUN = Mode("cvm", "fast_run")
else:
FAST_RUN = Mode(
"vm",
RewriteDatabaseQuery(include=["fast_run", "py_only"]),
)
C = Mode("c", "fast_run")
CVM = Mode("cvm", "fast_run")
VM = (Mode("vm", "fast_run"),)
NUMBA = Mode(
NumbaLinker(),
......@@ -489,10 +470,19 @@ MLX = Mode(
RewriteDatabaseQuery(include=["fast_run"]),
)
FAST_COMPILE = Mode(
VMLinker(use_cloop=False, c_thunks=False),
RewriteDatabaseQuery(include=["fast_compile", "py_only"]),
)
fast_run_linkers_to_mode = {
"cvm": CVM,
"vm": VM,
"numba": NUMBA,
}
predefined_modes = {
"FAST_COMPILE": FAST_COMPILE,
"FAST_RUN": FAST_RUN,
"C": C,
"CVM": CVM,
"JAX": JAX,
......@@ -501,7 +491,7 @@ predefined_modes = {
"MLX": MLX,
}
_CACHED_RUNTIME_MODES: dict[str, Mode] = {}
_CACHED_RUNTIME_MODES: dict[Any, Mode] = {}
def get_mode(orig_string):
......@@ -519,10 +509,20 @@ def get_mode(orig_string):
if upper_string in predefined_modes:
return predefined_modes[upper_string]
if upper_string == "FAST_RUN":
linker = config.linker
if linker == "auto":
return CVM if config.cxx else VM
return fast_run_linkers_to_mode[linker]
global _CACHED_RUNTIME_MODES
if upper_string in _CACHED_RUNTIME_MODES:
return _CACHED_RUNTIME_MODES[upper_string]
cache_key = ("MODE", config.linker) if upper_string == "MODE" else upper_string
try:
return _CACHED_RUNTIME_MODES[cache_key]
except KeyError:
pass
# Need to define the mode for the first time
if upper_string == "MODE":
......@@ -548,7 +548,7 @@ def get_mode(orig_string):
if config.optimizer_requiring:
ret = ret.requiring(*config.optimizer_requiring.split(":"))
# Cache the mode for next time
_CACHED_RUNTIME_MODES[upper_string] = ret
_CACHED_RUNTIME_MODES[cache_key] = ret
return ret
......
......@@ -371,11 +371,12 @@ def add_compile_configvars():
)
del param
default_linker = "cvm"
default_linker = "auto"
if rc == 0 and config.cxx != "":
# Keep the default linker the same as the one for the mode FAST_RUN
linker_options = [
"cvm",
"c|py",
"py",
"c",
......@@ -401,9 +402,8 @@ def add_compile_configvars():
config.add(
"linker",
"Default linker used if the pytensor flags mode is Mode",
# Not mutable because the default mode is cached after the first use.
EnumStr(default_linker, linker_options, mutable=False),
"Default linker used if the pytensor flags mode is Mode or FAST_RUN",
EnumStr(default_linker, linker_options, mutable=True),
in_c_key=False,
)
......
......@@ -1784,14 +1784,14 @@ class numeric_grad:
def mode_not_slow(mode):
from pytensor.compile.debugmode import DebugMode
from pytensor.compile.mode import FAST_RUN, get_mode
from pytensor.compile.mode import get_mode
if mode == "FAST_COMPILE":
return FAST_RUN
return get_mode("FAST_RUN")
mode = get_mode(mode)
if isinstance(mode, DebugMode):
opt = mode.optimizer
return FAST_RUN.clone(optimizer=opt)
return get_mode("FAST_RUN").clone(optimizer=opt)
else:
return mode
......
......@@ -3,7 +3,7 @@ import pytest
import pytensor.tensor as pt
from pytensor import config, function, grad, shared
from pytensor.compile.mode import FAST_RUN
from pytensor.compile.mode import get_mode
from pytensor.link.basic import JITLinker
from pytensor.scan.views import filter as pt_filter
from pytensor.scan.views import foldl, foldr
......@@ -65,7 +65,7 @@ def test_reduce_memory_consumption():
pt.constant(np.asarray(0.0, dtype=config.floatX)),
return_updates=False,
)
mode = FAST_RUN
mode = get_mode("FAST_RUN")
mode = mode.excluding("inplace")
f1 = function([], o, mode=mode)
inputs, outputs = clone_optimized_graph(f1)
......@@ -106,7 +106,7 @@ def test_foldl_memory_consumption(return_updates):
else:
o = o_raw
mode = FAST_RUN
mode = get_mode("FAST_RUN")
mode = mode.excluding("inplace")
f0 = function([], o, mode=mode)
inputs, outputs = clone_optimized_graph(f0)
......@@ -147,7 +147,7 @@ def test_foldr_memory_consumption(return_updates):
else:
o = o_raw
mode = FAST_RUN
mode = get_mode("FAST_RUN")
mode = mode.excluding("inplace")
f1 = function([], o, mode=mode)
inputs, outputs = clone_optimized_graph(f1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论