Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
566af64d
提交
566af64d
authored
9月 12, 2025
作者:
ricardoV94
提交者:
Ricardo Vieira
9月 30, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Use bitset to check ancestors more efficiently
上级
dc1e3b9c
显示空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
69 行增加
和
68 行删除
+69
-68
elemwise.py
pytensor/tensor/rewriting/elemwise.py
+62
-61
test_printing.py
tests/test_printing.py
+7
-7
没有找到文件。
pytensor/tensor/rewriting/elemwise.py
浏览文件 @
566af64d
...
...
@@ -6,6 +6,7 @@ import typing
from
collections
import
defaultdict
,
deque
from
collections.abc
import
Generator
,
Sequence
from
functools
import
cache
,
reduce
from
operator
import
or_
from
typing
import
Literal
from
warnings
import
warn
...
...
@@ -29,7 +30,7 @@ from pytensor.graph.rewriting.basic import (
)
from
pytensor.graph.rewriting.db
import
SequenceDB
from
pytensor.graph.rewriting.unify
import
OpPattern
from
pytensor.graph.traversal
import
ancestors
,
toposort
from
pytensor.graph.traversal
import
toposort
from
pytensor.graph.utils
import
InconsistencyError
,
MethodNotDefined
from
pytensor.scalar.math
import
Grad2F1Loop
,
_grad_2f1_loop
from
pytensor.tensor.basic
import
(
...
...
@@ -659,16 +660,9 @@ class FusionOptimizer(GraphRewriter):
visited_nodes
:
set
[
Apply
],
fuseable_clients
:
FUSEABLE_MAPPING
,
unfuseable_clients
:
UNFUSEABLE_MAPPING
,
ancestors_bitset
:
dict
[
Apply
,
int
],
toposort_index
:
dict
[
Apply
,
int
],
)
->
tuple
[
list
[
Variable
],
list
[
Variable
]]:
def
variables_depend_on
(
variables
,
depend_on
,
stop_search_at
=
None
)
->
bool
:
return
any
(
a
in
depend_on
for
a
in
ancestors
(
variables
,
blockers
=
stop_search_at
)
)
for
starting_node
in
toposort_index
:
if
starting_node
in
visited_nodes
:
continue
...
...
@@ -680,7 +674,8 @@ class FusionOptimizer(GraphRewriter):
subgraph_inputs
:
dict
[
Variable
,
Literal
[
None
]]
=
{}
# ordered set
subgraph_outputs
:
dict
[
Variable
,
Literal
[
None
]]
=
{}
# ordered set
unfuseable_clients_subgraph
:
set
[
Variable
]
=
set
()
subgraph_inputs_ancestors_bitset
=
0
unfuseable_clients_subgraph_bitset
=
0
# If we need to manipulate the maps in place, we'll do a shallow copy later
# For now we query on the original ones
...
...
@@ -712,50 +707,32 @@ class FusionOptimizer(GraphRewriter):
if
must_become_output
:
subgraph_outputs
.
pop
(
next_out
,
None
)
required_unfuseable_inputs
=
[
inp
for
inp
in
next_node
.
inputs
if
next_node
in
unfuseable_clients_clone
.
get
(
inp
)
]
new_required_unfuseable_inputs
=
[
inp
for
inp
in
required_unfuseable_inputs
if
inp
not
in
subgraph_inputs
]
must_backtrack
=
False
if
new_required_unfuseable_inputs
and
subgraph_outputs
:
# We need to check that any new inputs required by this node
# We need to check that any inputs required by this node
# do not depend on other outputs of the current subgraph,
# via an unfuseable path.
if
variables_depend_on
(
[
next_out
],
depend_on
=
unfuseable_clients_subgraph
,
stop_search_at
=
subgraph_outputs
,
):
must_backtrack
=
True
must_backtrack
=
(
ancestors_bitset
[
next_node
]
&
unfuseable_clients_subgraph_bitset
)
if
not
must_backtrack
:
implied_unfuseable_clients
=
{
c
implied_unfuseable_clients_bitset
=
reduce
(
or_
,
(
1
<<
toposort_index
[
client
]
for
client
in
unfuseable_clients_clone
.
get
(
next_out
)
if
not
isinstance
(
client
.
op
,
Output
)
for
c
in
client
.
outputs
}
new_implied_unfuseable_clients
=
(
implied_unfuseable_clients
-
unfuseable_clients_subgraph
),
0
,
)
if
new_implied_unfuseable_clients
and
subgraph_inputs
:
# We need to check that any inputs of the current subgraph
# do not depend on other clients of this node,
# via an unfuseable path.
if
variables_depend_on
(
subgraph_inputs
,
depend_on
=
new_implied_unfuseable_clients
,
):
must_backtrack
=
True
must_backtrack
=
(
subgraph_inputs_ancestors_bitset
&
implied_unfuseable_clients_bitset
)
if
must_backtrack
:
for
inp
in
next_node
.
inputs
:
...
...
@@ -796,29 +773,24 @@ class FusionOptimizer(GraphRewriter):
# immediate dependency problems. Update subgraph
# mappings as if it next_node was part of it.
# Useless inputs will be removed by the useless Composite rewrite
for
inp
in
new_required_unfuseable_inputs
:
subgraph_inputs
[
inp
]
=
None
if
must_become_output
:
subgraph_outputs
[
next_out
]
=
None
unfuseable_clients_subgraph
.
update
(
new_implied_unfuseable_clients
unfuseable_clients_subgraph
_bitset
|=
(
implied_unfuseable_clients_bitset
)
# Expand through unvisited fuseable ancestors
fuseable_nodes_to_visit
.
extendleft
(
sorted
(
(
inp
.
owner
for
inp
in
next_node
.
inputs
if
(
inp
not
in
required_unfuseable_inputs
and
inp
.
owner
not
in
visited_nodes
)
),
key
=
toposort_index
.
get
,
# type: ignore[arg-type]
)
for
inp
in
sorted
(
next_node
.
inputs
,
key
=
lambda
x
:
toposort_index
.
get
(
x
.
owner
,
-
1
),
):
if
next_node
in
unfuseable_clients_clone
.
get
(
inp
,
()):
# input must become an input of the subgraph since it's unfuseable with new node
subgraph_inputs_ancestors_bitset
|=
(
ancestors_bitset
.
get
(
inp
.
owner
,
0
)
)
subgraph_inputs
[
inp
]
=
None
elif
inp
.
owner
not
in
visited_nodes
:
fuseable_nodes_to_visit
.
appendleft
(
inp
.
owner
)
# Expand through unvisited fuseable clients
fuseable_nodes_to_visit
.
extend
(
...
...
@@ -855,6 +827,8 @@ class FusionOptimizer(GraphRewriter):
visited_nodes
:
set
[
Apply
],
fuseable_clients
:
FUSEABLE_MAPPING
,
unfuseable_clients
:
UNFUSEABLE_MAPPING
,
toposort_index
:
dict
[
Apply
,
int
],
ancestors_bitset
:
dict
[
Apply
,
int
],
starting_nodes
:
set
[
Apply
],
updated_nodes
:
set
[
Apply
],
)
->
None
:
...
...
@@ -865,11 +839,25 @@ class FusionOptimizer(GraphRewriter):
dropped_nodes
=
starting_nodes
-
updated_nodes
# Remove intermediate Composite nodes from mappings
# And compute the ancestors bitset of the new composite node
# As well as the new toposort index for the new node
new_node_ancestor_bitset
=
0
new_node_toposort_index
=
len
(
toposort_index
)
for
dropped_node
in
dropped_nodes
:
(
dropped_out
,)
=
dropped_node
.
outputs
fuseable_clients
.
pop
(
dropped_out
,
None
)
unfuseable_clients
.
pop
(
dropped_out
,
None
)
visited_nodes
.
remove
(
dropped_node
)
# The new composite ancestor bitset is the union
# of the ancestors of all the dropped nodes
new_node_ancestor_bitset
|=
ancestors_bitset
[
dropped_node
]
# The new composite node can have the same order as the latest node that was absorbed into it
new_node_toposort_index
=
max
(
new_node_toposort_index
,
toposort_index
[
dropped_node
]
)
ancestors_bitset
[
new_composite_node
]
=
new_node_ancestor_bitset
toposort_index
[
new_composite_node
]
=
new_node_toposort_index
# Update fuseable information for subgraph inputs
for
inp
in
subgraph_inputs
:
...
...
@@ -901,12 +889,23 @@ class FusionOptimizer(GraphRewriter):
fuseable_clients
,
unfuseable_clients
=
initialize_fuseable_mappings
(
fg
=
fg
)
visited_nodes
:
set
[
Apply
]
=
set
()
toposort_index
=
{
node
:
i
for
i
,
node
in
enumerate
(
fgraph
.
toposort
())}
# Create a bitset for each node of all its ancestors
# This allows to quickly check if a variable depends on a set
ancestors_bitset
:
dict
[
Apply
,
int
]
=
{}
for
node
,
index
in
toposort_index
.
items
():
node_ancestor_bitset
=
1
<<
index
for
inp
in
node
.
inputs
:
if
(
inp_node
:
=
inp
.
owner
)
is
not
None
:
node_ancestor_bitset
|=
ancestors_bitset
[
inp_node
]
ancestors_bitset
[
node
]
=
node_ancestor_bitset
while
True
:
try
:
subgraph_inputs
,
subgraph_outputs
=
find_fuseable_subgraph
(
visited_nodes
=
visited_nodes
,
fuseable_clients
=
fuseable_clients
,
unfuseable_clients
=
unfuseable_clients
,
ancestors_bitset
=
ancestors_bitset
,
toposort_index
=
toposort_index
,
)
except
ValueError
:
...
...
@@ -925,6 +924,8 @@ class FusionOptimizer(GraphRewriter):
visited_nodes
=
visited_nodes
,
fuseable_clients
=
fuseable_clients
,
unfuseable_clients
=
unfuseable_clients
,
toposort_index
=
toposort_index
,
ancestors_bitset
=
ancestors_bitset
,
starting_nodes
=
starting_nodes
,
updated_nodes
=
fg
.
apply_nodes
,
)
...
...
tests/test_printing.py
浏览文件 @
566af64d
...
...
@@ -301,7 +301,8 @@ def test_debugprint():
Gemv_op_name
=
"CGemv"
if
pytensor
.
config
.
blas__ldflags
else
"Gemv"
exp_res
=
dedent
(
r"""
Composite{(i2 + (i0 - i1))} 4
Composite{(i0 + (i1 - i2))} 4
├─ A
├─ ExpandDims{axis=0} v={0: [0]} 3
"""
f
" │ └─ {Gemv_op_name}{{inplace}} d={{0: [0]}} 2"
...
...
@@ -313,17 +314,16 @@ def test_debugprint():
│ ├─ B
│ ├─ <Vector(float64, shape=(?,))>
│ └─ 0.0
├─ D
└─ A
└─ D
Inner graphs:
Composite{(i
2 + (i0 - i1
))}
Composite{(i
0 + (i1 - i2
))}
← add 'o0'
├─ i2
└─ sub
├─ i0
└─ i1
└─ sub
├─ i1
└─ i2
"""
)
.
lstrip
()
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论