提交 085f2723 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Allow unifying with OpPattern

上级 51ab571a
...@@ -29,7 +29,7 @@ from pytensor.graph.basic import ( ...@@ -29,7 +29,7 @@ from pytensor.graph.basic import (
from pytensor.graph.features import AlreadyThere, Feature from pytensor.graph.features import AlreadyThere, Feature
from pytensor.graph.fg import FunctionGraph, Output from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.rewriting.unify import Var, convert_strs_to_vars from pytensor.graph.rewriting.unify import OpPattern, Var, convert_strs_to_vars
from pytensor.graph.utils import AssocList, InconsistencyError from pytensor.graph.utils import AssocList, InconsistencyError
from pytensor.misc.ordered_set import OrderedSet from pytensor.misc.ordered_set import OrderedSet
from pytensor.utils import flatten from pytensor.utils import flatten
...@@ -1312,6 +1312,7 @@ class PatternNodeRewriter(NodeRewriter): ...@@ -1312,6 +1312,7 @@ class PatternNodeRewriter(NodeRewriter):
The input and output patterns have the following syntax: The input and output patterns have the following syntax:
input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...) input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...)
input_pattern ::= (OpPattern(type(op), {<param>: <value>, ...}), <sub_pattern1>, <sub_pattern2>, ...)
input_pattern ::= dict(pattern = <input_pattern>, input_pattern ::= dict(pattern = <input_pattern>,
constraint = <constraint>) constraint = <constraint>)
sub_pattern ::= input_pattern sub_pattern ::= input_pattern
...@@ -1325,6 +1326,7 @@ class PatternNodeRewriter(NodeRewriter): ...@@ -1325,6 +1326,7 @@ class PatternNodeRewriter(NodeRewriter):
output_pattern ::= string output_pattern ::= string
output_pattern ::= int output_pattern ::= int
output_pattern ::= float output_pattern ::= float
output_pattern ::= callable
Each string in the input pattern is a variable that will be set to Each string in the input pattern is a variable that will be set to
whatever expression is found in its place. If the same string is whatever expression is found in its place. If the same string is
...@@ -1350,20 +1352,73 @@ class PatternNodeRewriter(NodeRewriter): ...@@ -1350,20 +1352,73 @@ class PatternNodeRewriter(NodeRewriter):
Examples Examples
-------- --------
PatternNodeRewriter((add, 'x', 'y'), (add, 'y', 'x')) .. code-block:: python
PatternNodeRewriter((multiply, 'x', 'x'), (square, 'x'))
PatternNodeRewriter((subtract, (add, 'x', 'y'), 'y'), 'x')
PatternNodeRewriter((power, 'x', Constant(double, 2.0)), (square, 'x'))
PatternNodeRewriter((boggle, {'pattern': 'x',
'constraint': lambda expr: expr.type == scrabble}),
(scrabble, 'x'))
from pytensor.graph.rewriting.basic import PatternNodeRewriter
from pytensor.tensor import add, mul, sub, pow, square
PatternNodeRewriter((add, "x", "y"), (add, "y", "x"))
PatternNodeRewriter((mul, "x", "x"), (square, "x"))
PatternNodeRewriter((sub, (add, "x", "y"), "y"), "x")
PatternNodeRewriter((pow, "x", 2.0), (square, "x"))
PatternNodeRewriter(
(mul, {"pattern": "x", "constraint": lambda expr: expr.ndim == 0}, "y"),
(mul, "y", "x"),
)
You can use OpPattern to match a subtype of an Op, with some parameter constraints
You can also specify a callable as the output pattern, which will be called with (fgraph, node, subs_dict) as arguments.
.. code-block:: python
from pytensor.graph.rewriting.basic import PatternNodeRewriter
from pytensor.graph.rewriting.unify import OpPattern
from pytensor.tensor.basic import Join
from pytensor.tensor.elemwise import CAReduce, Elemwise
def output_fn(fgraph, node, s):
reduce_op = node.op
reduced_a = reduce_op(s["a"])
reduced_b = reduce_op(s["b"])
return Elemwise(s["scalar_op"])(reduced_a, reduced_b)
PatternNodeRewriter(
(
OpPattern(CAReduce, scalar_op="scalar_op", axis=None),
(Join(), "join_axis", "a", "b"),
),
output_fn,
)
If you want to test a string parameter, you must use LiteralString to avoid it being interpreted as a unification variable.
.. code-block:: python
from pytensor.graph.rewriting.basic import PatternNodeRewriter
from pytensor.graph.rewriting.unify import OpPattern, LiteralString
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.slinalg import Solve
PatternNodeRewriter(
(
OpPattern(
Blockwise, core_op=OpPattern(Solve, assume_a=LiteralString("gen"))
),
"A",
"b",
)
)
""" """
def __init__( def __init__(
self, self,
in_pattern, in_pattern: tuple,
out_pattern, out_pattern: tuple | Callable | str,
allow_multiple_clients: bool = False, allow_multiple_clients: bool = False,
name: str | None = None, name: str | None = None,
tracks=(), tracks=(),
...@@ -1378,7 +1433,8 @@ class PatternNodeRewriter(NodeRewriter): ...@@ -1378,7 +1433,8 @@ class PatternNodeRewriter(NodeRewriter):
in_pattern in_pattern
The input pattern that we want to replace. The input pattern that we want to replace.
out_pattern out_pattern
The replacement pattern. The replacement pattern. Or a callable that takes (fgraph, node, subs_dict) as inputs,
and returns the replacement variable (or None/False to reject the rewrite).
allow_multiple_clients allow_multiple_clients
If ``False``, the pattern matching will fail if one of the subpatterns has If ``False``, the pattern matching will fail if one of the subpatterns has
more than one client. more than one client.
...@@ -1407,26 +1463,40 @@ class PatternNodeRewriter(NodeRewriter): ...@@ -1407,26 +1463,40 @@ class PatternNodeRewriter(NodeRewriter):
self.out_pattern = convert_strs_to_vars(out_pattern, var_map=var_map) self.out_pattern = convert_strs_to_vars(out_pattern, var_map=var_map)
self.values_eq_approx = values_eq_approx self.values_eq_approx = values_eq_approx
self.allow_cast = allow_cast self.allow_cast = allow_cast
if isinstance(in_pattern, list | tuple):
self.op = self.in_pattern[0]
elif isinstance(in_pattern, dict):
self.op = self.in_pattern["pattern"][0]
else:
raise TypeError(
"The pattern to search for must start with a specific Op instance."
)
self.allow_multiple_clients = allow_multiple_clients self.allow_multiple_clients = allow_multiple_clients
if name: if name:
self.__name__ = name self.__name__ = name
self._tracks = tracks
self.get_nodes = get_nodes self.get_nodes = get_nodes
if tracks != (): if tracks != ():
assert get_nodes if not get_nodes:
raise ValueError("Custom `tracks` requires `get_nodes` to be provided.")
self._tracks = tracks
else:
if isinstance(in_pattern, list | tuple):
op = self.in_pattern[0]
elif isinstance(in_pattern, dict):
op = self.in_pattern["pattern"][0]
else:
raise TypeError(
f"The in_pattern must be a sequence or a dict, but got {in_pattern} of type {type(in_pattern)}"
)
if isinstance(op, Op):
self._tracks = [op]
elif isinstance(op, type) and issubclass(op, Op):
raise ValueError(
f"The in_pattern starts with an Op class {op}, not an instance.\n"
"You can use pytensor.graph.unify.OpPattern instead if you want to match instances of a class."
)
elif isinstance(op, OpPattern):
self._tracks = [op.op_type]
else:
raise ValueError(
f"The in_pattern must start with a specific Op or an OpPattern instance. "
f"Got {op}, with type {type(op)}."
)
def tracks(self): def tracks(self):
if self._tracks != (): return self._tracks
return self._tracks
return [self.op]
def transform(self, fgraph, node, get_nodes=True): def transform(self, fgraph, node, get_nodes=True):
"""Check if the graph from node corresponds to ``in_pattern``. """Check if the graph from node corresponds to ``in_pattern``.
...@@ -1447,28 +1517,39 @@ class PatternNodeRewriter(NodeRewriter): ...@@ -1447,28 +1517,39 @@ class PatternNodeRewriter(NodeRewriter):
# PatternNodeRewriter doesn't support replacing multi-output nodes # PatternNodeRewriter doesn't support replacing multi-output nodes
return False return False
s = unify(self.in_pattern, node.out) s = unify(self.in_pattern, node.out, {})
if s is False: if s is False:
return False return False
ret = reify(self.out_pattern, s)
if isinstance(ret, ExpressionTuple):
ret = ret.evaled_obj
if self.values_eq_approx:
ret.tag.values_eq_approx = self.values_eq_approx
if not self.allow_multiple_clients: if not self.allow_multiple_clients:
input_vars = list(s.values()) input_vars = set(s.values())
clients = fgraph.clients
if any( if any(
len(fgraph.clients[v]) > 1 len(clients[v]) > 1
for v in vars_between(input_vars, node.inputs) for v in vars_between(input_vars, node.inputs)
if v not in input_vars if v not in input_vars
): ):
return False return False
if callable(self.out_pattern):
# token is the variable name used in the original pattern
ret = self.out_pattern(fgraph, node, {k.token: v for k, v in s.items()})
if ret is None or ret is False:
# The output function is still allowed to reject the rewrite
return False
if not isinstance(ret, Variable):
raise ValueError(
f"The output of the PatternNodeRewriter callable must be a variable got {ret} of type {type(ret)}."
)
else:
ret = reify(self.out_pattern, s)
if isinstance(ret, ExpressionTuple):
ret = ret.evaled_obj
if self.values_eq_approx:
ret.tag.values_eq_approx = self.values_eq_approx
[old_out] = node.outputs [old_out] = node.outputs
if not old_out.type.is_super(ret.type): if not old_out.type.is_super(ret.type):
from pytensor.tensor.type import TensorType from pytensor.tensor.type import TensorType
......
...@@ -86,7 +86,7 @@ class KanrenRelationSub(NodeRewriter): ...@@ -86,7 +86,7 @@ class KanrenRelationSub(NodeRewriter):
q = var() q = var()
kanren_results = run(None, q, self.kanren_relation(input_expr, q)) kanren_results = run(None, q, self.kanren_relation(input_expr, q))
chosen_res = self.results_filter(kanren_results) chosen_res = self.results_filter(kanren_results) # type: ignore[arg-type]
if chosen_res: if chosen_res:
if isinstance(chosen_res, list): if isinstance(chosen_res, list):
......
...@@ -10,8 +10,11 @@ that satisfies the constraints. That's useful for pattern matching. ...@@ -10,8 +10,11 @@ that satisfies the constraints. That's useful for pattern matching.
""" """
from collections.abc import Mapping from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from numbers import Number from numbers import Number
from types import UnionType
from typing import Any, TypeAlias
import numpy as np import numpy as np
from cons.core import ConsError, _car, _cdr from cons.core import ConsError, _car, _cdr
...@@ -254,6 +257,164 @@ def _unify_ConstrainedVar_object(u, v, s): ...@@ -254,6 +257,164 @@ def _unify_ConstrainedVar_object(u, v, s):
_unify.add((object, ConstrainedVar, Mapping), _unify_ConstrainedVar_object) _unify.add((object, ConstrainedVar, Mapping), _unify_ConstrainedVar_object)
@dataclass(frozen=True)
class LiteralString:
value: str
OpPatternOpTypeType: TypeAlias = type[Op] | tuple[type[Op], ...] | UnionType
@dataclass(unsafe_hash=True)
class OpPattern:
"""Class that can be unified with Op instances of a given type (or instance) and parameters.
Parameters that are not specified in the OpPattern are ignored during unification.
This is needed because some Ops can be complex to parametrize fully,
and not all parameters are relevant for a given pattern.
Examples
--------
OpPattern can be used with `PatternNodeRewriter` to define graph rewrites that match Ops with specific parameters.
The example below matches two nested CAReduce Ops with the same `scalar_op`,
the outer with `axis=None` (full reduction) and fuses them into a single CAReduce.
Note, that because we didn't specify it, the axis of the inner CAReduce can be anything.
The same goes for other properties of the Op that are not specified in the OpPattern.
.. testcode::
from pytensor.graph.rewriting.basic import PatternNodeRewriter
from pytensor.graph.rewriting.unify import OpPattern
from pytensor.tensor.basic import Join
from pytensor.tensor.elemwise import CAReduce, Elemwise
def output_fn(fgraph, node, s):
reduce_op = node.op
reduced_a = reduce_op(s["a"])
reduced_b = reduce_op(s["b"])
return Elemwise(s["scalar_op"])(reduced_a, reduced_b)
PatternNodeRewriter(
in_pattern=(OpPattern(CAReduce, scalar_op="scalar_op", axis=None),
(OpPattern(CAReduce, scalar_op="scalar_op",), "x")),
out_pattern=(OpPattern(CAReduce, scalar_op="scalar_op", axis=None), "x"),
)
OpPattern can also be used with `unification.unify` to match Ops with specific parameters.
This is used by PatternNodeRewriter but can also be used directly.
.. testcode::
from unification import var, unify
from etuples import etuple
import pytensor.tensor as pt
from pytensor.graph.rewriting.unify import OpPattern
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.slinalg import Solve
A = var("A")
b = var("b")
pattern = etuple(
OpPattern(Blockwise, core_op=OpPattern(Solve, assume_a="gen")),
A,
b,
)
A_pt = pt.tensor3("A")
b_pt = pt.tensor3("b")
out1 = pt.linalg.solve(A_pt, b_pt)
out2 = pt.linalg.solve(A_pt, b_pt, assume_a="pos")
assert unify(pattern, out1) == {A: A_pt, b: b_pt}
assert unify(pattern, out2) is False
assume_a = var("assume_a")
pattern = etuple(
OpPattern(Blockwise, core_op=OpPattern(Solve, assume_a=assume_a)),
A,
b,
)
assert unify(pattern, out1) == {A: A_pt, b: b_pt, assume_a: "gen"}
assert unify(pattern, out2) == {A: A_pt, b: b_pt, assume_a: "pos"}
"""
op_type: OpPatternOpTypeType
parameters: tuple[tuple[str, Any]]
def __init__(
self,
op_type: OpPatternOpTypeType,
parameters: dict[str, Any] | Sequence[tuple[str, Any]] | None = None,
**kwargs,
):
if kwargs:
if parameters is not None:
raise ValueError(
"Cannot provide both parameters dict and keyword arguments"
)
parameters = kwargs
if isinstance(parameters, dict):
parameters = tuple(sorted(parameters.items()))
elif isinstance(parameters, list | tuple):
parameters = tuple(sorted(parameters))
elif parameters is None:
parameters = ()
self.op_type = op_type
self.parameters = parameters # type: ignore[assignment]
def match_op(self, op: Op):
if not isinstance(op, self.op_type):
return False
return self.match_parameters(op)
def match_parameters(self, op):
# This is used by methods that already check the op_type is satisfied
# Some methods may index on the op_type and know in advance the op is matched
# Also recursive calls to OpPattern.match_parameters do the op check outside to exit early (see below)
for key, param in self.parameters:
if isinstance(param, OpPattern):
# Parameters can itself be other OpPatterns
# We check the op_type to avoid a nested call in cases we can reject early
sub_op = getattr(op, key)
if not isinstance(sub_op, param.op_type):
return False
# Match the pattern of the inner Op
# Skip if there are no parameters
if param.parameters and not param.match_parameters(sub_op):
return False
elif getattr(op, key) != param:
return False
return True
def __str__(self):
return f"OpPattern({self.op_type}, {', '.join(f'{k}={v}' for k, v in self.parameters)})"
def _unify_parametrized_op(v: Op, u: OpPattern, s: Mapping):
if not isinstance(v, u.op_type):
yield False
return
for parameter_key, parameter_pattern in u.parameters:
parameter_value = getattr(v, parameter_key)
new_s = yield _unify(parameter_value, parameter_pattern, s)
if new_s is False:
yield False
return
s = new_s
yield s
_unify.add((Op, OpPattern, Mapping), _unify_parametrized_op)
def convert_strs_to_vars( def convert_strs_to_vars(
x: tuple | str | dict, var_map: dict[str, Var] | None = None x: tuple | str | dict, var_map: dict[str, Var] | None = None
) -> ExpressionTuple | Var: ) -> ExpressionTuple | Var:
...@@ -266,11 +427,13 @@ def convert_strs_to_vars( ...@@ -266,11 +427,13 @@ def convert_strs_to_vars(
if var_map is None: if var_map is None:
var_map = {} var_map = {}
def _convert(y): def _convert(y, op_prop=False):
if isinstance(y, str): if isinstance(y, str):
v = var_map.get(y, var(y)) v = var_map.get(y, var(y))
var_map[y] = v var_map[y] = v
return v return v
if isinstance(y, LiteralString):
return y.value
elif isinstance(y, dict): elif isinstance(y, dict):
pattern = y["pattern"] pattern = y["pattern"]
if not isinstance(pattern, str): if not isinstance(pattern, str):
...@@ -282,8 +445,14 @@ def convert_strs_to_vars( ...@@ -282,8 +445,14 @@ def convert_strs_to_vars(
var_map[pattern] = v var_map[pattern] = v
return v return v
elif isinstance(y, tuple): elif isinstance(y, tuple):
return etuple(*(_convert(e) for e in y)) return etuple(*(_convert(e, op_prop=op_prop) for e in y))
elif isinstance(y, Number | np.ndarray): elif isinstance(y, OpPattern):
return OpPattern(
y.op_type,
{k: _convert(v, op_prop=True) for k, v in y.parameters},
)
elif (not op_prop) and isinstance(y, Number | np.ndarray):
# If we are converting an Op property, we don't want to convert numbers to PyTensor constants
from pytensor.tensor import as_tensor_variable from pytensor.tensor import as_tensor_variable
return as_tensor_variable(y) return as_tensor_variable(y)
......
...@@ -18,6 +18,7 @@ from pytensor.graph.rewriting.basic import ( ...@@ -18,6 +18,7 @@ from pytensor.graph.rewriting.basic import (
pre_constant_merge, pre_constant_merge,
pre_greedy_node_rewriter, pre_greedy_node_rewriter,
) )
from pytensor.graph.rewriting.unify import LiteralString, OpPattern
from pytensor.raise_op import assert_op from pytensor.raise_op import assert_op
from pytensor.tensor.math import Dot, add, dot, exp from pytensor.tensor.math import Dot, add, dot, exp
from pytensor.tensor.rewriting.basic import constant_folding from pytensor.tensor.rewriting.basic import constant_folding
...@@ -283,6 +284,42 @@ class TestPatternNodeRewriter: ...@@ -283,6 +284,42 @@ class TestPatternNodeRewriter:
str_g = str(g) str_g = str(g)
assert str_g == "FunctionGraph(Op4(z, y))" assert str_g == "FunctionGraph(Op4(z, y))"
def test_op_pattern(self):
a = MyVariable("a")
e1 = MyOp(name="MyOp(x=1)", x=1)(a)
e2 = MyOp(name="MyOp(x=2)", x=2)(a)
e_hello = MyOp(name="MyOp(x='hello')", x="hello")(a)
op_x3 = MyOp(name="MyOp(x=3)", x=3)
assert not equal_computations([e1], [op_x3(a)])
assert not equal_computations([e2], [op_x3(a)])
rewriter = WalkingPatternNodeRewriter(
(OpPattern(MyOp, x=1), "a"),
"a",
)
g = FunctionGraph([a], [e1, e2, e1], copy_inputs=False)
rewriter.rewrite(g)
assert equal_computations(g.outputs, [a, e2, a])
rewriter = WalkingPatternNodeRewriter(
(OpPattern(MyOp, x="x"), "a"),
lambda fgraph, node, subs: (
MyOp(name="MyOp(x+=10)", x=subs["x"] + 10)(subs["a"])
if subs["x"] < 10
else False
),
)
g = FunctionGraph([a], [e1], copy_inputs=False)
rewriter.rewrite(g)
assert equal_computations(g.outputs, [MyOp(name="x=11", x=11)(a)])
rewriter = WalkingPatternNodeRewriter(
(OpPattern(MyOp, x=LiteralString("hello")), "a"), "a"
)
g = FunctionGraph([a], [e1, e_hello], copy_inputs=False)
rewriter.rewrite(g)
assert equal_computations(g.outputs, [e1, a])
class NoInputOp(Op): class NoInputOp(Op):
__props__ = ("param",) __props__ = ("param",)
......
...@@ -11,7 +11,11 @@ import pytensor.scalar as ps ...@@ -11,7 +11,11 @@ import pytensor.scalar as ps
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.graph.basic import Apply, Constant, equal_computations from pytensor.graph.basic import Apply, Constant, equal_computations
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.rewriting.unify import ConstrainedVar, convert_strs_to_vars from pytensor.graph.rewriting.unify import (
ConstrainedVar,
OpPattern,
convert_strs_to_vars,
)
from pytensor.tensor.type import TensorType from pytensor.tensor.type import TensorType
from tests.graph.utils import MyType from tests.graph.utils import MyType
...@@ -348,3 +352,25 @@ def test_convert_strs_to_vars(): ...@@ -348,3 +352,25 @@ def test_convert_strs_to_vars():
res = convert_strs_to_vars((val,)) res = convert_strs_to_vars((val,))
assert isinstance(res[0], Constant) assert isinstance(res[0], Constant)
assert np.array_equal(res[0].data, val) assert np.array_equal(res[0].data, val)
def test_unify_OpPattern():
x_pt = MyType()("x_pt")
y_pt = MyType()("y_pt")
out1 = CustomOp(a=1)(x_pt, y_pt)
out2 = CustomOp(a=2)(x_pt, y_pt)
x = var("x")
y = var("y")
pattern = etuple(OpPattern(CustomOp), x, y)
assert unify(pattern, out1) == {x: x_pt, y: y_pt}
assert unify(pattern, out2) == {x: x_pt, y: y_pt}
pattern = etuple(OpPattern(CustomOp, a=1), x, y)
assert unify(pattern, out1) == {x: x_pt, y: y_pt}
assert unify(pattern, out2) is False
a = var("a")
pattern = etuple(OpPattern(CustomOp, a=a), x, y)
assert unify(pattern, out1) == {x: x_pt, y: y_pt, a: 1}
assert unify(pattern, out2) == {x: x_pt, y: y_pt, a: 2}
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论