提交 e274e0d2 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow linkers to specify required/incompatible rewrites

上级 f5bf2afa
......@@ -1331,7 +1331,11 @@ default_make_thunk = [get_unbound_function(COp.make_thunk)]
# the external requirements of the .linker attribute of a mode
# 1) it's a class instance
# 2) it a has a .clone() method
# 3) it has required_rewrites and incompatible_rewrites class attributes
class _DummyLinker:
required_rewrites = ()
incompatible_rewrites = ()
# This is not a real linker anyway
def clone(self, allow_gc=None):
return self
......
......@@ -352,7 +352,14 @@ class Mode:
if isinstance(optimizer, str) or optimizer is None:
optimizer = predefined_optimizers[optimizer]
if isinstance(optimizer, RewriteDatabaseQuery):
# TODO: From the __init__ signature this should always be the case
# But some tests and internal logic allow passing a GraphRewriter directly as optimizer
# Cleanup!
self.provided_optimizer = optimizer
if r := linker.required_rewrites:
optimizer = optimizer.including(*r)
if r := linker.incompatible_rewrites:
optimizer = optimizer.excluding(*r)
self._optimizer = optimizer
self.call_time = 0
self.fn_time = 0
......@@ -365,14 +372,13 @@ class Mode:
f"optdb={self.optdb})"
)
def __get_optimizer(self):
@property
def optimizer(self):
if isinstance(self._optimizer, RewriteDatabaseQuery):
return self.optdb.query(self._optimizer)
else:
return self._optimizer
optimizer = property(__get_optimizer)
def get_linker_optimizer(self, linker, optimizer):
if isinstance(linker, str) or linker is None:
linker = predefined_linkers[linker]
......@@ -466,61 +472,21 @@ C_VM = Mode("cvm", "fast_run")
NUMBA = Mode(
NumbaLinker(),
RewriteDatabaseQuery(
include=["fast_run", "numba"],
exclude=[
"cxx_only",
"BlasOpt",
"local_careduce_fusion",
"scan_save_mem_prealloc",
],
),
RewriteDatabaseQuery(include=["fast_run", "numba"]),
)
JAX = Mode(
JAXLinker(),
RewriteDatabaseQuery(
include=["fast_run", "jax"],
exclude=[
"cxx_only",
"BlasOpt",
"fusion",
"inplace",
"scan_save_mem_prealloc",
# There are specific variants for the LU decompositions supported by JAX
"reuse_lu_decomposition_multiple_solves",
"scan_split_non_sequence_lu_decomposition_solve",
],
),
RewriteDatabaseQuery(include=["fast_run", "jax"]),
)
PYTORCH = Mode(
PytorchLinker(),
RewriteDatabaseQuery(
include=["fast_run"],
exclude=[
"cxx_only",
"BlasOpt",
"fusion",
"inplace",
"scan_save_mem_prealloc",
"reuse_lu_decomposition_multiple_solves",
"scan_split_non_sequence_lu_decomposition_solve",
],
),
RewriteDatabaseQuery(include=["fast_run"]),
)
MLX = Mode(
MLXLinker(),
RewriteDatabaseQuery(
include=["fast_run"],
exclude=[
"cxx_only",
"BlasOpt",
"fusion",
"inplace",
"scan_save_mem_prealloc",
],
),
RewriteDatabaseQuery(include=["fast_run"]),
)
......
......@@ -157,6 +157,9 @@ class Linker(ABC):
the FunctionGraph.
"""
required_rewrites: tuple[str, ...] = ("minimum_compile",)
incompatible_rewrites: tuple[str, ...] = ()
def __init__(
self,
*,
......
......@@ -9,6 +9,22 @@ from pytensor.link.basic import JITLinker
class JAXLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using JAX."""
required_rewrites = (
"minimum_compile",
"jax",
) # TODO: Distinguish between optional "jax" and "minimum_compile_jax"
incompatible_rewrites = (
"cxx",
"BlasOpt",
"local_careduce_fusion",
"scan_save_mem_prealloc",
# JAX does it his own inplace optimization
"inplace",
# There are specific variants for the LU decompositions supported by JAX
"reuse_lu_decomposition_multiple_solves",
"scan_split_non_sequence_lu_decomposition_solve",
)
scalar_shape_inputs: tuple[int, ...]
def __init__(self, *args, **kwargs):
......
......@@ -4,6 +4,14 @@ from pytensor.link.basic import JITLinker
class MLXLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using Apple's MLX."""
incompatible_rewrites = (
"cxx_only",
"BlasOpt",
"fusion",
"inplace",
"scan_save_mem_prealloc",
)
def __init__(self, use_compile=True, *args, **kwargs):
super().__init__(*args, **kwargs)
self.gen_functors = []
......
......@@ -2,6 +2,17 @@ from pytensor.link.basic import JITLinker
class NumbaLinker(JITLinker):
required_rewrites = (
"minimum_compile",
"numba",
) # TODO: Distinguish between optional "numba" and "minimum_compile_numba"
incompatible_rewrites = (
"cxx",
"BlasOpt",
"local_careduce_fusion",
"scan_save_mem_prealloc",
)
"""A `Linker` that JIT-compiles NumPy-based operations using Numba."""
def fgraph_convert(self, fgraph, **kwargs):
......
......@@ -5,6 +5,16 @@ from pytensor.link.utils import unique_name_generator
class PytorchLinker(JITLinker):
"""A `Linker` that compiles NumPy-based operations using torch.compile."""
incompatible_rewrites = (
"cxx_only",
"BlasOpt",
"fusion",
"inplace",
"scan_save_mem_prealloc",
"reuse_lu_decomposition_multiple_solves",
"scan_split_non_sequence_lu_decomposition_solve",
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.gen_functors = []
......
......@@ -55,11 +55,15 @@ def test_NoOutputFromInplace():
def test_including():
mode = Mode(optimizer="merge")
assert set(mode._optimizer.include) == {"merge"}
mode = Mode(linker="py", optimizer="merge")
assert set(mode._optimizer.include) == {"minimum_compile", "merge"}
new_mode = mode.including("fast_compile")
assert set(new_mode._optimizer.include) == {"merge", "fast_compile"}
assert set(new_mode._optimizer.include) == {
"minimum_compile",
"merge",
"fast_compile",
}
class TestBunchOfModes:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论