提交 19f1486b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow OpPattern in tracks

Also avoid repeated checks when an outer rewriter enforces tracks before calling individual node rewriters
上级 085f2723
......@@ -583,7 +583,7 @@
" def tracks(self):\n",
" return [pt.log]\n",
" \n",
" def transform(self, fgraph, node):\n",
" def transform(self, fgraph, node, enforce_tracks=True):\n",
" return local_log1p(node) \n",
" \n",
" def __str__(self):\n",
......@@ -669,8 +669,8 @@
"@node_rewriter(tracks=[pt.abs])\n",
"def local_useless_abs_exp(fgraph, node):\n",
" # Because of the tracks we don't need to check \n",
" # that `node` has a `Sign` Op.\n",
" # We still need to check whether it's input is an `Abs` Op\n",
" # that `node` has a `Abs` Op.\n",
" # We still need to check whether it's input is an `Exp` Op\n",
" exp_node = node.inputs[0].owner\n",
" if exp_node is None or exp_node.op != pt.exp:\n",
" return None\n",
......
......@@ -74,7 +74,7 @@ class KanrenRelationSub(NodeRewriter):
self.node_filter = node_filter
super().__init__()
def transform(self, fgraph, node):
def transform(self, fgraph, node, enforce_tracks: bool = True):
if self.node_filter(node) is False:
return False
......@@ -92,7 +92,7 @@ class KanrenRelationSub(NodeRewriter):
if isinstance(chosen_res, list):
new_outputs = [eval_if_etuple(v) for v in chosen_res]
else:
new_outputs = [eval_if_etuple(chosen_res)]
new_outputs = [eval_if_etuple(chosen_res)] # type: ignore[unreachable]
return new_outputs
else:
......
......@@ -278,6 +278,42 @@ class OpPattern:
Examples
--------
OpPattern can be used in the `tracks` functionality of `node_rewriter` to more flexible filter out nodes.
For Ops that are parametrized by other Ops, it's possible to use nested OpPatterns.
.. test-code::
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.graph.rewriting.unify import OpPattern
from pytensor.tensor.elemwise import CAReduce
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.slinalg import Solve
@node_rewriter(tracks=[OpPattern(CAReduce, axis=None)])
def local_car_reduce_all_rewriter(fgraph, node):
# This will always be true!
assert isinstance(node.op, CAReduce) and node.op.axis is None
...
# Any Blockwise whose core_op is a Solve Op (or subclass) instance
@node_rewriter(tracks=[OpPattern(Blockwise, core_op=OpPattern(Solve))])
def local_blockwise_solve_triangular_rewriter(fgraph, node):
# This will always be true!
assert isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Solve)
...
# Any Blockwise whose core_op is a Solve Op (or subclass) instance with b_ndim==1
@node_rewriter(tracks=[OpPattern(Blockwise, core_op=OpPattern(Solve, b_ndim=1))])
def local_blockwise_vector_solve_rewriter(fgraph, node):
# This will always be true!
assert (
isinstance(node.op, Blockwise)
and isinstance(node.op.core_op, Solve)
and node.op.core_op.b_ndim == 1
)
...
OpPattern can be used with `PatternNodeRewriter` to define graph rewrites that match Ops with specific parameters.
The example below matches two nested CAReduce Ops with the same `scalar_op`,
the outer with `axis=None` (full reduction) and fuses them into a single CAReduce.
......
......@@ -1338,9 +1338,9 @@ class AlgebraicCanonizer(NodeRewriter):
return ct + num, denum
def transform(self, fgraph, node):
def transform(self, fgraph, node, enforce_tracks=True):
op = node.op
if op not in [self.main, self.inverse, self.reciprocal]:
if enforce_tracks and (op not in {self.main, self.inverse, self.reciprocal}):
return False
assert len(node.outputs) == 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论