提交 19f1486b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow OpPattern in tracks

Also avoid repeated checks when an outer rewriter enforces tracks before calling individual node rewriters
上级 085f2723
...@@ -583,7 +583,7 @@ ...@@ -583,7 +583,7 @@
" def tracks(self):\n", " def tracks(self):\n",
" return [pt.log]\n", " return [pt.log]\n",
" \n", " \n",
" def transform(self, fgraph, node):\n", " def transform(self, fgraph, node, enforce_tracks=True):\n",
" return local_log1p(node) \n", " return local_log1p(node) \n",
" \n", " \n",
" def __str__(self):\n", " def __str__(self):\n",
...@@ -669,8 +669,8 @@ ...@@ -669,8 +669,8 @@
"@node_rewriter(tracks=[pt.abs])\n", "@node_rewriter(tracks=[pt.abs])\n",
"def local_useless_abs_exp(fgraph, node):\n", "def local_useless_abs_exp(fgraph, node):\n",
" # Because of the tracks we don't need to check \n", " # Because of the tracks we don't need to check \n",
" # that `node` has a `Sign` Op.\n", " # that `node` has a `Abs` Op.\n",
" # We still need to check whether it's input is an `Abs` Op\n", " # We still need to check whether it's input is an `Exp` Op\n",
" exp_node = node.inputs[0].owner\n", " exp_node = node.inputs[0].owner\n",
" if exp_node is None or exp_node.op != pt.exp:\n", " if exp_node is None or exp_node.op != pt.exp:\n",
" return None\n", " return None\n",
......
...@@ -74,7 +74,7 @@ class KanrenRelationSub(NodeRewriter): ...@@ -74,7 +74,7 @@ class KanrenRelationSub(NodeRewriter):
self.node_filter = node_filter self.node_filter = node_filter
super().__init__() super().__init__()
def transform(self, fgraph, node): def transform(self, fgraph, node, enforce_tracks: bool = True):
if self.node_filter(node) is False: if self.node_filter(node) is False:
return False return False
...@@ -92,7 +92,7 @@ class KanrenRelationSub(NodeRewriter): ...@@ -92,7 +92,7 @@ class KanrenRelationSub(NodeRewriter):
if isinstance(chosen_res, list): if isinstance(chosen_res, list):
new_outputs = [eval_if_etuple(v) for v in chosen_res] new_outputs = [eval_if_etuple(v) for v in chosen_res]
else: else:
new_outputs = [eval_if_etuple(chosen_res)] new_outputs = [eval_if_etuple(chosen_res)] # type: ignore[unreachable]
return new_outputs return new_outputs
else: else:
......
...@@ -278,6 +278,42 @@ class OpPattern: ...@@ -278,6 +278,42 @@ class OpPattern:
Examples Examples
-------- --------
OpPattern can be used in the `tracks` functionality of `node_rewriter` to more flexible filter out nodes.
For Ops that are parametrized by other Ops, it's possible to use nested OpPatterns.
.. test-code::
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.graph.rewriting.unify import OpPattern
from pytensor.tensor.elemwise import CAReduce
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.slinalg import Solve
@node_rewriter(tracks=[OpPattern(CAReduce, axis=None)])
def local_car_reduce_all_rewriter(fgraph, node):
# This will always be true!
assert isinstance(node.op, CAReduce) and node.op.axis is None
...
# Any Blockwise whose core_op is a Solve Op (or subclass) instance
@node_rewriter(tracks=[OpPattern(Blockwise, core_op=OpPattern(Solve))])
def local_blockwise_solve_triangular_rewriter(fgraph, node):
# This will always be true!
assert isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Solve)
...
# Any Blockwise whose core_op is a Solve Op (or subclass) instance with b_ndim==1
@node_rewriter(tracks=[OpPattern(Blockwise, core_op=OpPattern(Solve, b_ndim=1))])
def local_blockwise_vector_solve_rewriter(fgraph, node):
# This will always be true!
assert (
isinstance(node.op, Blockwise)
and isinstance(node.op.core_op, Solve)
and node.op.core_op.b_ndim == 1
)
...
OpPattern can be used with `PatternNodeRewriter` to define graph rewrites that match Ops with specific parameters. OpPattern can be used with `PatternNodeRewriter` to define graph rewrites that match Ops with specific parameters.
The example below matches two nested CAReduce Ops with the same `scalar_op`, The example below matches two nested CAReduce Ops with the same `scalar_op`,
the outer with `axis=None` (full reduction) and fuses them into a single CAReduce. the outer with `axis=None` (full reduction) and fuses them into a single CAReduce.
......
...@@ -1338,9 +1338,9 @@ class AlgebraicCanonizer(NodeRewriter): ...@@ -1338,9 +1338,9 @@ class AlgebraicCanonizer(NodeRewriter):
return ct + num, denum return ct + num, denum
def transform(self, fgraph, node): def transform(self, fgraph, node, enforce_tracks=True):
op = node.op op = node.op
if op not in [self.main, self.inverse, self.reciprocal]: if enforce_tracks and (op not in {self.main, self.inverse, self.reciprocal}):
return False return False
assert len(node.outputs) == 1 assert len(node.outputs) == 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论