Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
085f2723
提交
085f2723
authored
8月 27, 2025
作者:
ricardoV94
提交者:
Ricardo Vieira
9月 20, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Allow unifying with OpPattern
上级
51ab571a
显示空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
353 行增加
和
40 行删除
+353
-40
basic.py
pytensor/graph/rewriting/basic.py
+115
-34
kanren.py
pytensor/graph/rewriting/kanren.py
+1
-1
unify.py
pytensor/graph/rewriting/unify.py
+173
-4
test_basic.py
tests/graph/rewriting/test_basic.py
+37
-0
test_unify.py
tests/graph/rewriting/test_unify.py
+27
-1
没有找到文件。
pytensor/graph/rewriting/basic.py
浏览文件 @
085f2723
...
...
@@ -29,7 +29,7 @@ from pytensor.graph.basic import (
from
pytensor.graph.features
import
AlreadyThere
,
Feature
from
pytensor.graph.fg
import
FunctionGraph
,
Output
from
pytensor.graph.op
import
Op
from
pytensor.graph.rewriting.unify
import
Var
,
convert_strs_to_vars
from
pytensor.graph.rewriting.unify
import
OpPattern
,
Var
,
convert_strs_to_vars
from
pytensor.graph.utils
import
AssocList
,
InconsistencyError
from
pytensor.misc.ordered_set
import
OrderedSet
from
pytensor.utils
import
flatten
...
...
@@ -1312,6 +1312,7 @@ class PatternNodeRewriter(NodeRewriter):
The input and output patterns have the following syntax:
input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...)
input_pattern ::= (OpPattern(type(op), {<param>: <value>, ...}), <sub_pattern1>, <sub_pattern2>, ...)
input_pattern ::= dict(pattern = <input_pattern>,
constraint = <constraint>)
sub_pattern ::= input_pattern
...
...
@@ -1325,6 +1326,7 @@ class PatternNodeRewriter(NodeRewriter):
output_pattern ::= string
output_pattern ::= int
output_pattern ::= float
output_pattern ::= callable
Each string in the input pattern is a variable that will be set to
whatever expression is found in its place. If the same string is
...
...
@@ -1350,20 +1352,73 @@ class PatternNodeRewriter(NodeRewriter):
Examples
--------
PatternNodeRewriter((add, 'x', 'y'), (add, 'y', 'x'))
PatternNodeRewriter((multiply, 'x', 'x'), (square, 'x'))
PatternNodeRewriter((subtract, (add, 'x', 'y'), 'y'), 'x')
PatternNodeRewriter((power, 'x', Constant(double, 2.0)), (square, 'x'))
PatternNodeRewriter((boggle, {'pattern': 'x',
'constraint': lambda expr: expr.type == scrabble}),
(scrabble, 'x'))
.. code-block:: python
from pytensor.graph.rewriting.basic import PatternNodeRewriter
from pytensor.tensor import add, mul, sub, pow, square
PatternNodeRewriter((add, "x", "y"), (add, "y", "x"))
PatternNodeRewriter((mul, "x", "x"), (square, "x"))
PatternNodeRewriter((sub, (add, "x", "y"), "y"), "x")
PatternNodeRewriter((pow, "x", 2.0), (square, "x"))
PatternNodeRewriter(
(mul, {"pattern": "x", "constraint": lambda expr: expr.ndim == 0}, "y"),
(mul, "y", "x"),
)
You can use OpPattern to match a subtype of an Op, with some parameter constraints
You can also specify a callable as the output pattern, which will be called with (fgraph, node, subs_dict) as arguments.
.. code-block:: python
from pytensor.graph.rewriting.basic import PatternNodeRewriter
from pytensor.graph.rewriting.unify import OpPattern
from pytensor.tensor.basic import Join
from pytensor.tensor.elemwise import CAReduce, Elemwise
def output_fn(fgraph, node, s):
reduce_op = node.op
reduced_a = reduce_op(s["a"])
reduced_b = reduce_op(s["b"])
return Elemwise(s["scalar_op"])(reduced_a, reduced_b)
PatternNodeRewriter(
(
OpPattern(CAReduce, scalar_op="scalar_op", axis=None),
(Join(), "join_axis", "a", "b"),
),
output_fn,
)
If you want to test a string parameter, you must use LiteralString to avoid it being interpreted as a unification variable.
.. code-block:: python
from pytensor.graph.rewriting.basic import PatternNodeRewriter
from pytensor.graph.rewriting.unify import OpPattern, LiteralString
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.slinalg import Solve
PatternNodeRewriter(
(
OpPattern(
Blockwise, core_op=OpPattern(Solve, assume_a=LiteralString("gen"))
),
"A",
"b",
)
)
"""
def
__init__
(
self
,
in_pattern
,
out_pattern
,
in_pattern
:
tuple
,
out_pattern
:
tuple
|
Callable
|
str
,
allow_multiple_clients
:
bool
=
False
,
name
:
str
|
None
=
None
,
tracks
=
(),
...
...
@@ -1378,7 +1433,8 @@ class PatternNodeRewriter(NodeRewriter):
in_pattern
The input pattern that we want to replace.
out_pattern
The replacement pattern.
The replacement pattern. Or a callable that takes (fgraph, node, subs_dict) as inputs,
and returns the replacement variable (or None/False to reject the rewrite).
allow_multiple_clients
If ``False``, the pattern matching will fail if one of the subpatterns has
more than one client.
...
...
@@ -1407,26 +1463,40 @@ class PatternNodeRewriter(NodeRewriter):
self
.
out_pattern
=
convert_strs_to_vars
(
out_pattern
,
var_map
=
var_map
)
self
.
values_eq_approx
=
values_eq_approx
self
.
allow_cast
=
allow_cast
if
isinstance
(
in_pattern
,
list
|
tuple
):
self
.
op
=
self
.
in_pattern
[
0
]
elif
isinstance
(
in_pattern
,
dict
):
self
.
op
=
self
.
in_pattern
[
"pattern"
][
0
]
else
:
raise
TypeError
(
"The pattern to search for must start with a specific Op instance."
)
self
.
allow_multiple_clients
=
allow_multiple_clients
if
name
:
self
.
__name__
=
name
self
.
_tracks
=
tracks
self
.
get_nodes
=
get_nodes
if
tracks
!=
():
assert
get_nodes
if
not
get_nodes
:
raise
ValueError
(
"Custom `tracks` requires `get_nodes` to be provided."
)
self
.
_tracks
=
tracks
else
:
if
isinstance
(
in_pattern
,
list
|
tuple
):
op
=
self
.
in_pattern
[
0
]
elif
isinstance
(
in_pattern
,
dict
):
op
=
self
.
in_pattern
[
"pattern"
][
0
]
else
:
raise
TypeError
(
f
"The in_pattern must be a sequence or a dict, but got {in_pattern} of type {type(in_pattern)}"
)
if
isinstance
(
op
,
Op
):
self
.
_tracks
=
[
op
]
elif
isinstance
(
op
,
type
)
and
issubclass
(
op
,
Op
):
raise
ValueError
(
f
"The in_pattern starts with an Op class {op}, not an instance.
\n
"
"You can use pytensor.graph.unify.OpPattern instead if you want to match instances of a class."
)
elif
isinstance
(
op
,
OpPattern
):
self
.
_tracks
=
[
op
.
op_type
]
else
:
raise
ValueError
(
f
"The in_pattern must start with a specific Op or an OpPattern instance. "
f
"Got {op}, with type {type(op)}."
)
def
tracks
(
self
):
if
self
.
_tracks
!=
():
return
self
.
_tracks
return
[
self
.
op
]
def
transform
(
self
,
fgraph
,
node
,
get_nodes
=
True
):
"""Check if the graph from node corresponds to ``in_pattern``.
...
...
@@ -1447,28 +1517,39 @@ class PatternNodeRewriter(NodeRewriter):
# PatternNodeRewriter doesn't support replacing multi-output nodes
return
False
s
=
unify
(
self
.
in_pattern
,
node
.
out
)
s
=
unify
(
self
.
in_pattern
,
node
.
out
,
{}
)
if
s
is
False
:
return
False
ret
=
reify
(
self
.
out_pattern
,
s
)
if
isinstance
(
ret
,
ExpressionTuple
):
ret
=
ret
.
evaled_obj
if
self
.
values_eq_approx
:
ret
.
tag
.
values_eq_approx
=
self
.
values_eq_approx
if
not
self
.
allow_multiple_clients
:
input_vars
=
list
(
s
.
values
())
input_vars
=
set
(
s
.
values
())
clients
=
fgraph
.
clients
if
any
(
len
(
fgraph
.
clients
[
v
])
>
1
len
(
clients
[
v
])
>
1
for
v
in
vars_between
(
input_vars
,
node
.
inputs
)
if
v
not
in
input_vars
):
return
False
if
callable
(
self
.
out_pattern
):
# token is the variable name used in the original pattern
ret
=
self
.
out_pattern
(
fgraph
,
node
,
{
k
.
token
:
v
for
k
,
v
in
s
.
items
()})
if
ret
is
None
or
ret
is
False
:
# The output function is still allowed to reject the rewrite
return
False
if
not
isinstance
(
ret
,
Variable
):
raise
ValueError
(
f
"The output of the PatternNodeRewriter callable must be a variable got {ret} of type {type(ret)}."
)
else
:
ret
=
reify
(
self
.
out_pattern
,
s
)
if
isinstance
(
ret
,
ExpressionTuple
):
ret
=
ret
.
evaled_obj
if
self
.
values_eq_approx
:
ret
.
tag
.
values_eq_approx
=
self
.
values_eq_approx
[
old_out
]
=
node
.
outputs
if
not
old_out
.
type
.
is_super
(
ret
.
type
):
from
pytensor.tensor.type
import
TensorType
...
...
pytensor/graph/rewriting/kanren.py
浏览文件 @
085f2723
...
...
@@ -86,7 +86,7 @@ class KanrenRelationSub(NodeRewriter):
q
=
var
()
kanren_results
=
run
(
None
,
q
,
self
.
kanren_relation
(
input_expr
,
q
))
chosen_res
=
self
.
results_filter
(
kanren_results
)
chosen_res
=
self
.
results_filter
(
kanren_results
)
# type: ignore[arg-type]
if
chosen_res
:
if
isinstance
(
chosen_res
,
list
):
...
...
pytensor/graph/rewriting/unify.py
浏览文件 @
085f2723
...
...
@@ -10,8 +10,11 @@ that satisfies the constraints. That's useful for pattern matching.
"""
from
collections.abc
import
Mapping
from
collections.abc
import
Mapping
,
Sequence
from
dataclasses
import
dataclass
from
numbers
import
Number
from
types
import
UnionType
from
typing
import
Any
,
TypeAlias
import
numpy
as
np
from
cons.core
import
ConsError
,
_car
,
_cdr
...
...
@@ -254,6 +257,164 @@ def _unify_ConstrainedVar_object(u, v, s):
_unify
.
add
((
object
,
ConstrainedVar
,
Mapping
),
_unify_ConstrainedVar_object
)
@dataclass
(
frozen
=
True
)
class
LiteralString
:
value
:
str
OpPatternOpTypeType
:
TypeAlias
=
type
[
Op
]
|
tuple
[
type
[
Op
],
...
]
|
UnionType
@dataclass
(
unsafe_hash
=
True
)
class
OpPattern
:
"""Class that can be unified with Op instances of a given type (or instance) and parameters.
Parameters that are not specified in the OpPattern are ignored during unification.
This is needed because some Ops can be complex to parametrize fully,
and not all parameters are relevant for a given pattern.
Examples
--------
OpPattern can be used with `PatternNodeRewriter` to define graph rewrites that match Ops with specific parameters.
The example below matches two nested CAReduce Ops with the same `scalar_op`,
the outer with `axis=None` (full reduction) and fuses them into a single CAReduce.
Note, that because we didn't specify it, the axis of the inner CAReduce can be anything.
The same goes for other properties of the Op that are not specified in the OpPattern.
.. testcode::
from pytensor.graph.rewriting.basic import PatternNodeRewriter
from pytensor.graph.rewriting.unify import OpPattern
from pytensor.tensor.basic import Join
from pytensor.tensor.elemwise import CAReduce, Elemwise
def output_fn(fgraph, node, s):
reduce_op = node.op
reduced_a = reduce_op(s["a"])
reduced_b = reduce_op(s["b"])
return Elemwise(s["scalar_op"])(reduced_a, reduced_b)
PatternNodeRewriter(
in_pattern=(OpPattern(CAReduce, scalar_op="scalar_op", axis=None),
(OpPattern(CAReduce, scalar_op="scalar_op",), "x")),
out_pattern=(OpPattern(CAReduce, scalar_op="scalar_op", axis=None), "x"),
)
OpPattern can also be used with `unification.unify` to match Ops with specific parameters.
This is used by PatternNodeRewriter but can also be used directly.
.. testcode::
from unification import var, unify
from etuples import etuple
import pytensor.tensor as pt
from pytensor.graph.rewriting.unify import OpPattern
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.slinalg import Solve
A = var("A")
b = var("b")
pattern = etuple(
OpPattern(Blockwise, core_op=OpPattern(Solve, assume_a="gen")),
A,
b,
)
A_pt = pt.tensor3("A")
b_pt = pt.tensor3("b")
out1 = pt.linalg.solve(A_pt, b_pt)
out2 = pt.linalg.solve(A_pt, b_pt, assume_a="pos")
assert unify(pattern, out1) == {A: A_pt, b: b_pt}
assert unify(pattern, out2) is False
assume_a = var("assume_a")
pattern = etuple(
OpPattern(Blockwise, core_op=OpPattern(Solve, assume_a=assume_a)),
A,
b,
)
assert unify(pattern, out1) == {A: A_pt, b: b_pt, assume_a: "gen"}
assert unify(pattern, out2) == {A: A_pt, b: b_pt, assume_a: "pos"}
"""
op_type
:
OpPatternOpTypeType
parameters
:
tuple
[
tuple
[
str
,
Any
]]
def
__init__
(
self
,
op_type
:
OpPatternOpTypeType
,
parameters
:
dict
[
str
,
Any
]
|
Sequence
[
tuple
[
str
,
Any
]]
|
None
=
None
,
**
kwargs
,
):
if
kwargs
:
if
parameters
is
not
None
:
raise
ValueError
(
"Cannot provide both parameters dict and keyword arguments"
)
parameters
=
kwargs
if
isinstance
(
parameters
,
dict
):
parameters
=
tuple
(
sorted
(
parameters
.
items
()))
elif
isinstance
(
parameters
,
list
|
tuple
):
parameters
=
tuple
(
sorted
(
parameters
))
elif
parameters
is
None
:
parameters
=
()
self
.
op_type
=
op_type
self
.
parameters
=
parameters
# type: ignore[assignment]
def
match_op
(
self
,
op
:
Op
):
if
not
isinstance
(
op
,
self
.
op_type
):
return
False
return
self
.
match_parameters
(
op
)
def
match_parameters
(
self
,
op
):
# This is used by methods that already check the op_type is satisfied
# Some methods may index on the op_type and know in advance the op is matched
# Also recursive calls to OpPattern.match_parameters do the op check outside to exit early (see below)
for
key
,
param
in
self
.
parameters
:
if
isinstance
(
param
,
OpPattern
):
# Parameters can itself be other OpPatterns
# We check the op_type to avoid a nested call in cases we can reject early
sub_op
=
getattr
(
op
,
key
)
if
not
isinstance
(
sub_op
,
param
.
op_type
):
return
False
# Match the pattern of the inner Op
# Skip if there are no parameters
if
param
.
parameters
and
not
param
.
match_parameters
(
sub_op
):
return
False
elif
getattr
(
op
,
key
)
!=
param
:
return
False
return
True
def
__str__
(
self
):
return
f
"OpPattern({self.op_type}, {', '.join(f'{k}={v}' for k, v in self.parameters)})"
def
_unify_parametrized_op
(
v
:
Op
,
u
:
OpPattern
,
s
:
Mapping
):
if
not
isinstance
(
v
,
u
.
op_type
):
yield
False
return
for
parameter_key
,
parameter_pattern
in
u
.
parameters
:
parameter_value
=
getattr
(
v
,
parameter_key
)
new_s
=
yield
_unify
(
parameter_value
,
parameter_pattern
,
s
)
if
new_s
is
False
:
yield
False
return
s
=
new_s
yield
s
_unify
.
add
((
Op
,
OpPattern
,
Mapping
),
_unify_parametrized_op
)
def
convert_strs_to_vars
(
x
:
tuple
|
str
|
dict
,
var_map
:
dict
[
str
,
Var
]
|
None
=
None
)
->
ExpressionTuple
|
Var
:
...
...
@@ -266,11 +427,13 @@ def convert_strs_to_vars(
if
var_map
is
None
:
var_map
=
{}
def
_convert
(
y
):
def
_convert
(
y
,
op_prop
=
False
):
if
isinstance
(
y
,
str
):
v
=
var_map
.
get
(
y
,
var
(
y
))
var_map
[
y
]
=
v
return
v
if
isinstance
(
y
,
LiteralString
):
return
y
.
value
elif
isinstance
(
y
,
dict
):
pattern
=
y
[
"pattern"
]
if
not
isinstance
(
pattern
,
str
):
...
...
@@ -282,8 +445,14 @@ def convert_strs_to_vars(
var_map
[
pattern
]
=
v
return
v
elif
isinstance
(
y
,
tuple
):
return
etuple
(
*
(
_convert
(
e
)
for
e
in
y
))
elif
isinstance
(
y
,
Number
|
np
.
ndarray
):
return
etuple
(
*
(
_convert
(
e
,
op_prop
=
op_prop
)
for
e
in
y
))
elif
isinstance
(
y
,
OpPattern
):
return
OpPattern
(
y
.
op_type
,
{
k
:
_convert
(
v
,
op_prop
=
True
)
for
k
,
v
in
y
.
parameters
},
)
elif
(
not
op_prop
)
and
isinstance
(
y
,
Number
|
np
.
ndarray
):
# If we are converting an Op property, we don't want to convert numbers to PyTensor constants
from
pytensor.tensor
import
as_tensor_variable
return
as_tensor_variable
(
y
)
...
...
tests/graph/rewriting/test_basic.py
浏览文件 @
085f2723
...
...
@@ -18,6 +18,7 @@ from pytensor.graph.rewriting.basic import (
pre_constant_merge
,
pre_greedy_node_rewriter
,
)
from
pytensor.graph.rewriting.unify
import
LiteralString
,
OpPattern
from
pytensor.raise_op
import
assert_op
from
pytensor.tensor.math
import
Dot
,
add
,
dot
,
exp
from
pytensor.tensor.rewriting.basic
import
constant_folding
...
...
@@ -283,6 +284,42 @@ class TestPatternNodeRewriter:
str_g
=
str
(
g
)
assert
str_g
==
"FunctionGraph(Op4(z, y))"
def
test_op_pattern
(
self
):
a
=
MyVariable
(
"a"
)
e1
=
MyOp
(
name
=
"MyOp(x=1)"
,
x
=
1
)(
a
)
e2
=
MyOp
(
name
=
"MyOp(x=2)"
,
x
=
2
)(
a
)
e_hello
=
MyOp
(
name
=
"MyOp(x='hello')"
,
x
=
"hello"
)(
a
)
op_x3
=
MyOp
(
name
=
"MyOp(x=3)"
,
x
=
3
)
assert
not
equal_computations
([
e1
],
[
op_x3
(
a
)])
assert
not
equal_computations
([
e2
],
[
op_x3
(
a
)])
rewriter
=
WalkingPatternNodeRewriter
(
(
OpPattern
(
MyOp
,
x
=
1
),
"a"
),
"a"
,
)
g
=
FunctionGraph
([
a
],
[
e1
,
e2
,
e1
],
copy_inputs
=
False
)
rewriter
.
rewrite
(
g
)
assert
equal_computations
(
g
.
outputs
,
[
a
,
e2
,
a
])
rewriter
=
WalkingPatternNodeRewriter
(
(
OpPattern
(
MyOp
,
x
=
"x"
),
"a"
),
lambda
fgraph
,
node
,
subs
:
(
MyOp
(
name
=
"MyOp(x+=10)"
,
x
=
subs
[
"x"
]
+
10
)(
subs
[
"a"
])
if
subs
[
"x"
]
<
10
else
False
),
)
g
=
FunctionGraph
([
a
],
[
e1
],
copy_inputs
=
False
)
rewriter
.
rewrite
(
g
)
assert
equal_computations
(
g
.
outputs
,
[
MyOp
(
name
=
"x=11"
,
x
=
11
)(
a
)])
rewriter
=
WalkingPatternNodeRewriter
(
(
OpPattern
(
MyOp
,
x
=
LiteralString
(
"hello"
)),
"a"
),
"a"
)
g
=
FunctionGraph
([
a
],
[
e1
,
e_hello
],
copy_inputs
=
False
)
rewriter
.
rewrite
(
g
)
assert
equal_computations
(
g
.
outputs
,
[
e1
,
a
])
class
NoInputOp
(
Op
):
__props__
=
(
"param"
,)
...
...
tests/graph/rewriting/test_unify.py
浏览文件 @
085f2723
...
...
@@ -11,7 +11,11 @@ import pytensor.scalar as ps
import
pytensor.tensor
as
pt
from
pytensor.graph.basic
import
Apply
,
Constant
,
equal_computations
from
pytensor.graph.op
import
Op
from
pytensor.graph.rewriting.unify
import
ConstrainedVar
,
convert_strs_to_vars
from
pytensor.graph.rewriting.unify
import
(
ConstrainedVar
,
OpPattern
,
convert_strs_to_vars
,
)
from
pytensor.tensor.type
import
TensorType
from
tests.graph.utils
import
MyType
...
...
@@ -348,3 +352,25 @@ def test_convert_strs_to_vars():
res
=
convert_strs_to_vars
((
val
,))
assert
isinstance
(
res
[
0
],
Constant
)
assert
np
.
array_equal
(
res
[
0
]
.
data
,
val
)
def
test_unify_OpPattern
():
x_pt
=
MyType
()(
"x_pt"
)
y_pt
=
MyType
()(
"y_pt"
)
out1
=
CustomOp
(
a
=
1
)(
x_pt
,
y_pt
)
out2
=
CustomOp
(
a
=
2
)(
x_pt
,
y_pt
)
x
=
var
(
"x"
)
y
=
var
(
"y"
)
pattern
=
etuple
(
OpPattern
(
CustomOp
),
x
,
y
)
assert
unify
(
pattern
,
out1
)
==
{
x
:
x_pt
,
y
:
y_pt
}
assert
unify
(
pattern
,
out2
)
==
{
x
:
x_pt
,
y
:
y_pt
}
pattern
=
etuple
(
OpPattern
(
CustomOp
,
a
=
1
),
x
,
y
)
assert
unify
(
pattern
,
out1
)
==
{
x
:
x_pt
,
y
:
y_pt
}
assert
unify
(
pattern
,
out2
)
is
False
a
=
var
(
"a"
)
pattern
=
etuple
(
OpPattern
(
CustomOp
,
a
=
a
),
x
,
y
)
assert
unify
(
pattern
,
out1
)
==
{
x
:
x_pt
,
y
:
y_pt
,
a
:
1
}
assert
unify
(
pattern
,
out2
)
==
{
x
:
x_pt
,
y
:
y_pt
,
a
:
2
}
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论