提交 19f1486b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow OpPattern in tracks

Also avoid repeated checks when an outer rewriter enforces tracks before calling individual node rewriters
上级 085f2723
......@@ -583,7 +583,7 @@
" def tracks(self):\n",
" return [pt.log]\n",
" \n",
" def transform(self, fgraph, node):\n",
" def transform(self, fgraph, node, enforce_tracks=True):\n",
" return local_log1p(node) \n",
" \n",
" def __str__(self):\n",
......@@ -669,8 +669,8 @@
"@node_rewriter(tracks=[pt.abs])\n",
"def local_useless_abs_exp(fgraph, node):\n",
" # Because of the tracks we don't need to check \n",
" # that `node` has a `Sign` Op.\n",
" # We still need to check whether it's input is an `Abs` Op\n",
" # that `node` has a `Abs` Op.\n",
" # We still need to check whether it's input is an `Exp` Op\n",
" exp_node = node.inputs[0].owner\n",
" if exp_node is None or exp_node.op != pt.exp:\n",
" return None\n",
......
......@@ -141,7 +141,12 @@ class NodeRewriter(Rewriter):
@abc.abstractmethod
def transform(
self, fgraph: FunctionGraph, node: Apply, *args, **kwargs
self,
fgraph: FunctionGraph,
node: Apply,
enforce_tracks: bool = True,
*args,
**kwargs,
) -> TransformOutputType:
r"""Rewrite the sub-graph given by `node`.
......@@ -159,7 +164,9 @@ class NodeRewriter(Rewriter):
A `FunctionGraph` containing `node`.
node
An `Apply` node to be rewritten.
enforce_tracks: bool
Whether the transform method should enforce tracks, or it can be assumed the caller already enforced them in a pre-filter stage.
See `node_rewriter` tracks argument for more details.
"""
raise NotImplementedError()
......@@ -935,15 +942,48 @@ class FromFunctionNodeRewriter(NodeRewriter):
def __init__(self, fn, tracks=None, requirements=()):
self.fn = fn
self._tracks = tracks
self._tracked_types = (
tuple(t for t in tracks if isinstance(t, type)) if tracks else ()
)
self._tracked_ops = set()
self._tracked_types = type(None)
self._tracked_op_pattern_types = type(None)
self._tracked_op_patterns: list[OpPattern] = []
if tracks is not None:
if not tracks:
raise ValueError(
"To specify a general rewrite leave tracks as None instead of an empty container"
)
for t in tracks:
if isinstance(t, Op):
self._tracked_ops.add(t)
elif isinstance(t, type):
self._tracked_types |= t
elif isinstance(t, OpPattern):
if t.parameters:
self._tracked_op_patterns.append(t)
self._tracked_op_pattern_types |= t.op_type
else:
# An OpPattern without parameters behaves like a regular tracked_type
self._tracked_types |= t
else:
raise TypeError(
"`tracks` must consist of `Op` classes, `Op` instances or `OpPattern` instances. "
f"Got {t} of type {type(t)}"
)
self.requirements = requirements
def transform(self, fgraph, node):
if self._tracks:
def transform(self, fgraph, node, enforce_tracks: bool = True):
if enforce_tracks and self._tracks:
node_op = node.op
if not (
node.op in self._tracks or isinstance(node.op, self._tracked_types)
node_op in self._tracked_ops
or isinstance(node_op, self._tracked_types)
or (
isinstance(node.op, self._tracked_op_pattern_types)
and any(
t.match_parameters(node_op)
for t in self._tracked_op_patterns
if isinstance(node_op, t.op_type)
)
)
):
return False
......@@ -967,7 +1007,7 @@ class FromFunctionNodeRewriter(NodeRewriter):
def node_rewriter(
tracks: Sequence[Op | type] | None,
tracks: Sequence[Op | type, OpPattern] | None,
inplace: bool = False,
requirements: tuple[type, ...] | None = (),
):
......@@ -976,7 +1016,7 @@ def node_rewriter(
Parameters
----------
tracks
The `Op` types or instances to which this rewrite applies.
The `Op` type, instances or `OpPattern` to which this rewrite applies.
Use ``None`` instead of an empty list to have the rewrite apply to
all `Op`\s.
inplace
......@@ -995,14 +1035,16 @@ def node_rewriter(
if tracks is not None:
if len(tracks) == 0:
raise ValueError(
"Use `None` instead of an empty list to make an rewrite apply to all nodes."
"Use `None` instead of an empty list to make a rewrite apply to all nodes."
)
for t in tracks:
if not (
isinstance(t, Op) or (isinstance(t, type) and issubclass(t, Op))
isinstance(t, Op | OpPattern)
or (isinstance(t, type) and issubclass(t, Op))
):
raise TypeError(
"`tracks` must consist of `Op` classes or instances."
"`tracks` must consist of `Op` classes, `Op` instances or `OpPattern` instances. "
f"Got {t} of type {type(t)}"
)
req = requirements
if inplace:
......@@ -1024,47 +1066,93 @@ class OpToRewriterTracker:
def __init__(self):
self.tracked_instances: dict[Op, list[NodeRewriter]] = defaultdict(list)
self.tracked_types: dict[type, list[NodeRewriter]] = defaultdict(list)
self.tracked_pattern_types: dict[type, dict[OpPattern, list[NodeRewriter]]] = (
defaultdict(lambda: defaultdict(list))
)
self.untracked_rewrites: list[NodeRewriter] = []
self._cached_composed_mro = None
def add_tracker(self, rw: NodeRewriter):
"""Add a `NodeRewriter` to be keyed by its `NodeRewriter.tracks` or applied generally."""
if self._cached_composed_mro is not None:
# We shouldn't actually add_trackers after the first call to get_trackers
# But just to be safe we kill the cache here
self._cached_composed_mro = None
tracks = rw.tracks()
if tracks is None:
self.untracked_rewrites.append(rw)
else:
for c in tracks:
if isinstance(c, OpPattern):
if not isinstance(c.op_type, type):
# OpPattern allows anything that you can check with isinstance(op, op_type),
# including tuples or union types. But for OpToRewriterTracker we need a single type.
raise NotImplementedError(
"OpToRewriterTracker requires the outermost `OpPattern.op_type` to be a type. "
f"Got {c.op_type} of type {type(c.op_type)}"
)
if c.parameters:
self.tracked_pattern_types[c.op_type][c].append(rw)
else:
# An OpPattern without parameters behaves like a regular tracked_type
self.tracked_types[c.op_type].append(rw)
if isinstance(c, type):
self.tracked_types[c].append(rw)
else:
self.tracked_instances[c].append(rw)
def _find_impl(self, cls) -> list[NodeRewriter]:
r"""Returns the `NodeRewriter`\s that apply to `cls` based on inheritance.
@functools.cache
def get_trackers(self, op: Op) -> list[NodeRewriter]:
"""Get all the rewrites applicable to an `Op`."""
if self._cached_composed_mro is None:
# Cache the mro call on the Op type. We have a small subset of op_types we actually care about
# like Elemwise, Blockwise, and so on, which we don't need to repeatedly investigate
tracked_types = (
self.tracked_types.keys() | self.tracked_pattern_types.keys()
)
@functools.cache
def cached_composed_mro(op_type, tracked_types=tracked_types):
return _compose_mro(op_type, tracked_types)
self._cached_composed_mro = cached_composed_mro
This based on `functools._find_impl`.
"""
mro = _compose_mro(cls, self.tracked_types.keys())
matches = []
for t in mro:
match = self.tracked_types.get(t, None)
if match:
matches.extend(match)
if self.tracked_types or self.tracked_pattern_types:
# Find matches for type(op) (and their subclasses) using the same approach that functools.singledispatch uses
mro = self._cached_composed_mro(type(op))
for t in mro:
if (match := self.tracked_types.get(t, None)) is not None:
matches.extend(match)
if (
potential_matches := self.tracked_pattern_types.get(t, None)
) is not None:
# We still need to check if the Op parameters match the constraints
matches.extend(
[
item
for op_pattern, r_list in potential_matches.items()
if op_pattern.match_parameters(op)
for item in r_list
]
)
matches.extend(self.tracked_instances.get(op, []))
matches.extend(self.untracked_rewrites)
return matches
@functools.lru_cache
def get_trackers(self, op: Op) -> list[NodeRewriter]:
"""Get all the rewrites applicable to `op`."""
return (
self._find_impl(type(op))
+ self.tracked_instances.get(op, [])
+ self.untracked_rewrites
)
def get_rewriters(self):
def get_rewriters(self) -> Iterable[NodeRewriter]:
"""Get all the registered rewriters."""
return chain(
chain.from_iterable(self.tracked_types.values()),
chain.from_iterable(self.tracked_instances.values()),
chain.from_iterable(
chain(self.tracked_types.values(), self.tracked_instances.values())
item
for sub_dict in self.tracked_pattern_types.values()
for item in sub_dict.values()
),
self.untracked_rewrites,
)
......@@ -1138,7 +1226,7 @@ class SequentialNodeRewriter(NodeRewriter):
t.extend(at)
return t
def transform(self, fgraph, node):
def transform(self, fgraph, node, enforce_tracks=False):
if len(self.rewrites) == 0:
return
......@@ -1150,7 +1238,8 @@ class SequentialNodeRewriter(NodeRewriter):
new_repl = None
for rewrite in rewrites:
rewrite_start = time.perf_counter()
new_repl = rewrite.transform(fgraph, node)
# Tracks are already enforced by `self.tracker.get_trackers`
new_repl = rewrite.transform(fgraph, node, enforce_tracks=False)
rewrite_finish = time.perf_counter()
if self.profile:
self.time_rewrites[rewrite] += rewrite_start - rewrite_finish
......@@ -1292,8 +1381,8 @@ class SubstitutionNodeRewriter(NodeRewriter):
def tracks(self):
return [self.op1]
def transform(self, fgraph, node):
if node.op != self.op1:
def transform(self, fgraph, node, enforce_tracks=True):
if enforce_tracks and (node.op != self.op1):
return False
repl = self.op2.make_node(*node.inputs)
if self.transfer_tags:
......@@ -1498,7 +1587,7 @@ class PatternNodeRewriter(NodeRewriter):
def tracks(self):
return self._tracks
def transform(self, fgraph, node, get_nodes=True):
def transform(self, fgraph, node, enforce_tracks: bool = False, get_nodes=True):
"""Check if the graph from node corresponds to ``in_pattern``.
If it does, it constructs ``out_pattern`` and performs the replacement.
......@@ -1788,6 +1877,7 @@ class NodeProcessingGraphRewriter(GraphRewriter):
fgraph: FunctionGraph,
node: Apply,
node_rewriter: NodeRewriter | None = None,
enforce_tracks: bool = True,
):
r"""Apply `node_rewriter` to `node`.
......@@ -1805,6 +1895,9 @@ class NodeProcessingGraphRewriter(GraphRewriter):
node_rewriter
A `NodeRewriter` instance that may have a better idea for
how to compute node's outputs.
enforce_tracks: bool
Whether the transform method should enforce tracks,
or it can be assumed the caller already enforced them in a pre-filter stage.
Returns
-------
......@@ -1820,7 +1913,9 @@ class NodeProcessingGraphRewriter(GraphRewriter):
# TODO FIXME: This class's interface is broken
assert node_rewriter is not None
try:
replacements = node_rewriter.transform(fgraph, node)
replacements = node_rewriter.transform(
fgraph, node, enforce_tracks=enforce_tracks
)
except Exception as e:
if self.failure_callback is not None:
self.failure_callback(
......@@ -1938,7 +2033,8 @@ class WalkingGraphRewriter(NodeProcessingGraphRewriter):
if node not in fgraph.apply_nodes:
continue
current_node = node
nb += self.process_node(fgraph, node)
# This rewriter does not enforce tracks itself
nb += self.process_node(fgraph, node, enforce_tracks=True)
loop_t = time.perf_counter() - t0
finally:
self.detach_updater(fgraph, u)
......@@ -2279,8 +2375,9 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
for node_rewriter in self.node_tracker.get_trackers(node.op):
nb = change_tracker.nb_imported
t_rewrite = time.perf_counter()
# Tracks are already enfoced by `self.node_tracker.get_trackers`
node_rewriter_change = self.process_node(
fgraph, node, node_rewriter
fgraph, node, node_rewriter, enforce_tracks=False
)
time_rewriters[node_rewriter] += time.perf_counter() - t_rewrite
if not node_rewriter_change:
......
......@@ -74,7 +74,7 @@ class KanrenRelationSub(NodeRewriter):
self.node_filter = node_filter
super().__init__()
def transform(self, fgraph, node):
def transform(self, fgraph, node, enforce_tracks: bool = True):
if self.node_filter(node) is False:
return False
......@@ -92,7 +92,7 @@ class KanrenRelationSub(NodeRewriter):
if isinstance(chosen_res, list):
new_outputs = [eval_if_etuple(v) for v in chosen_res]
else:
new_outputs = [eval_if_etuple(chosen_res)]
new_outputs = [eval_if_etuple(chosen_res)] # type: ignore[unreachable]
return new_outputs
else:
......
......@@ -278,6 +278,42 @@ class OpPattern:
Examples
--------
OpPattern can be used in the `tracks` functionality of `node_rewriter` to more flexible filter out nodes.
For Ops that are parametrized by other Ops, it's possible to use nested OpPatterns.
.. test-code::
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.graph.rewriting.unify import OpPattern
from pytensor.tensor.elemwise import CAReduce
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.slinalg import Solve
@node_rewriter(tracks=[OpPattern(CAReduce, axis=None)])
def local_car_reduce_all_rewriter(fgraph, node):
# This will always be true!
assert isinstance(node.op, CAReduce) and node.op.axis is None
...
# Any Blockwise whose core_op is a Solve Op (or subclass) instance
@node_rewriter(tracks=[OpPattern(Blockwise, core_op=OpPattern(Solve))])
def local_blockwise_solve_triangular_rewriter(fgraph, node):
# This will always be true!
assert isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Solve)
...
# Any Blockwise whose core_op is a Solve Op (or subclass) instance with b_ndim==1
@node_rewriter(tracks=[OpPattern(Blockwise, core_op=OpPattern(Solve, b_ndim=1))])
def local_blockwise_vector_solve_rewriter(fgraph, node):
# This will always be true!
assert (
isinstance(node.op, Blockwise)
and isinstance(node.op.core_op, Solve)
and node.op.core_op.b_ndim == 1
)
...
OpPattern can be used with `PatternNodeRewriter` to define graph rewrites that match Ops with specific parameters.
The example below matches two nested CAReduce Ops with the same `scalar_op`,
the outer with `axis=None` (full reduction) and fuses them into a single CAReduce.
......
......@@ -1338,9 +1338,9 @@ class AlgebraicCanonizer(NodeRewriter):
return ct + num, denum
def transform(self, fgraph, node):
def transform(self, fgraph, node, enforce_tracks=True):
op = node.op
if op not in [self.main, self.inverse, self.reciprocal]:
if enforce_tracks and (op not in {self.main, self.inverse, self.reciprocal}):
return False
assert len(node.outputs) == 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论