提交 085f2723 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Allow unifying with OpPattern

上级 51ab571a
......@@ -29,7 +29,7 @@ from pytensor.graph.basic import (
from pytensor.graph.features import AlreadyThere, Feature
from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op
from pytensor.graph.rewriting.unify import Var, convert_strs_to_vars
from pytensor.graph.rewriting.unify import OpPattern, Var, convert_strs_to_vars
from pytensor.graph.utils import AssocList, InconsistencyError
from pytensor.misc.ordered_set import OrderedSet
from pytensor.utils import flatten
......@@ -1312,6 +1312,7 @@ class PatternNodeRewriter(NodeRewriter):
The input and output patterns have the following syntax:
input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...)
input_pattern ::= (OpPattern(type(op), {<param>: <value>, ...}), <sub_pattern1>, <sub_pattern2>, ...)
input_pattern ::= dict(pattern = <input_pattern>,
constraint = <constraint>)
sub_pattern ::= input_pattern
......@@ -1325,6 +1326,7 @@ class PatternNodeRewriter(NodeRewriter):
output_pattern ::= string
output_pattern ::= int
output_pattern ::= float
output_pattern ::= callable
Each string in the input pattern is a variable that will be set to
whatever expression is found in its place. If the same string is
......@@ -1350,20 +1352,73 @@ class PatternNodeRewriter(NodeRewriter):
Examples
--------
PatternNodeRewriter((add, 'x', 'y'), (add, 'y', 'x'))
PatternNodeRewriter((multiply, 'x', 'x'), (square, 'x'))
PatternNodeRewriter((subtract, (add, 'x', 'y'), 'y'), 'x')
PatternNodeRewriter((power, 'x', Constant(double, 2.0)), (square, 'x'))
PatternNodeRewriter((boggle, {'pattern': 'x',
'constraint': lambda expr: expr.type == scrabble}),
(scrabble, 'x'))
.. code-block:: python
from pytensor.graph.rewriting.basic import PatternNodeRewriter
from pytensor.tensor import add, mul, sub, pow, square
PatternNodeRewriter((add, "x", "y"), (add, "y", "x"))
PatternNodeRewriter((mul, "x", "x"), (square, "x"))
PatternNodeRewriter((sub, (add, "x", "y"), "y"), "x")
PatternNodeRewriter((pow, "x", 2.0), (square, "x"))
PatternNodeRewriter(
(mul, {"pattern": "x", "constraint": lambda expr: expr.ndim == 0}, "y"),
(mul, "y", "x"),
)
You can use OpPattern to match a subtype of an Op, with some parameter constraints
You can also specify a callable as the output pattern, which will be called with (fgraph, node, subs_dict) as arguments.
.. code-block:: python
from pytensor.graph.rewriting.basic import PatternNodeRewriter
from pytensor.graph.rewriting.unify import OpPattern
from pytensor.tensor.basic import Join
from pytensor.tensor.elemwise import CAReduce, Elemwise
def output_fn(fgraph, node, s):
reduce_op = node.op
reduced_a = reduce_op(s["a"])
reduced_b = reduce_op(s["b"])
return Elemwise(s["scalar_op"])(reduced_a, reduced_b)
PatternNodeRewriter(
(
OpPattern(CAReduce, scalar_op="scalar_op", axis=None),
(Join(), "join_axis", "a", "b"),
),
output_fn,
)
If you want to test a string parameter, you must use LiteralString to avoid it being interpreted as a unification variable.
.. code-block:: python
from pytensor.graph.rewriting.basic import PatternNodeRewriter
from pytensor.graph.rewriting.unify import OpPattern, LiteralString
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.slinalg import Solve
PatternNodeRewriter(
(
OpPattern(
Blockwise, core_op=OpPattern(Solve, assume_a=LiteralString("gen"))
),
"A",
"b",
)
)
"""
def __init__(
self,
in_pattern,
out_pattern,
in_pattern: tuple,
out_pattern: tuple | Callable | str,
allow_multiple_clients: bool = False,
name: str | None = None,
tracks=(),
......@@ -1378,7 +1433,8 @@ class PatternNodeRewriter(NodeRewriter):
in_pattern
The input pattern that we want to replace.
out_pattern
The replacement pattern.
The replacement pattern. Or a callable that takes (fgraph, node, subs_dict) as inputs,
and returns the replacement variable (or None/False to reject the rewrite).
allow_multiple_clients
If ``False``, the pattern matching will fail if one of the subpatterns has
more than one client.
......@@ -1407,26 +1463,40 @@ class PatternNodeRewriter(NodeRewriter):
self.out_pattern = convert_strs_to_vars(out_pattern, var_map=var_map)
self.values_eq_approx = values_eq_approx
self.allow_cast = allow_cast
if isinstance(in_pattern, list | tuple):
self.op = self.in_pattern[0]
elif isinstance(in_pattern, dict):
self.op = self.in_pattern["pattern"][0]
else:
raise TypeError(
"The pattern to search for must start with a specific Op instance."
)
self.allow_multiple_clients = allow_multiple_clients
if name:
self.__name__ = name
self._tracks = tracks
self.get_nodes = get_nodes
if tracks != ():
assert get_nodes
if not get_nodes:
raise ValueError("Custom `tracks` requires `get_nodes` to be provided.")
self._tracks = tracks
else:
if isinstance(in_pattern, list | tuple):
op = self.in_pattern[0]
elif isinstance(in_pattern, dict):
op = self.in_pattern["pattern"][0]
else:
raise TypeError(
f"The in_pattern must be a sequence or a dict, but got {in_pattern} of type {type(in_pattern)}"
)
if isinstance(op, Op):
self._tracks = [op]
elif isinstance(op, type) and issubclass(op, Op):
raise ValueError(
f"The in_pattern starts with an Op class {op}, not an instance.\n"
"You can use pytensor.graph.unify.OpPattern instead if you want to match instances of a class."
)
elif isinstance(op, OpPattern):
self._tracks = [op.op_type]
else:
raise ValueError(
f"The in_pattern must start with a specific Op or an OpPattern instance. "
f"Got {op}, with type {type(op)}."
)
def tracks(self):
if self._tracks != ():
return self._tracks
return [self.op]
return self._tracks
def transform(self, fgraph, node, get_nodes=True):
"""Check if the graph from node corresponds to ``in_pattern``.
......@@ -1447,28 +1517,39 @@ class PatternNodeRewriter(NodeRewriter):
# PatternNodeRewriter doesn't support replacing multi-output nodes
return False
s = unify(self.in_pattern, node.out)
s = unify(self.in_pattern, node.out, {})
if s is False:
return False
ret = reify(self.out_pattern, s)
if isinstance(ret, ExpressionTuple):
ret = ret.evaled_obj
if self.values_eq_approx:
ret.tag.values_eq_approx = self.values_eq_approx
if not self.allow_multiple_clients:
input_vars = list(s.values())
input_vars = set(s.values())
clients = fgraph.clients
if any(
len(fgraph.clients[v]) > 1
len(clients[v]) > 1
for v in vars_between(input_vars, node.inputs)
if v not in input_vars
):
return False
if callable(self.out_pattern):
# token is the variable name used in the original pattern
ret = self.out_pattern(fgraph, node, {k.token: v for k, v in s.items()})
if ret is None or ret is False:
# The output function is still allowed to reject the rewrite
return False
if not isinstance(ret, Variable):
raise ValueError(
f"The output of the PatternNodeRewriter callable must be a variable got {ret} of type {type(ret)}."
)
else:
ret = reify(self.out_pattern, s)
if isinstance(ret, ExpressionTuple):
ret = ret.evaled_obj
if self.values_eq_approx:
ret.tag.values_eq_approx = self.values_eq_approx
[old_out] = node.outputs
if not old_out.type.is_super(ret.type):
from pytensor.tensor.type import TensorType
......
......@@ -86,7 +86,7 @@ class KanrenRelationSub(NodeRewriter):
q = var()
kanren_results = run(None, q, self.kanren_relation(input_expr, q))
chosen_res = self.results_filter(kanren_results)
chosen_res = self.results_filter(kanren_results) # type: ignore[arg-type]
if chosen_res:
if isinstance(chosen_res, list):
......
......@@ -10,8 +10,11 @@ that satisfies the constraints. That's useful for pattern matching.
"""
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from numbers import Number
from types import UnionType
from typing import Any, TypeAlias
import numpy as np
from cons.core import ConsError, _car, _cdr
......@@ -254,6 +257,164 @@ def _unify_ConstrainedVar_object(u, v, s):
_unify.add((object, ConstrainedVar, Mapping), _unify_ConstrainedVar_object)
@dataclass(frozen=True)
class LiteralString:
value: str
OpPatternOpTypeType: TypeAlias = type[Op] | tuple[type[Op], ...] | UnionType
@dataclass(unsafe_hash=True)
class OpPattern:
"""Class that can be unified with Op instances of a given type (or instance) and parameters.
Parameters that are not specified in the OpPattern are ignored during unification.
This is needed because some Ops can be complex to parametrize fully,
and not all parameters are relevant for a given pattern.
Examples
--------
OpPattern can be used with `PatternNodeRewriter` to define graph rewrites that match Ops with specific parameters.
The example below matches two nested CAReduce Ops with the same `scalar_op`,
the outer with `axis=None` (full reduction) and fuses them into a single CAReduce.
Note, that because we didn't specify it, the axis of the inner CAReduce can be anything.
The same goes for other properties of the Op that are not specified in the OpPattern.
.. testcode::
from pytensor.graph.rewriting.basic import PatternNodeRewriter
from pytensor.graph.rewriting.unify import OpPattern
from pytensor.tensor.basic import Join
from pytensor.tensor.elemwise import CAReduce, Elemwise
def output_fn(fgraph, node, s):
reduce_op = node.op
reduced_a = reduce_op(s["a"])
reduced_b = reduce_op(s["b"])
return Elemwise(s["scalar_op"])(reduced_a, reduced_b)
PatternNodeRewriter(
in_pattern=(OpPattern(CAReduce, scalar_op="scalar_op", axis=None),
(OpPattern(CAReduce, scalar_op="scalar_op",), "x")),
out_pattern=(OpPattern(CAReduce, scalar_op="scalar_op", axis=None), "x"),
)
OpPattern can also be used with `unification.unify` to match Ops with specific parameters.
This is used by PatternNodeRewriter but can also be used directly.
.. testcode::
from unification import var, unify
from etuples import etuple
import pytensor.tensor as pt
from pytensor.graph.rewriting.unify import OpPattern
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.slinalg import Solve
A = var("A")
b = var("b")
pattern = etuple(
OpPattern(Blockwise, core_op=OpPattern(Solve, assume_a="gen")),
A,
b,
)
A_pt = pt.tensor3("A")
b_pt = pt.tensor3("b")
out1 = pt.linalg.solve(A_pt, b_pt)
out2 = pt.linalg.solve(A_pt, b_pt, assume_a="pos")
assert unify(pattern, out1) == {A: A_pt, b: b_pt}
assert unify(pattern, out2) is False
assume_a = var("assume_a")
pattern = etuple(
OpPattern(Blockwise, core_op=OpPattern(Solve, assume_a=assume_a)),
A,
b,
)
assert unify(pattern, out1) == {A: A_pt, b: b_pt, assume_a: "gen"}
assert unify(pattern, out2) == {A: A_pt, b: b_pt, assume_a: "pos"}
"""
op_type: OpPatternOpTypeType
parameters: tuple[tuple[str, Any]]
def __init__(
self,
op_type: OpPatternOpTypeType,
parameters: dict[str, Any] | Sequence[tuple[str, Any]] | None = None,
**kwargs,
):
if kwargs:
if parameters is not None:
raise ValueError(
"Cannot provide both parameters dict and keyword arguments"
)
parameters = kwargs
if isinstance(parameters, dict):
parameters = tuple(sorted(parameters.items()))
elif isinstance(parameters, list | tuple):
parameters = tuple(sorted(parameters))
elif parameters is None:
parameters = ()
self.op_type = op_type
self.parameters = parameters # type: ignore[assignment]
def match_op(self, op: Op):
if not isinstance(op, self.op_type):
return False
return self.match_parameters(op)
def match_parameters(self, op):
# This is used by methods that already check the op_type is satisfied
# Some methods may index on the op_type and know in advance the op is matched
# Also recursive calls to OpPattern.match_parameters do the op check outside to exit early (see below)
for key, param in self.parameters:
if isinstance(param, OpPattern):
# Parameters can itself be other OpPatterns
# We check the op_type to avoid a nested call in cases we can reject early
sub_op = getattr(op, key)
if not isinstance(sub_op, param.op_type):
return False
# Match the pattern of the inner Op
# Skip if there are no parameters
if param.parameters and not param.match_parameters(sub_op):
return False
elif getattr(op, key) != param:
return False
return True
def __str__(self):
return f"OpPattern({self.op_type}, {', '.join(f'{k}={v}' for k, v in self.parameters)})"
def _unify_parametrized_op(v: Op, u: OpPattern, s: Mapping):
if not isinstance(v, u.op_type):
yield False
return
for parameter_key, parameter_pattern in u.parameters:
parameter_value = getattr(v, parameter_key)
new_s = yield _unify(parameter_value, parameter_pattern, s)
if new_s is False:
yield False
return
s = new_s
yield s
_unify.add((Op, OpPattern, Mapping), _unify_parametrized_op)
def convert_strs_to_vars(
x: tuple | str | dict, var_map: dict[str, Var] | None = None
) -> ExpressionTuple | Var:
......@@ -266,11 +427,13 @@ def convert_strs_to_vars(
if var_map is None:
var_map = {}
def _convert(y):
def _convert(y, op_prop=False):
if isinstance(y, str):
v = var_map.get(y, var(y))
var_map[y] = v
return v
if isinstance(y, LiteralString):
return y.value
elif isinstance(y, dict):
pattern = y["pattern"]
if not isinstance(pattern, str):
......@@ -282,8 +445,14 @@ def convert_strs_to_vars(
var_map[pattern] = v
return v
elif isinstance(y, tuple):
return etuple(*(_convert(e) for e in y))
elif isinstance(y, Number | np.ndarray):
return etuple(*(_convert(e, op_prop=op_prop) for e in y))
elif isinstance(y, OpPattern):
return OpPattern(
y.op_type,
{k: _convert(v, op_prop=True) for k, v in y.parameters},
)
elif (not op_prop) and isinstance(y, Number | np.ndarray):
# If we are converting an Op property, we don't want to convert numbers to PyTensor constants
from pytensor.tensor import as_tensor_variable
return as_tensor_variable(y)
......
......@@ -18,6 +18,7 @@ from pytensor.graph.rewriting.basic import (
pre_constant_merge,
pre_greedy_node_rewriter,
)
from pytensor.graph.rewriting.unify import LiteralString, OpPattern
from pytensor.raise_op import assert_op
from pytensor.tensor.math import Dot, add, dot, exp
from pytensor.tensor.rewriting.basic import constant_folding
......@@ -283,6 +284,42 @@ class TestPatternNodeRewriter:
str_g = str(g)
assert str_g == "FunctionGraph(Op4(z, y))"
def test_op_pattern(self):
a = MyVariable("a")
e1 = MyOp(name="MyOp(x=1)", x=1)(a)
e2 = MyOp(name="MyOp(x=2)", x=2)(a)
e_hello = MyOp(name="MyOp(x='hello')", x="hello")(a)
op_x3 = MyOp(name="MyOp(x=3)", x=3)
assert not equal_computations([e1], [op_x3(a)])
assert not equal_computations([e2], [op_x3(a)])
rewriter = WalkingPatternNodeRewriter(
(OpPattern(MyOp, x=1), "a"),
"a",
)
g = FunctionGraph([a], [e1, e2, e1], copy_inputs=False)
rewriter.rewrite(g)
assert equal_computations(g.outputs, [a, e2, a])
rewriter = WalkingPatternNodeRewriter(
(OpPattern(MyOp, x="x"), "a"),
lambda fgraph, node, subs: (
MyOp(name="MyOp(x+=10)", x=subs["x"] + 10)(subs["a"])
if subs["x"] < 10
else False
),
)
g = FunctionGraph([a], [e1], copy_inputs=False)
rewriter.rewrite(g)
assert equal_computations(g.outputs, [MyOp(name="x=11", x=11)(a)])
rewriter = WalkingPatternNodeRewriter(
(OpPattern(MyOp, x=LiteralString("hello")), "a"), "a"
)
g = FunctionGraph([a], [e1, e_hello], copy_inputs=False)
rewriter.rewrite(g)
assert equal_computations(g.outputs, [e1, a])
class NoInputOp(Op):
__props__ = ("param",)
......
......@@ -11,7 +11,11 @@ import pytensor.scalar as ps
import pytensor.tensor as pt
from pytensor.graph.basic import Apply, Constant, equal_computations
from pytensor.graph.op import Op
from pytensor.graph.rewriting.unify import ConstrainedVar, convert_strs_to_vars
from pytensor.graph.rewriting.unify import (
ConstrainedVar,
OpPattern,
convert_strs_to_vars,
)
from pytensor.tensor.type import TensorType
from tests.graph.utils import MyType
......@@ -348,3 +352,25 @@ def test_convert_strs_to_vars():
res = convert_strs_to_vars((val,))
assert isinstance(res[0], Constant)
assert np.array_equal(res[0].data, val)
def test_unify_OpPattern():
x_pt = MyType()("x_pt")
y_pt = MyType()("y_pt")
out1 = CustomOp(a=1)(x_pt, y_pt)
out2 = CustomOp(a=2)(x_pt, y_pt)
x = var("x")
y = var("y")
pattern = etuple(OpPattern(CustomOp), x, y)
assert unify(pattern, out1) == {x: x_pt, y: y_pt}
assert unify(pattern, out2) == {x: x_pt, y: y_pt}
pattern = etuple(OpPattern(CustomOp, a=1), x, y)
assert unify(pattern, out1) == {x: x_pt, y: y_pt}
assert unify(pattern, out2) is False
a = var("a")
pattern = etuple(OpPattern(CustomOp, a=a), x, y)
assert unify(pattern, out1) == {x: x_pt, y: y_pt, a: 1}
assert unify(pattern, out2) == {x: x_pt, y: y_pt, a: 2}
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论