提交 e274e0d2 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow linkers to specify required/incompatible rewrites

上级 f5bf2afa
...@@ -1331,7 +1331,11 @@ default_make_thunk = [get_unbound_function(COp.make_thunk)] ...@@ -1331,7 +1331,11 @@ default_make_thunk = [get_unbound_function(COp.make_thunk)]
# the external requirements of the .linker attribute of a mode # the external requirements of the .linker attribute of a mode
# 1) it's a class instance # 1) it's a class instance
# 2) it a has a .clone() method # 2) it a has a .clone() method
# 3) it has required_rewrites and incompatible_rewrites class attributes
class _DummyLinker: class _DummyLinker:
required_rewrites = ()
incompatible_rewrites = ()
# This is not a real linker anyway # This is not a real linker anyway
def clone(self, allow_gc=None): def clone(self, allow_gc=None):
return self return self
......
...@@ -352,7 +352,14 @@ class Mode: ...@@ -352,7 +352,14 @@ class Mode:
if isinstance(optimizer, str) or optimizer is None: if isinstance(optimizer, str) or optimizer is None:
optimizer = predefined_optimizers[optimizer] optimizer = predefined_optimizers[optimizer]
if isinstance(optimizer, RewriteDatabaseQuery): if isinstance(optimizer, RewriteDatabaseQuery):
# TODO: From the __init__ signature this should always be the case
# But some tests and internal logic allow passing a GraphRewriter directly as optimizer
# Cleanup!
self.provided_optimizer = optimizer self.provided_optimizer = optimizer
if r := linker.required_rewrites:
optimizer = optimizer.including(*r)
if r := linker.incompatible_rewrites:
optimizer = optimizer.excluding(*r)
self._optimizer = optimizer self._optimizer = optimizer
self.call_time = 0 self.call_time = 0
self.fn_time = 0 self.fn_time = 0
...@@ -365,14 +372,13 @@ class Mode: ...@@ -365,14 +372,13 @@ class Mode:
f"optdb={self.optdb})" f"optdb={self.optdb})"
) )
def __get_optimizer(self): @property
def optimizer(self):
if isinstance(self._optimizer, RewriteDatabaseQuery): if isinstance(self._optimizer, RewriteDatabaseQuery):
return self.optdb.query(self._optimizer) return self.optdb.query(self._optimizer)
else: else:
return self._optimizer return self._optimizer
optimizer = property(__get_optimizer)
def get_linker_optimizer(self, linker, optimizer): def get_linker_optimizer(self, linker, optimizer):
if isinstance(linker, str) or linker is None: if isinstance(linker, str) or linker is None:
linker = predefined_linkers[linker] linker = predefined_linkers[linker]
...@@ -466,61 +472,21 @@ C_VM = Mode("cvm", "fast_run") ...@@ -466,61 +472,21 @@ C_VM = Mode("cvm", "fast_run")
NUMBA = Mode( NUMBA = Mode(
NumbaLinker(), NumbaLinker(),
RewriteDatabaseQuery( RewriteDatabaseQuery(include=["fast_run", "numba"]),
include=["fast_run", "numba"],
exclude=[
"cxx_only",
"BlasOpt",
"local_careduce_fusion",
"scan_save_mem_prealloc",
],
),
) )
JAX = Mode( JAX = Mode(
JAXLinker(), JAXLinker(),
RewriteDatabaseQuery( RewriteDatabaseQuery(include=["fast_run", "jax"]),
include=["fast_run", "jax"],
exclude=[
"cxx_only",
"BlasOpt",
"fusion",
"inplace",
"scan_save_mem_prealloc",
# There are specific variants for the LU decompositions supported by JAX
"reuse_lu_decomposition_multiple_solves",
"scan_split_non_sequence_lu_decomposition_solve",
],
),
) )
PYTORCH = Mode( PYTORCH = Mode(
PytorchLinker(), PytorchLinker(),
RewriteDatabaseQuery( RewriteDatabaseQuery(include=["fast_run"]),
include=["fast_run"],
exclude=[
"cxx_only",
"BlasOpt",
"fusion",
"inplace",
"scan_save_mem_prealloc",
"reuse_lu_decomposition_multiple_solves",
"scan_split_non_sequence_lu_decomposition_solve",
],
),
) )
MLX = Mode( MLX = Mode(
MLXLinker(), MLXLinker(),
RewriteDatabaseQuery( RewriteDatabaseQuery(include=["fast_run"]),
include=["fast_run"],
exclude=[
"cxx_only",
"BlasOpt",
"fusion",
"inplace",
"scan_save_mem_prealloc",
],
),
) )
......
...@@ -157,6 +157,9 @@ class Linker(ABC): ...@@ -157,6 +157,9 @@ class Linker(ABC):
the FunctionGraph. the FunctionGraph.
""" """
required_rewrites: tuple[str, ...] = ("minimum_compile",)
incompatible_rewrites: tuple[str, ...] = ()
def __init__( def __init__(
self, self,
*, *,
......
...@@ -9,6 +9,22 @@ from pytensor.link.basic import JITLinker ...@@ -9,6 +9,22 @@ from pytensor.link.basic import JITLinker
class JAXLinker(JITLinker): class JAXLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using JAX.""" """A `Linker` that JIT-compiles NumPy-based operations using JAX."""
required_rewrites = (
"minimum_compile",
"jax",
) # TODO: Distinguish between optional "jax" and "minimum_compile_jax"
incompatible_rewrites = (
"cxx",
"BlasOpt",
"local_careduce_fusion",
"scan_save_mem_prealloc",
# JAX does it his own inplace optimization
"inplace",
# There are specific variants for the LU decompositions supported by JAX
"reuse_lu_decomposition_multiple_solves",
"scan_split_non_sequence_lu_decomposition_solve",
)
scalar_shape_inputs: tuple[int, ...] scalar_shape_inputs: tuple[int, ...]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
......
...@@ -4,6 +4,14 @@ from pytensor.link.basic import JITLinker ...@@ -4,6 +4,14 @@ from pytensor.link.basic import JITLinker
class MLXLinker(JITLinker): class MLXLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using Apple's MLX.""" """A `Linker` that JIT-compiles NumPy-based operations using Apple's MLX."""
incompatible_rewrites = (
"cxx_only",
"BlasOpt",
"fusion",
"inplace",
"scan_save_mem_prealloc",
)
def __init__(self, use_compile=True, *args, **kwargs): def __init__(self, use_compile=True, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.gen_functors = [] self.gen_functors = []
......
...@@ -2,6 +2,17 @@ from pytensor.link.basic import JITLinker ...@@ -2,6 +2,17 @@ from pytensor.link.basic import JITLinker
class NumbaLinker(JITLinker): class NumbaLinker(JITLinker):
required_rewrites = (
"minimum_compile",
"numba",
) # TODO: Distinguish between optional "numba" and "minimum_compile_numba"
incompatible_rewrites = (
"cxx",
"BlasOpt",
"local_careduce_fusion",
"scan_save_mem_prealloc",
)
"""A `Linker` that JIT-compiles NumPy-based operations using Numba.""" """A `Linker` that JIT-compiles NumPy-based operations using Numba."""
def fgraph_convert(self, fgraph, **kwargs): def fgraph_convert(self, fgraph, **kwargs):
......
...@@ -5,6 +5,16 @@ from pytensor.link.utils import unique_name_generator ...@@ -5,6 +5,16 @@ from pytensor.link.utils import unique_name_generator
class PytorchLinker(JITLinker): class PytorchLinker(JITLinker):
"""A `Linker` that compiles NumPy-based operations using torch.compile.""" """A `Linker` that compiles NumPy-based operations using torch.compile."""
incompatible_rewrites = (
"cxx_only",
"BlasOpt",
"fusion",
"inplace",
"scan_save_mem_prealloc",
"reuse_lu_decomposition_multiple_solves",
"scan_split_non_sequence_lu_decomposition_solve",
)
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.gen_functors = [] self.gen_functors = []
......
...@@ -55,11 +55,15 @@ def test_NoOutputFromInplace(): ...@@ -55,11 +55,15 @@ def test_NoOutputFromInplace():
def test_including(): def test_including():
mode = Mode(optimizer="merge") mode = Mode(linker="py", optimizer="merge")
assert set(mode._optimizer.include) == {"merge"} assert set(mode._optimizer.include) == {"minimum_compile", "merge"}
new_mode = mode.including("fast_compile") new_mode = mode.including("fast_compile")
assert set(new_mode._optimizer.include) == {"merge", "fast_compile"} assert set(new_mode._optimizer.include) == {
"minimum_compile",
"merge",
"fast_compile",
}
class TestBunchOfModes: class TestBunchOfModes:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论