Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
566af64d
提交
566af64d
authored
9月 12, 2025
作者:
ricardoV94
提交者:
Ricardo Vieira
9月 30, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Use bitset to check ancestors more efficiently
上级
dc1e3b9c
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
77 行增加
和
76 行删除
+77
-76
elemwise.py
pytensor/tensor/rewriting/elemwise.py
+70
-69
test_printing.py
tests/test_printing.py
+7
-7
没有找到文件。
pytensor/tensor/rewriting/elemwise.py
浏览文件 @
566af64d
...
@@ -6,6 +6,7 @@ import typing
...
@@ -6,6 +6,7 @@ import typing
from
collections
import
defaultdict
,
deque
from
collections
import
defaultdict
,
deque
from
collections.abc
import
Generator
,
Sequence
from
collections.abc
import
Generator
,
Sequence
from
functools
import
cache
,
reduce
from
functools
import
cache
,
reduce
from
operator
import
or_
from
typing
import
Literal
from
typing
import
Literal
from
warnings
import
warn
from
warnings
import
warn
...
@@ -29,7 +30,7 @@ from pytensor.graph.rewriting.basic import (
...
@@ -29,7 +30,7 @@ from pytensor.graph.rewriting.basic import (
)
)
from
pytensor.graph.rewriting.db
import
SequenceDB
from
pytensor.graph.rewriting.db
import
SequenceDB
from
pytensor.graph.rewriting.unify
import
OpPattern
from
pytensor.graph.rewriting.unify
import
OpPattern
from
pytensor.graph.traversal
import
ancestors
,
toposort
from
pytensor.graph.traversal
import
toposort
from
pytensor.graph.utils
import
InconsistencyError
,
MethodNotDefined
from
pytensor.graph.utils
import
InconsistencyError
,
MethodNotDefined
from
pytensor.scalar.math
import
Grad2F1Loop
,
_grad_2f1_loop
from
pytensor.scalar.math
import
Grad2F1Loop
,
_grad_2f1_loop
from
pytensor.tensor.basic
import
(
from
pytensor.tensor.basic
import
(
...
@@ -659,16 +660,9 @@ class FusionOptimizer(GraphRewriter):
...
@@ -659,16 +660,9 @@ class FusionOptimizer(GraphRewriter):
visited_nodes
:
set
[
Apply
],
visited_nodes
:
set
[
Apply
],
fuseable_clients
:
FUSEABLE_MAPPING
,
fuseable_clients
:
FUSEABLE_MAPPING
,
unfuseable_clients
:
UNFUSEABLE_MAPPING
,
unfuseable_clients
:
UNFUSEABLE_MAPPING
,
ancestors_bitset
:
dict
[
Apply
,
int
],
toposort_index
:
dict
[
Apply
,
int
],
toposort_index
:
dict
[
Apply
,
int
],
)
->
tuple
[
list
[
Variable
],
list
[
Variable
]]:
)
->
tuple
[
list
[
Variable
],
list
[
Variable
]]:
def
variables_depend_on
(
variables
,
depend_on
,
stop_search_at
=
None
)
->
bool
:
return
any
(
a
in
depend_on
for
a
in
ancestors
(
variables
,
blockers
=
stop_search_at
)
)
for
starting_node
in
toposort_index
:
for
starting_node
in
toposort_index
:
if
starting_node
in
visited_nodes
:
if
starting_node
in
visited_nodes
:
continue
continue
...
@@ -680,7 +674,8 @@ class FusionOptimizer(GraphRewriter):
...
@@ -680,7 +674,8 @@ class FusionOptimizer(GraphRewriter):
subgraph_inputs
:
dict
[
Variable
,
Literal
[
None
]]
=
{}
# ordered set
subgraph_inputs
:
dict
[
Variable
,
Literal
[
None
]]
=
{}
# ordered set
subgraph_outputs
:
dict
[
Variable
,
Literal
[
None
]]
=
{}
# ordered set
subgraph_outputs
:
dict
[
Variable
,
Literal
[
None
]]
=
{}
# ordered set
unfuseable_clients_subgraph
:
set
[
Variable
]
=
set
()
subgraph_inputs_ancestors_bitset
=
0
unfuseable_clients_subgraph_bitset
=
0
# If we need to manipulate the maps in place, we'll do a shallow copy later
# If we need to manipulate the maps in place, we'll do a shallow copy later
# For now we query on the original ones
# For now we query on the original ones
...
@@ -712,50 +707,32 @@ class FusionOptimizer(GraphRewriter):
...
@@ -712,50 +707,32 @@ class FusionOptimizer(GraphRewriter):
if
must_become_output
:
if
must_become_output
:
subgraph_outputs
.
pop
(
next_out
,
None
)
subgraph_outputs
.
pop
(
next_out
,
None
)
required_unfuseable_inputs
=
[
# We need to check that any inputs required by this node
inp
# do not depend on other outputs of the current subgraph,
for
inp
in
next_node
.
inputs
# via an unfuseable path.
if
next_node
in
unfuseable_clients_clone
.
get
(
inp
)
must_backtrack
=
(
]
ancestors_bitset
[
next_node
]
new_required_unfuseable_inputs
=
[
&
unfuseable_clients_subgraph_bitset
inp
)
for
inp
in
required_unfuseable_inputs
if
inp
not
in
subgraph_inputs
]
must_backtrack
=
False
if
new_required_unfuseable_inputs
and
subgraph_outputs
:
# We need to check that any new inputs required by this node
# do not depend on other outputs of the current subgraph,
# via an unfuseable path.
if
variables_depend_on
(
[
next_out
],
depend_on
=
unfuseable_clients_subgraph
,
stop_search_at
=
subgraph_outputs
,
):
must_backtrack
=
True
if
not
must_backtrack
:
if
not
must_backtrack
:
implied_unfuseable_clients
=
{
implied_unfuseable_clients_bitset
=
reduce
(
c
or_
,
for
client
in
unfuseable_clients_clone
.
get
(
next_out
)
(
if
not
isinstance
(
client
.
op
,
Output
)
1
<<
toposort_index
[
client
]
for
c
in
client
.
outputs
for
client
in
unfuseable_clients_clone
.
get
(
next_out
)
}
if
not
isinstance
(
client
.
op
,
Output
)
),
new_implied_unfuseable_clients
=
(
0
,
implied_unfuseable_clients
-
unfuseable_clients_subgraph
)
)
if
new_implied_unfuseable_clients
and
subgraph_inputs
:
# We need to check that any inputs of the current subgraph
# We need to check that any inputs of the current subgraph
# do not depend on other clients of this node,
# do not depend on other clients of this node,
# via an unfuseable path.
# via an unfuseable path.
must_backtrack
=
(
if
variables_depend_on
(
subgraph_inputs_ancestors_bitset
subgraph_inputs
,
&
implied_unfuseable_clients_bitset
depend_on
=
new_implied_unfuseable_clients
,
)
):
must_backtrack
=
True
if
must_backtrack
:
if
must_backtrack
:
for
inp
in
next_node
.
inputs
:
for
inp
in
next_node
.
inputs
:
...
@@ -796,29 +773,24 @@ class FusionOptimizer(GraphRewriter):
...
@@ -796,29 +773,24 @@ class FusionOptimizer(GraphRewriter):
# immediate dependency problems. Update subgraph
# immediate dependency problems. Update subgraph
# mappings as if it next_node was part of it.
# mappings as if it next_node was part of it.
# Useless inputs will be removed by the useless Composite rewrite
# Useless inputs will be removed by the useless Composite rewrite
for
inp
in
new_required_unfuseable_inputs
:
subgraph_inputs
[
inp
]
=
None
if
must_become_output
:
if
must_become_output
:
subgraph_outputs
[
next_out
]
=
None
subgraph_outputs
[
next_out
]
=
None
unfuseable_clients_subgraph
.
update
(
unfuseable_clients_subgraph
_bitset
|=
(
new_implied_unfuseable_clients
implied_unfuseable_clients_bitset
)
)
# Expand through unvisited fuseable ancestors
for
inp
in
sorted
(
fuseable_nodes_to_visit
.
extendleft
(
next_node
.
inputs
,
sorted
(
key
=
lambda
x
:
toposort_index
.
get
(
x
.
owner
,
-
1
),
(
):
inp
.
owner
if
next_node
in
unfuseable_clients_clone
.
get
(
inp
,
()):
for
inp
in
next_node
.
inputs
# input must become an input of the subgraph since it's unfuseable with new node
if
(
subgraph_inputs_ancestors_bitset
|=
(
inp
not
in
required_unfuseable_inputs
ancestors_bitset
.
get
(
inp
.
owner
,
0
)
and
inp
.
owner
not
in
visited_nodes
)
)
subgraph_inputs
[
inp
]
=
None
),
elif
inp
.
owner
not
in
visited_nodes
:
key
=
toposort_index
.
get
,
# type: ignore[arg-type]
fuseable_nodes_to_visit
.
appendleft
(
inp
.
owner
)
)
)
# Expand through unvisited fuseable clients
# Expand through unvisited fuseable clients
fuseable_nodes_to_visit
.
extend
(
fuseable_nodes_to_visit
.
extend
(
...
@@ -855,6 +827,8 @@ class FusionOptimizer(GraphRewriter):
...
@@ -855,6 +827,8 @@ class FusionOptimizer(GraphRewriter):
visited_nodes
:
set
[
Apply
],
visited_nodes
:
set
[
Apply
],
fuseable_clients
:
FUSEABLE_MAPPING
,
fuseable_clients
:
FUSEABLE_MAPPING
,
unfuseable_clients
:
UNFUSEABLE_MAPPING
,
unfuseable_clients
:
UNFUSEABLE_MAPPING
,
toposort_index
:
dict
[
Apply
,
int
],
ancestors_bitset
:
dict
[
Apply
,
int
],
starting_nodes
:
set
[
Apply
],
starting_nodes
:
set
[
Apply
],
updated_nodes
:
set
[
Apply
],
updated_nodes
:
set
[
Apply
],
)
->
None
:
)
->
None
:
...
@@ -865,11 +839,25 @@ class FusionOptimizer(GraphRewriter):
...
@@ -865,11 +839,25 @@ class FusionOptimizer(GraphRewriter):
dropped_nodes
=
starting_nodes
-
updated_nodes
dropped_nodes
=
starting_nodes
-
updated_nodes
# Remove intermediate Composite nodes from mappings
# Remove intermediate Composite nodes from mappings
# And compute the ancestors bitset of the new composite node
# As well as the new toposort index for the new node
new_node_ancestor_bitset
=
0
new_node_toposort_index
=
len
(
toposort_index
)
for
dropped_node
in
dropped_nodes
:
for
dropped_node
in
dropped_nodes
:
(
dropped_out
,)
=
dropped_node
.
outputs
(
dropped_out
,)
=
dropped_node
.
outputs
fuseable_clients
.
pop
(
dropped_out
,
None
)
fuseable_clients
.
pop
(
dropped_out
,
None
)
unfuseable_clients
.
pop
(
dropped_out
,
None
)
unfuseable_clients
.
pop
(
dropped_out
,
None
)
visited_nodes
.
remove
(
dropped_node
)
visited_nodes
.
remove
(
dropped_node
)
# The new composite ancestor bitset is the union
# of the ancestors of all the dropped nodes
new_node_ancestor_bitset
|=
ancestors_bitset
[
dropped_node
]
# The new composite node can have the same order as the latest node that was absorbed into it
new_node_toposort_index
=
max
(
new_node_toposort_index
,
toposort_index
[
dropped_node
]
)
ancestors_bitset
[
new_composite_node
]
=
new_node_ancestor_bitset
toposort_index
[
new_composite_node
]
=
new_node_toposort_index
# Update fuseable information for subgraph inputs
# Update fuseable information for subgraph inputs
for
inp
in
subgraph_inputs
:
for
inp
in
subgraph_inputs
:
...
@@ -901,12 +889,23 @@ class FusionOptimizer(GraphRewriter):
...
@@ -901,12 +889,23 @@ class FusionOptimizer(GraphRewriter):
fuseable_clients
,
unfuseable_clients
=
initialize_fuseable_mappings
(
fg
=
fg
)
fuseable_clients
,
unfuseable_clients
=
initialize_fuseable_mappings
(
fg
=
fg
)
visited_nodes
:
set
[
Apply
]
=
set
()
visited_nodes
:
set
[
Apply
]
=
set
()
toposort_index
=
{
node
:
i
for
i
,
node
in
enumerate
(
fgraph
.
toposort
())}
toposort_index
=
{
node
:
i
for
i
,
node
in
enumerate
(
fgraph
.
toposort
())}
# Create a bitset for each node of all its ancestors
# This allows to quickly check if a variable depends on a set
ancestors_bitset
:
dict
[
Apply
,
int
]
=
{}
for
node
,
index
in
toposort_index
.
items
():
node_ancestor_bitset
=
1
<<
index
for
inp
in
node
.
inputs
:
if
(
inp_node
:
=
inp
.
owner
)
is
not
None
:
node_ancestor_bitset
|=
ancestors_bitset
[
inp_node
]
ancestors_bitset
[
node
]
=
node_ancestor_bitset
while
True
:
while
True
:
try
:
try
:
subgraph_inputs
,
subgraph_outputs
=
find_fuseable_subgraph
(
subgraph_inputs
,
subgraph_outputs
=
find_fuseable_subgraph
(
visited_nodes
=
visited_nodes
,
visited_nodes
=
visited_nodes
,
fuseable_clients
=
fuseable_clients
,
fuseable_clients
=
fuseable_clients
,
unfuseable_clients
=
unfuseable_clients
,
unfuseable_clients
=
unfuseable_clients
,
ancestors_bitset
=
ancestors_bitset
,
toposort_index
=
toposort_index
,
toposort_index
=
toposort_index
,
)
)
except
ValueError
:
except
ValueError
:
...
@@ -925,6 +924,8 @@ class FusionOptimizer(GraphRewriter):
...
@@ -925,6 +924,8 @@ class FusionOptimizer(GraphRewriter):
visited_nodes
=
visited_nodes
,
visited_nodes
=
visited_nodes
,
fuseable_clients
=
fuseable_clients
,
fuseable_clients
=
fuseable_clients
,
unfuseable_clients
=
unfuseable_clients
,
unfuseable_clients
=
unfuseable_clients
,
toposort_index
=
toposort_index
,
ancestors_bitset
=
ancestors_bitset
,
starting_nodes
=
starting_nodes
,
starting_nodes
=
starting_nodes
,
updated_nodes
=
fg
.
apply_nodes
,
updated_nodes
=
fg
.
apply_nodes
,
)
)
...
...
tests/test_printing.py
浏览文件 @
566af64d
...
@@ -301,7 +301,8 @@ def test_debugprint():
...
@@ -301,7 +301,8 @@ def test_debugprint():
Gemv_op_name
=
"CGemv"
if
pytensor
.
config
.
blas__ldflags
else
"Gemv"
Gemv_op_name
=
"CGemv"
if
pytensor
.
config
.
blas__ldflags
else
"Gemv"
exp_res
=
dedent
(
exp_res
=
dedent
(
r"""
r"""
Composite{(i2 + (i0 - i1))} 4
Composite{(i0 + (i1 - i2))} 4
├─ A
├─ ExpandDims{axis=0} v={0: [0]} 3
├─ ExpandDims{axis=0} v={0: [0]} 3
"""
"""
f
" │ └─ {Gemv_op_name}{{inplace}} d={{0: [0]}} 2"
f
" │ └─ {Gemv_op_name}{{inplace}} d={{0: [0]}} 2"
...
@@ -313,17 +314,16 @@ def test_debugprint():
...
@@ -313,17 +314,16 @@ def test_debugprint():
│ ├─ B
│ ├─ B
│ ├─ <Vector(float64, shape=(?,))>
│ ├─ <Vector(float64, shape=(?,))>
│ └─ 0.0
│ └─ 0.0
├─ D
└─ D
└─ A
Inner graphs:
Inner graphs:
Composite{(i
2 + (i0 - i1
))}
Composite{(i
0 + (i1 - i2
))}
← add 'o0'
← add 'o0'
├─ i2
└─ sub
├─ i0
├─ i0
└─ i1
└─ sub
├─ i1
└─ i2
"""
"""
)
.
lstrip
()
)
.
lstrip
()
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论