提交 7039415c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow config.linker to change and affect both `Mode` and `FAST_RUN` abstract modes

上级 7ae5ae3c
...@@ -17,8 +17,8 @@ from pytensor.compile.function.types import ( ...@@ -17,8 +17,8 @@ from pytensor.compile.function.types import (
) )
from pytensor.compile.io import In, Out, SymbolicInput, SymbolicOutput from pytensor.compile.io import In, Out, SymbolicInput, SymbolicOutput
from pytensor.compile.mode import ( from pytensor.compile.mode import (
CVM,
FAST_COMPILE, FAST_COMPILE,
FAST_RUN,
JAX, JAX,
NUMBA, NUMBA,
OPT_FAST_COMPILE, OPT_FAST_COMPILE,
...@@ -33,6 +33,7 @@ from pytensor.compile.mode import ( ...@@ -33,6 +33,7 @@ from pytensor.compile.mode import (
PYTORCH, PYTORCH,
AddDestroyHandler, AddDestroyHandler,
AddFeatureOptimizer, AddFeatureOptimizer,
C,
Mode, Mode,
PrintCurrentFunctionGraph, PrintCurrentFunctionGraph,
get_default_mode, get_default_mode,
......
...@@ -5,7 +5,7 @@ WRITEME ...@@ -5,7 +5,7 @@ WRITEME
import logging import logging
import warnings import warnings
from typing import Literal from typing import Any, Literal
from pytensor.compile.function.types import Supervisor from pytensor.compile.function.types import Supervisor
from pytensor.configdefaults import config from pytensor.configdefaults import config
...@@ -62,23 +62,17 @@ def register_linker(name, linker): ...@@ -62,23 +62,17 @@ def register_linker(name, linker):
predefined_linkers[name] = linker predefined_linkers[name] = linker
# If a string is passed as the optimizer argument in the constructor OPT_NONE = RewriteDatabaseQuery(include=[])
# for Mode, it will be used as the key to retrieve the real optimizer
# in this dictionary
exclude = []
if not config.cxx:
exclude = ["cxx_only"]
OPT_NONE = RewriteDatabaseQuery(include=[], exclude=exclude)
# Minimum set of rewrites needed to evaluate a function. This is needed for graphs with "dummy" Operations # Minimum set of rewrites needed to evaluate a function. This is needed for graphs with "dummy" Operations
OPT_MINIMUM = RewriteDatabaseQuery(include=["minimum_compile"], exclude=exclude) OPT_MINIMUM = RewriteDatabaseQuery(include=["minimum_compile"])
# Even if multiple merge optimizer call will be there, this shouldn't # Even if multiple merge optimizer call will be there, this shouldn't
# impact performance. # impact performance.
OPT_MERGE = RewriteDatabaseQuery(include=["merge"], exclude=exclude) OPT_MERGE = RewriteDatabaseQuery(include=["merge"])
OPT_FAST_RUN = RewriteDatabaseQuery(include=["fast_run"], exclude=exclude) OPT_FAST_RUN = RewriteDatabaseQuery(include=["fast_run"])
OPT_FAST_RUN_STABLE = OPT_FAST_RUN.requiring("stable") OPT_FAST_RUN_STABLE = OPT_FAST_RUN.requiring("stable")
OPT_FAST_COMPILE = RewriteDatabaseQuery(include=["fast_compile"], exclude=exclude) OPT_FAST_COMPILE = RewriteDatabaseQuery(include=["fast_compile"])
OPT_STABILIZE = RewriteDatabaseQuery(include=["fast_run"], exclude=exclude) OPT_STABILIZE = RewriteDatabaseQuery(include=["fast_run"])
OPT_STABILIZE.position_cutoff = 1.5000001 OPT_STABILIZE.position_cutoff = 1.5000001
OPT_NONE.name = "OPT_NONE" OPT_NONE.name = "OPT_NONE"
OPT_MINIMUM.name = "OPT_MINIMUM" OPT_MINIMUM.name = "OPT_MINIMUM"
...@@ -316,6 +310,8 @@ class Mode: ...@@ -316,6 +310,8 @@ class Mode:
): ):
if linker is None: if linker is None:
linker = config.linker linker = config.linker
if isinstance(linker, str) and linker == "auto":
linker = "cvm" if config.cxx else "vm"
if isinstance(optimizer, str) and optimizer == "default": if isinstance(optimizer, str) and optimizer == "default":
optimizer = config.optimizer optimizer = config.optimizer
...@@ -451,24 +447,9 @@ class Mode: ...@@ -451,24 +447,9 @@ class Mode:
return new_mode return new_mode
# If a string is passed as the mode argument in function or
# FunctionMaker, the Mode will be taken from this dictionary using the
# string as the key
# Use VM_linker to allow lazy evaluation by default.
FAST_COMPILE = Mode(
VMLinker(use_cloop=False, c_thunks=False),
RewriteDatabaseQuery(include=["fast_compile", "py_only"]),
)
if config.cxx:
FAST_RUN = Mode("cvm", "fast_run")
else:
FAST_RUN = Mode(
"vm",
RewriteDatabaseQuery(include=["fast_run", "py_only"]),
)
C = Mode("c", "fast_run") C = Mode("c", "fast_run")
CVM = Mode("cvm", "fast_run") CVM = Mode("cvm", "fast_run")
VM = (Mode("vm", "fast_run"),)
NUMBA = Mode( NUMBA = Mode(
NumbaLinker(), NumbaLinker(),
...@@ -489,10 +470,19 @@ MLX = Mode( ...@@ -489,10 +470,19 @@ MLX = Mode(
RewriteDatabaseQuery(include=["fast_run"]), RewriteDatabaseQuery(include=["fast_run"]),
) )
FAST_COMPILE = Mode(
VMLinker(use_cloop=False, c_thunks=False),
RewriteDatabaseQuery(include=["fast_compile", "py_only"]),
)
fast_run_linkers_to_mode = {
"cvm": CVM,
"vm": VM,
"numba": NUMBA,
}
predefined_modes = { predefined_modes = {
"FAST_COMPILE": FAST_COMPILE, "FAST_COMPILE": FAST_COMPILE,
"FAST_RUN": FAST_RUN,
"C": C, "C": C,
"CVM": CVM, "CVM": CVM,
"JAX": JAX, "JAX": JAX,
...@@ -501,7 +491,7 @@ predefined_modes = { ...@@ -501,7 +491,7 @@ predefined_modes = {
"MLX": MLX, "MLX": MLX,
} }
_CACHED_RUNTIME_MODES: dict[str, Mode] = {} _CACHED_RUNTIME_MODES: dict[Any, Mode] = {}
def get_mode(orig_string): def get_mode(orig_string):
...@@ -519,10 +509,20 @@ def get_mode(orig_string): ...@@ -519,10 +509,20 @@ def get_mode(orig_string):
if upper_string in predefined_modes: if upper_string in predefined_modes:
return predefined_modes[upper_string] return predefined_modes[upper_string]
if upper_string == "FAST_RUN":
linker = config.linker
if linker == "auto":
return CVM if config.cxx else VM
return fast_run_linkers_to_mode[linker]
global _CACHED_RUNTIME_MODES global _CACHED_RUNTIME_MODES
if upper_string in _CACHED_RUNTIME_MODES: cache_key = ("MODE", config.linker) if upper_string == "MODE" else upper_string
return _CACHED_RUNTIME_MODES[upper_string]
try:
return _CACHED_RUNTIME_MODES[cache_key]
except KeyError:
pass
# Need to define the mode for the first time # Need to define the mode for the first time
if upper_string == "MODE": if upper_string == "MODE":
...@@ -548,7 +548,7 @@ def get_mode(orig_string): ...@@ -548,7 +548,7 @@ def get_mode(orig_string):
if config.optimizer_requiring: if config.optimizer_requiring:
ret = ret.requiring(*config.optimizer_requiring.split(":")) ret = ret.requiring(*config.optimizer_requiring.split(":"))
# Cache the mode for next time # Cache the mode for next time
_CACHED_RUNTIME_MODES[upper_string] = ret _CACHED_RUNTIME_MODES[cache_key] = ret
return ret return ret
......
...@@ -371,11 +371,12 @@ def add_compile_configvars(): ...@@ -371,11 +371,12 @@ def add_compile_configvars():
) )
del param del param
default_linker = "cvm" default_linker = "auto"
if rc == 0 and config.cxx != "": if rc == 0 and config.cxx != "":
# Keep the default linker the same as the one for the mode FAST_RUN # Keep the default linker the same as the one for the mode FAST_RUN
linker_options = [ linker_options = [
"cvm",
"c|py", "c|py",
"py", "py",
"c", "c",
...@@ -401,9 +402,8 @@ def add_compile_configvars(): ...@@ -401,9 +402,8 @@ def add_compile_configvars():
config.add( config.add(
"linker", "linker",
"Default linker used if the pytensor flags mode is Mode", "Default linker used if the pytensor flags mode is Mode or FAST_RUN",
# Not mutable because the default mode is cached after the first use. EnumStr(default_linker, linker_options, mutable=True),
EnumStr(default_linker, linker_options, mutable=False),
in_c_key=False, in_c_key=False,
) )
......
...@@ -1784,14 +1784,14 @@ class numeric_grad: ...@@ -1784,14 +1784,14 @@ class numeric_grad:
def mode_not_slow(mode): def mode_not_slow(mode):
from pytensor.compile.debugmode import DebugMode from pytensor.compile.debugmode import DebugMode
from pytensor.compile.mode import FAST_RUN, get_mode from pytensor.compile.mode import get_mode
if mode == "FAST_COMPILE": if mode == "FAST_COMPILE":
return FAST_RUN return get_mode("FAST_RUN")
mode = get_mode(mode) mode = get_mode(mode)
if isinstance(mode, DebugMode): if isinstance(mode, DebugMode):
opt = mode.optimizer opt = mode.optimizer
return FAST_RUN.clone(optimizer=opt) return get_mode("FAST_RUN").clone(optimizer=opt)
else: else:
return mode return mode
......
...@@ -3,7 +3,7 @@ import pytest ...@@ -3,7 +3,7 @@ import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config, function, grad, shared from pytensor import config, function, grad, shared
from pytensor.compile.mode import FAST_RUN from pytensor.compile.mode import get_mode
from pytensor.link.basic import JITLinker from pytensor.link.basic import JITLinker
from pytensor.scan.views import filter as pt_filter from pytensor.scan.views import filter as pt_filter
from pytensor.scan.views import foldl, foldr from pytensor.scan.views import foldl, foldr
...@@ -65,7 +65,7 @@ def test_reduce_memory_consumption(): ...@@ -65,7 +65,7 @@ def test_reduce_memory_consumption():
pt.constant(np.asarray(0.0, dtype=config.floatX)), pt.constant(np.asarray(0.0, dtype=config.floatX)),
return_updates=False, return_updates=False,
) )
mode = FAST_RUN mode = get_mode("FAST_RUN")
mode = mode.excluding("inplace") mode = mode.excluding("inplace")
f1 = function([], o, mode=mode) f1 = function([], o, mode=mode)
inputs, outputs = clone_optimized_graph(f1) inputs, outputs = clone_optimized_graph(f1)
...@@ -106,7 +106,7 @@ def test_foldl_memory_consumption(return_updates): ...@@ -106,7 +106,7 @@ def test_foldl_memory_consumption(return_updates):
else: else:
o = o_raw o = o_raw
mode = FAST_RUN mode = get_mode("FAST_RUN")
mode = mode.excluding("inplace") mode = mode.excluding("inplace")
f0 = function([], o, mode=mode) f0 = function([], o, mode=mode)
inputs, outputs = clone_optimized_graph(f0) inputs, outputs = clone_optimized_graph(f0)
...@@ -147,7 +147,7 @@ def test_foldr_memory_consumption(return_updates): ...@@ -147,7 +147,7 @@ def test_foldr_memory_consumption(return_updates):
else: else:
o = o_raw o = o_raw
mode = FAST_RUN mode = get_mode("FAST_RUN")
mode = mode.excluding("inplace") mode = mode.excluding("inplace")
f1 = function([], o, mode=mode) f1 = function([], o, mode=mode)
inputs, outputs = clone_optimized_graph(f1) inputs, outputs = clone_optimized_graph(f1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论