提交 801845cc authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Skip scan rewrites if there is no Scan Op in the graph

上级 051b32da
...@@ -310,18 +310,21 @@ class EquilibriumDB(RewriteDatabase): ...@@ -310,18 +310,21 @@ class EquilibriumDB(RewriteDatabase):
""" """
def __init__( def __init__(
self, ignore_newtrees: bool = True, tracks_on_change_inputs: bool = False self,
ignore_newtrees: bool = True,
tracks_on_change_inputs: bool = False,
eq_rewriter_class=pytensor_rewriting.EquilibriumGraphRewriter,
): ):
""" """
Parameters Parameters
---------- ----------
ignore_newtrees ignore_newtrees
If ``False``, apply rewrites to new nodes introduced during If ``False``, apply rewrites to new nodes introduced during rewritings.
rewriting.
tracks_on_change_inputs tracks_on_change_inputs
If ``True``, re-apply rewrites on nodes with changed inputs. If ``True``, re-apply rewrites on nodes with changed inputs.
eq_rewriter_class: EquilibriumGraphRewriter class, optional
The class used to create the equilibrium rewriter. Defaults to EquilibriumGraphRewriter.
""" """
super().__init__() super().__init__()
...@@ -329,6 +332,7 @@ class EquilibriumDB(RewriteDatabase): ...@@ -329,6 +332,7 @@ class EquilibriumDB(RewriteDatabase):
self.tracks_on_change_inputs = tracks_on_change_inputs self.tracks_on_change_inputs = tracks_on_change_inputs
self.__final__: dict[str, bool] = {} self.__final__: dict[str, bool] = {}
self.__cleanup__: dict[str, bool] = {} self.__cleanup__: dict[str, bool] = {}
self.eq_rewriter_class = eq_rewriter_class
def register( def register(
self, self,
...@@ -360,7 +364,7 @@ class EquilibriumDB(RewriteDatabase): ...@@ -360,7 +364,7 @@ class EquilibriumDB(RewriteDatabase):
final_rewriters = None final_rewriters = None
if len(cleanup_rewriters) == 0: if len(cleanup_rewriters) == 0:
cleanup_rewriters = None cleanup_rewriters = None
return pytensor_rewriting.EquilibriumGraphRewriter( return self.eq_rewriter_class(
rewriters, rewriters,
max_use_ratio=config.optdb__max_use_ratio, max_use_ratio=config.optdb__max_use_ratio,
ignore_newtrees=self.ignore_newtrees, ignore_newtrees=self.ignore_newtrees,
......
...@@ -30,6 +30,7 @@ from pytensor.graph.fg import FunctionGraph, Output ...@@ -30,6 +30,7 @@ from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import compute_test_value from pytensor.graph.op import compute_test_value
from pytensor.graph.replace import clone_replace from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import ( from pytensor.graph.rewriting.basic import (
EquilibriumGraphRewriter,
GraphRewriter, GraphRewriter,
copy_stack_trace, copy_stack_trace,
in2out, in2out,
...@@ -2517,12 +2518,21 @@ def scan_push_out_dot1(fgraph, node): ...@@ -2517,12 +2518,21 @@ def scan_push_out_dot1(fgraph, node):
return False return False
class ScanEquilibriumGraphRewriter(EquilibriumGraphRewriter):
"""Subclass of EquilibriumGraphRewriter that aborts early if there are no Scan Ops in the graph"""
def apply(self, fgraph, start_from=None):
if not any(isinstance(node.op, Scan) for node in fgraph.apply_nodes):
return
super().apply(fgraph=fgraph, start_from=start_from)
# I've added an equilibrium because later scan optimization in the sequence # I've added an equilibrium because later scan optimization in the sequence
# can make it such that earlier optimizations should apply. However, in # can make it such that earlier optimizations should apply. However, in
# general I do not expect the sequence to run more then once # general I do not expect the sequence to run more then once
scan_eqopt1 = EquilibriumDB() scan_eqopt1 = EquilibriumDB(eq_rewriter_class=ScanEquilibriumGraphRewriter)
scan_seqopt1 = SequenceDB() scan_seqopt1 = SequenceDB()
scan_eqopt2 = EquilibriumDB() scan_eqopt2 = EquilibriumDB(eq_rewriter_class=ScanEquilibriumGraphRewriter)
# scan_eqopt1 before ShapeOpt at 0.1 # scan_eqopt1 before ShapeOpt at 0.1
# This is needed to don't have ShapeFeature trac old Scan that we # This is needed to don't have ShapeFeature trac old Scan that we
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论