提交 19f1486b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow OpPattern in tracks

Also avoid repeated checks when an outer rewriter enforces tracks before calling individual node rewriters
上级 085f2723
...@@ -583,7 +583,7 @@ ...@@ -583,7 +583,7 @@
" def tracks(self):\n", " def tracks(self):\n",
" return [pt.log]\n", " return [pt.log]\n",
" \n", " \n",
" def transform(self, fgraph, node):\n", " def transform(self, fgraph, node, enforce_tracks=True):\n",
" return local_log1p(node) \n", " return local_log1p(node) \n",
" \n", " \n",
" def __str__(self):\n", " def __str__(self):\n",
...@@ -669,8 +669,8 @@ ...@@ -669,8 +669,8 @@
"@node_rewriter(tracks=[pt.abs])\n", "@node_rewriter(tracks=[pt.abs])\n",
"def local_useless_abs_exp(fgraph, node):\n", "def local_useless_abs_exp(fgraph, node):\n",
" # Because of the tracks we don't need to check \n", " # Because of the tracks we don't need to check \n",
" # that `node` has a `Sign` Op.\n", " # that `node` has a `Abs` Op.\n",
" # We still need to check whether it's input is an `Abs` Op\n", " # We still need to check whether it's input is an `Exp` Op\n",
" exp_node = node.inputs[0].owner\n", " exp_node = node.inputs[0].owner\n",
" if exp_node is None or exp_node.op != pt.exp:\n", " if exp_node is None or exp_node.op != pt.exp:\n",
" return None\n", " return None\n",
......
...@@ -141,7 +141,12 @@ class NodeRewriter(Rewriter): ...@@ -141,7 +141,12 @@ class NodeRewriter(Rewriter):
@abc.abstractmethod @abc.abstractmethod
def transform( def transform(
self, fgraph: FunctionGraph, node: Apply, *args, **kwargs self,
fgraph: FunctionGraph,
node: Apply,
enforce_tracks: bool = True,
*args,
**kwargs,
) -> TransformOutputType: ) -> TransformOutputType:
r"""Rewrite the sub-graph given by `node`. r"""Rewrite the sub-graph given by `node`.
...@@ -159,7 +164,9 @@ class NodeRewriter(Rewriter): ...@@ -159,7 +164,9 @@ class NodeRewriter(Rewriter):
A `FunctionGraph` containing `node`. A `FunctionGraph` containing `node`.
node node
An `Apply` node to be rewritten. An `Apply` node to be rewritten.
enforce_tracks: bool
Whether the transform method should enforce tracks, or it can be assumed the caller already enforced them in a pre-filter stage.
See `node_rewriter` tracks argument for more details.
""" """
raise NotImplementedError() raise NotImplementedError()
...@@ -935,15 +942,48 @@ class FromFunctionNodeRewriter(NodeRewriter): ...@@ -935,15 +942,48 @@ class FromFunctionNodeRewriter(NodeRewriter):
def __init__(self, fn, tracks=None, requirements=()): def __init__(self, fn, tracks=None, requirements=()):
self.fn = fn self.fn = fn
self._tracks = tracks self._tracks = tracks
self._tracked_types = ( self._tracked_ops = set()
tuple(t for t in tracks if isinstance(t, type)) if tracks else () self._tracked_types = type(None)
) self._tracked_op_pattern_types = type(None)
self._tracked_op_patterns: list[OpPattern] = []
if tracks is not None:
if not tracks:
raise ValueError(
"To specify a general rewrite leave tracks as None instead of an empty container"
)
for t in tracks:
if isinstance(t, Op):
self._tracked_ops.add(t)
elif isinstance(t, type):
self._tracked_types |= t
elif isinstance(t, OpPattern):
if t.parameters:
self._tracked_op_patterns.append(t)
self._tracked_op_pattern_types |= t.op_type
else:
# An OpPattern without parameters behaves like a regular tracked_type
self._tracked_types |= t
else:
raise TypeError(
"`tracks` must consist of `Op` classes, `Op` instances or `OpPattern` instances. "
f"Got {t} of type {type(t)}"
)
self.requirements = requirements self.requirements = requirements
def transform(self, fgraph, node): def transform(self, fgraph, node, enforce_tracks: bool = True):
if self._tracks: if enforce_tracks and self._tracks:
node_op = node.op
if not ( if not (
node.op in self._tracks or isinstance(node.op, self._tracked_types) node_op in self._tracked_ops
or isinstance(node_op, self._tracked_types)
or (
isinstance(node.op, self._tracked_op_pattern_types)
and any(
t.match_parameters(node_op)
for t in self._tracked_op_patterns
if isinstance(node_op, t.op_type)
)
)
): ):
return False return False
...@@ -967,7 +1007,7 @@ class FromFunctionNodeRewriter(NodeRewriter): ...@@ -967,7 +1007,7 @@ class FromFunctionNodeRewriter(NodeRewriter):
def node_rewriter( def node_rewriter(
tracks: Sequence[Op | type] | None, tracks: Sequence[Op | type, OpPattern] | None,
inplace: bool = False, inplace: bool = False,
requirements: tuple[type, ...] | None = (), requirements: tuple[type, ...] | None = (),
): ):
...@@ -976,7 +1016,7 @@ def node_rewriter( ...@@ -976,7 +1016,7 @@ def node_rewriter(
Parameters Parameters
---------- ----------
tracks tracks
The `Op` types or instances to which this rewrite applies. The `Op` type, instances or `OpPattern` to which this rewrite applies.
Use ``None`` instead of an empty list to have the rewrite apply to Use ``None`` instead of an empty list to have the rewrite apply to
all `Op`\s. all `Op`\s.
inplace inplace
...@@ -995,14 +1035,16 @@ def node_rewriter( ...@@ -995,14 +1035,16 @@ def node_rewriter(
if tracks is not None: if tracks is not None:
if len(tracks) == 0: if len(tracks) == 0:
raise ValueError( raise ValueError(
"Use `None` instead of an empty list to make an rewrite apply to all nodes." "Use `None` instead of an empty list to make a rewrite apply to all nodes."
) )
for t in tracks: for t in tracks:
if not ( if not (
isinstance(t, Op) or (isinstance(t, type) and issubclass(t, Op)) isinstance(t, Op | OpPattern)
or (isinstance(t, type) and issubclass(t, Op))
): ):
raise TypeError( raise TypeError(
"`tracks` must consist of `Op` classes or instances." "`tracks` must consist of `Op` classes, `Op` instances or `OpPattern` instances. "
f"Got {t} of type {type(t)}"
) )
req = requirements req = requirements
if inplace: if inplace:
...@@ -1024,47 +1066,93 @@ class OpToRewriterTracker: ...@@ -1024,47 +1066,93 @@ class OpToRewriterTracker:
def __init__(self): def __init__(self):
self.tracked_instances: dict[Op, list[NodeRewriter]] = defaultdict(list) self.tracked_instances: dict[Op, list[NodeRewriter]] = defaultdict(list)
self.tracked_types: dict[type, list[NodeRewriter]] = defaultdict(list) self.tracked_types: dict[type, list[NodeRewriter]] = defaultdict(list)
self.tracked_pattern_types: dict[type, dict[OpPattern, list[NodeRewriter]]] = (
defaultdict(lambda: defaultdict(list))
)
self.untracked_rewrites: list[NodeRewriter] = [] self.untracked_rewrites: list[NodeRewriter] = []
self._cached_composed_mro = None
def add_tracker(self, rw: NodeRewriter): def add_tracker(self, rw: NodeRewriter):
"""Add a `NodeRewriter` to be keyed by its `NodeRewriter.tracks` or applied generally.""" """Add a `NodeRewriter` to be keyed by its `NodeRewriter.tracks` or applied generally."""
if self._cached_composed_mro is not None:
# We shouldn't actually add_trackers after the first call to get_trackers
# But just to be safe we kill the cache here
self._cached_composed_mro = None
tracks = rw.tracks() tracks = rw.tracks()
if tracks is None: if tracks is None:
self.untracked_rewrites.append(rw) self.untracked_rewrites.append(rw)
else: else:
for c in tracks: for c in tracks:
if isinstance(c, OpPattern):
if not isinstance(c.op_type, type):
# OpPattern allows anything that you can check with isinstance(op, op_type),
# including tuples or union types. But for OpToRewriterTracker we need a single type.
raise NotImplementedError(
"OpToRewriterTracker requires the outermost `OpPattern.op_type` to be a type. "
f"Got {c.op_type} of type {type(c.op_type)}"
)
if c.parameters:
self.tracked_pattern_types[c.op_type][c].append(rw)
else:
# An OpPattern without parameters behaves like a regular tracked_type
self.tracked_types[c.op_type].append(rw)
if isinstance(c, type): if isinstance(c, type):
self.tracked_types[c].append(rw) self.tracked_types[c].append(rw)
else: else:
self.tracked_instances[c].append(rw) self.tracked_instances[c].append(rw)
def _find_impl(self, cls) -> list[NodeRewriter]: @functools.cache
r"""Returns the `NodeRewriter`\s that apply to `cls` based on inheritance. def get_trackers(self, op: Op) -> list[NodeRewriter]:
"""Get all the rewrites applicable to an `Op`."""
if self._cached_composed_mro is None:
# Cache the mro call on the Op type. We have a small subset of op_types we actually care about
# like Elemwise, Blockwise, and so on, which we don't need to repeatedly investigate
tracked_types = (
self.tracked_types.keys() | self.tracked_pattern_types.keys()
)
@functools.cache
def cached_composed_mro(op_type, tracked_types=tracked_types):
return _compose_mro(op_type, tracked_types)
self._cached_composed_mro = cached_composed_mro
This based on `functools._find_impl`.
"""
mro = _compose_mro(cls, self.tracked_types.keys())
matches = [] matches = []
for t in mro: if self.tracked_types or self.tracked_pattern_types:
match = self.tracked_types.get(t, None) # Find matches for type(op) (and their subclasses) using the same approach that functools.singledispatch uses
if match: mro = self._cached_composed_mro(type(op))
matches.extend(match) for t in mro:
if (match := self.tracked_types.get(t, None)) is not None:
matches.extend(match)
if (
potential_matches := self.tracked_pattern_types.get(t, None)
) is not None:
# We still need to check if the Op parameters match the constraints
matches.extend(
[
item
for op_pattern, r_list in potential_matches.items()
if op_pattern.match_parameters(op)
for item in r_list
]
)
matches.extend(self.tracked_instances.get(op, []))
matches.extend(self.untracked_rewrites)
return matches return matches
@functools.lru_cache def get_rewriters(self) -> Iterable[NodeRewriter]:
def get_trackers(self, op: Op) -> list[NodeRewriter]: """Get all the registered rewriters."""
"""Get all the rewrites applicable to `op`."""
return (
self._find_impl(type(op))
+ self.tracked_instances.get(op, [])
+ self.untracked_rewrites
)
def get_rewriters(self):
return chain( return chain(
chain.from_iterable(self.tracked_types.values()),
chain.from_iterable(self.tracked_instances.values()),
chain.from_iterable( chain.from_iterable(
chain(self.tracked_types.values(), self.tracked_instances.values()) item
for sub_dict in self.tracked_pattern_types.values()
for item in sub_dict.values()
), ),
self.untracked_rewrites, self.untracked_rewrites,
) )
...@@ -1138,7 +1226,7 @@ class SequentialNodeRewriter(NodeRewriter): ...@@ -1138,7 +1226,7 @@ class SequentialNodeRewriter(NodeRewriter):
t.extend(at) t.extend(at)
return t return t
def transform(self, fgraph, node): def transform(self, fgraph, node, enforce_tracks=False):
if len(self.rewrites) == 0: if len(self.rewrites) == 0:
return return
...@@ -1150,7 +1238,8 @@ class SequentialNodeRewriter(NodeRewriter): ...@@ -1150,7 +1238,8 @@ class SequentialNodeRewriter(NodeRewriter):
new_repl = None new_repl = None
for rewrite in rewrites: for rewrite in rewrites:
rewrite_start = time.perf_counter() rewrite_start = time.perf_counter()
new_repl = rewrite.transform(fgraph, node) # Tracks are already enforced by `self.tracker.get_trackers`
new_repl = rewrite.transform(fgraph, node, enforce_tracks=False)
rewrite_finish = time.perf_counter() rewrite_finish = time.perf_counter()
if self.profile: if self.profile:
self.time_rewrites[rewrite] += rewrite_start - rewrite_finish self.time_rewrites[rewrite] += rewrite_start - rewrite_finish
...@@ -1292,8 +1381,8 @@ class SubstitutionNodeRewriter(NodeRewriter): ...@@ -1292,8 +1381,8 @@ class SubstitutionNodeRewriter(NodeRewriter):
def tracks(self): def tracks(self):
return [self.op1] return [self.op1]
def transform(self, fgraph, node): def transform(self, fgraph, node, enforce_tracks=True):
if node.op != self.op1: if enforce_tracks and (node.op != self.op1):
return False return False
repl = self.op2.make_node(*node.inputs) repl = self.op2.make_node(*node.inputs)
if self.transfer_tags: if self.transfer_tags:
...@@ -1498,7 +1587,7 @@ class PatternNodeRewriter(NodeRewriter): ...@@ -1498,7 +1587,7 @@ class PatternNodeRewriter(NodeRewriter):
def tracks(self): def tracks(self):
return self._tracks return self._tracks
def transform(self, fgraph, node, get_nodes=True): def transform(self, fgraph, node, enforce_tracks: bool = False, get_nodes=True):
"""Check if the graph from node corresponds to ``in_pattern``. """Check if the graph from node corresponds to ``in_pattern``.
If it does, it constructs ``out_pattern`` and performs the replacement. If it does, it constructs ``out_pattern`` and performs the replacement.
...@@ -1788,6 +1877,7 @@ class NodeProcessingGraphRewriter(GraphRewriter): ...@@ -1788,6 +1877,7 @@ class NodeProcessingGraphRewriter(GraphRewriter):
fgraph: FunctionGraph, fgraph: FunctionGraph,
node: Apply, node: Apply,
node_rewriter: NodeRewriter | None = None, node_rewriter: NodeRewriter | None = None,
enforce_tracks: bool = True,
): ):
r"""Apply `node_rewriter` to `node`. r"""Apply `node_rewriter` to `node`.
...@@ -1805,6 +1895,9 @@ class NodeProcessingGraphRewriter(GraphRewriter): ...@@ -1805,6 +1895,9 @@ class NodeProcessingGraphRewriter(GraphRewriter):
node_rewriter node_rewriter
A `NodeRewriter` instance that may have a better idea for A `NodeRewriter` instance that may have a better idea for
how to compute node's outputs. how to compute node's outputs.
enforce_tracks: bool
Whether the transform method should enforce tracks,
or it can be assumed the caller already enforced them in a pre-filter stage.
Returns Returns
------- -------
...@@ -1820,7 +1913,9 @@ class NodeProcessingGraphRewriter(GraphRewriter): ...@@ -1820,7 +1913,9 @@ class NodeProcessingGraphRewriter(GraphRewriter):
# TODO FIXME: This class's interface is broken # TODO FIXME: This class's interface is broken
assert node_rewriter is not None assert node_rewriter is not None
try: try:
replacements = node_rewriter.transform(fgraph, node) replacements = node_rewriter.transform(
fgraph, node, enforce_tracks=enforce_tracks
)
except Exception as e: except Exception as e:
if self.failure_callback is not None: if self.failure_callback is not None:
self.failure_callback( self.failure_callback(
...@@ -1938,7 +2033,8 @@ class WalkingGraphRewriter(NodeProcessingGraphRewriter): ...@@ -1938,7 +2033,8 @@ class WalkingGraphRewriter(NodeProcessingGraphRewriter):
if node not in fgraph.apply_nodes: if node not in fgraph.apply_nodes:
continue continue
current_node = node current_node = node
nb += self.process_node(fgraph, node) # This rewriter does not enforce tracks itself
nb += self.process_node(fgraph, node, enforce_tracks=True)
loop_t = time.perf_counter() - t0 loop_t = time.perf_counter() - t0
finally: finally:
self.detach_updater(fgraph, u) self.detach_updater(fgraph, u)
...@@ -2279,8 +2375,9 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter): ...@@ -2279,8 +2375,9 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
for node_rewriter in self.node_tracker.get_trackers(node.op): for node_rewriter in self.node_tracker.get_trackers(node.op):
nb = change_tracker.nb_imported nb = change_tracker.nb_imported
t_rewrite = time.perf_counter() t_rewrite = time.perf_counter()
# Tracks are already enfoced by `self.node_tracker.get_trackers`
node_rewriter_change = self.process_node( node_rewriter_change = self.process_node(
fgraph, node, node_rewriter fgraph, node, node_rewriter, enforce_tracks=False
) )
time_rewriters[node_rewriter] += time.perf_counter() - t_rewrite time_rewriters[node_rewriter] += time.perf_counter() - t_rewrite
if not node_rewriter_change: if not node_rewriter_change:
......
...@@ -74,7 +74,7 @@ class KanrenRelationSub(NodeRewriter): ...@@ -74,7 +74,7 @@ class KanrenRelationSub(NodeRewriter):
self.node_filter = node_filter self.node_filter = node_filter
super().__init__() super().__init__()
def transform(self, fgraph, node): def transform(self, fgraph, node, enforce_tracks: bool = True):
if self.node_filter(node) is False: if self.node_filter(node) is False:
return False return False
...@@ -92,7 +92,7 @@ class KanrenRelationSub(NodeRewriter): ...@@ -92,7 +92,7 @@ class KanrenRelationSub(NodeRewriter):
if isinstance(chosen_res, list): if isinstance(chosen_res, list):
new_outputs = [eval_if_etuple(v) for v in chosen_res] new_outputs = [eval_if_etuple(v) for v in chosen_res]
else: else:
new_outputs = [eval_if_etuple(chosen_res)] new_outputs = [eval_if_etuple(chosen_res)] # type: ignore[unreachable]
return new_outputs return new_outputs
else: else:
......
...@@ -278,6 +278,42 @@ class OpPattern: ...@@ -278,6 +278,42 @@ class OpPattern:
Examples Examples
-------- --------
OpPattern can be used in the `tracks` functionality of `node_rewriter` to more flexible filter out nodes.
For Ops that are parametrized by other Ops, it's possible to use nested OpPatterns.
.. test-code::
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.graph.rewriting.unify import OpPattern
from pytensor.tensor.elemwise import CAReduce
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.slinalg import Solve
@node_rewriter(tracks=[OpPattern(CAReduce, axis=None)])
def local_car_reduce_all_rewriter(fgraph, node):
# This will always be true!
assert isinstance(node.op, CAReduce) and node.op.axis is None
...
# Any Blockwise whose core_op is a Solve Op (or subclass) instance
@node_rewriter(tracks=[OpPattern(Blockwise, core_op=OpPattern(Solve))])
def local_blockwise_solve_triangular_rewriter(fgraph, node):
# This will always be true!
assert isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Solve)
...
# Any Blockwise whose core_op is a Solve Op (or subclass) instance with b_ndim==1
@node_rewriter(tracks=[OpPattern(Blockwise, core_op=OpPattern(Solve, b_ndim=1))])
def local_blockwise_vector_solve_rewriter(fgraph, node):
# This will always be true!
assert (
isinstance(node.op, Blockwise)
and isinstance(node.op.core_op, Solve)
and node.op.core_op.b_ndim == 1
)
...
OpPattern can be used with `PatternNodeRewriter` to define graph rewrites that match Ops with specific parameters. OpPattern can be used with `PatternNodeRewriter` to define graph rewrites that match Ops with specific parameters.
The example below matches two nested CAReduce Ops with the same `scalar_op`, The example below matches two nested CAReduce Ops with the same `scalar_op`,
the outer with `axis=None` (full reduction) and fuses them into a single CAReduce. the outer with `axis=None` (full reduction) and fuses them into a single CAReduce.
......
...@@ -1338,9 +1338,9 @@ class AlgebraicCanonizer(NodeRewriter): ...@@ -1338,9 +1338,9 @@ class AlgebraicCanonizer(NodeRewriter):
return ct + num, denum return ct + num, denum
def transform(self, fgraph, node): def transform(self, fgraph, node, enforce_tracks=True):
op = node.op op = node.op
if op not in [self.main, self.inverse, self.reciprocal]: if enforce_tracks and (op not in {self.main, self.inverse, self.reciprocal}):
return False return False
assert len(node.outputs) == 1 assert len(node.outputs) == 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论