Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
d5d298a5
提交
d5d298a5
authored
9月 12, 2025
作者:
ricardoV94
提交者:
Ricardo Vieira
9月 30, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Cleanup FusionOptimizer code
上级
71618f60
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
79 行增加
和
94 行删除
+79
-94
elemwise.py
pytensor/tensor/rewriting/elemwise.py
+79
-94
没有找到文件。
pytensor/tensor/rewriting/elemwise.py
浏览文件 @
d5d298a5
...
...
@@ -5,7 +5,7 @@ import sys
from
collections
import
defaultdict
,
deque
from
collections.abc
import
Generator
,
Sequence
from
functools
import
cache
,
reduce
from
typing
import
TypeVar
from
typing
import
Literal
from
warnings
import
warn
import
pytensor.scalar.basic
as
ps
...
...
@@ -555,8 +555,6 @@ class FusionOptimizer(GraphRewriter):
callbacks_before
=
fgraph
.
execute_callbacks_times
.
copy
()
callback_before
=
fgraph
.
execute_callbacks_time
max_operands
=
elemwise_max_operands_fct
(
None
)
def
find_next_fuseable_subgraph
(
fg
:
FunctionGraph
,
)
->
Generator
[
tuple
[
list
[
Variable
],
list
[
Variable
]],
None
,
None
]:
...
...
@@ -568,8 +566,7 @@ class FusionOptimizer(GraphRewriter):
This generator assumes that such subgraph is replaced by a single
Elemwise Composite before being accessed again in the next iteration.
"""
FUSEABLE_MAPPING
=
defaultdict
[
Variable
,
list
[
Apply
]]
FUSEABLE_MAPPING
=
defaultdict
[
Variable
,
set
[
Apply
]]
UNFUSEABLE_MAPPING
=
defaultdict
[
Variable
,
set
[
Apply
]]
def
initialize_fuseable_mappings
(
...
...
@@ -591,35 +588,31 @@ class FusionOptimizer(GraphRewriter):
# to ensure the rewrite remains deterministic.
# This is not a problem from unfuseable ones, as they can never
# become part of the graph.
fuseable_clients
:
FUSEABLE_MAPPING
=
defaultdict
(
lis
t
)
fuseable_clients
:
FUSEABLE_MAPPING
=
defaultdict
(
se
t
)
unfuseable_clients
:
UNFUSEABLE_MAPPING
=
defaultdict
(
set
)
for
out
,
clients
in
fg
.
clients
.
items
():
# Old FunctionGraph nodes remain in the clients dictionary
# even after they are removed by rewrites
if
not
clients
:
continue
out_maybe_fuseable
=
(
out
.
owner
out
.
owner
is
not
None
and
isinstance
(
out
.
owner
.
op
,
Elemwise
)
# and not isinstance(out.owner.op.scalar_op, ps.Composite)
and
len
(
out
.
owner
.
outputs
)
==
1
and
elemwise_scalar_op_has_c_code
(
out
.
owner
)
)
for
client
,
_
in
clients
:
if
(
out_maybe_fuseable
and
isinstance
(
client
.
op
,
Elemwise
)
# and not isinstance(client.op.scalar_op, ps.Composite)
and
len
(
client
.
outputs
)
==
1
and
out
.
type
.
broadcastable
==
client
.
outputs
[
0
]
.
type
.
broadcastable
and
elemwise_scalar_op_has_c_code
(
client
)
):
if
client
not
in
fuseable_clients
[
out
]:
fuseable_clients
[
out
]
.
append
(
client
)
else
:
unfuseable_clients
[
out
]
.
add
(
client
)
if
out_maybe_fuseable
:
out_bcast
=
out
.
type
.
broadcastable
for
client
,
_
in
clients
:
if
(
isinstance
(
client
.
op
,
Elemwise
)
# and not isinstance(client.op.scalar_op, ps.Composite)
and
len
(
client
.
outputs
)
==
1
and
out_bcast
==
client
.
outputs
[
0
]
.
type
.
broadcastable
and
elemwise_scalar_op_has_c_code
(
client
)
):
fuseable_clients
[
out
]
.
add
(
client
)
else
:
unfuseable_clients
[
out
]
.
add
(
client
)
else
:
unfuseable_clients
[
out
]
=
{
client
for
client
,
_
in
clients
}
return
fuseable_clients
,
unfuseable_clients
...
...
@@ -630,16 +623,6 @@ class FusionOptimizer(GraphRewriter):
unfuseable_clients
:
UNFUSEABLE_MAPPING
,
toposort_index
:
dict
[
Apply
,
int
],
)
->
tuple
[
list
[
Variable
],
list
[
Variable
]]:
KT
=
TypeVar
(
"KT"
)
VT
=
TypeVar
(
"VT"
,
list
,
set
)
def
shallow_clone_defaultdict
(
d
:
defaultdict
[
KT
,
VT
],
)
->
defaultdict
[
KT
,
VT
]:
new_dict
:
defaultdict
[
KT
,
VT
]
=
defaultdict
(
d
.
default_factory
)
new_dict
.
update
({
k
:
v
.
copy
()
for
k
,
v
in
d
.
items
()})
return
new_dict
def
variables_depend_on
(
variables
,
depend_on
,
stop_search_at
=
None
)
->
bool
:
...
...
@@ -657,17 +640,19 @@ class FusionOptimizer(GraphRewriter):
visited_nodes
.
add
(
starting_node
)
continue
subgraph_inputs
:
list
[
Variable
]
=
[]
subgraph_outputs
:
list
[
Variable
]
=
[]
subgraph_inputs
:
dict
[
Variable
,
Literal
[
None
]]
=
{}
# ordered set
subgraph_outputs
:
dict
[
Variable
,
Literal
[
None
]]
=
{}
# ordered set
unfuseable_clients_subgraph
:
set
[
Variable
]
=
set
()
# Shallow cloning of maps so that they can be manipulated in place
fuseable_clients_temp
=
shallow_clone_defaultdict
(
fuseable_clients
)
unfuseable_clients_clone
=
shallow_clone_defaultdict
(
unfuseable_clients
fuseable_clients_clone
:
FUSEABLE_MAPPING
=
defaultdict
(
set
)
fuseable_clients_clone
.
update
(
{
k
:
v
.
copy
()
for
k
,
v
in
fuseable_clients
.
items
()}
)
unfuseable_clients_clone
:
UNFUSEABLE_MAPPING
=
defaultdict
(
set
)
unfuseable_clients_clone
.
update
(
{
k
:
v
.
copy
()
for
k
,
v
in
unfuseable_clients
.
items
()}
)
fuseable_nodes_to_visit
=
deque
([
starting_node
])
# We now try to expand as much as possible towards the potentially
# fuseable clients and ancestors to detect the largest possible
...
...
@@ -676,6 +661,7 @@ class FusionOptimizer(GraphRewriter):
# some inputs or clients may depend on other nodes of the same
# subgraph via a path that cannot be included in the Composite
# (unfuseable)
fuseable_nodes_to_visit
=
deque
([
starting_node
])
while
fuseable_nodes_to_visit
:
next_node
=
fuseable_nodes_to_visit
.
popleft
()
visited_nodes
.
add
(
next_node
)
...
...
@@ -684,15 +670,14 @@ class FusionOptimizer(GraphRewriter):
# If the output variable of next_node has no fuseable clients
# or has unfuseable clients, then next_node must become an output
# if it is to be fused.
must_become_output
=
(
next_out
not
in
fuseable_clients_temp
or
next_out
in
unfuseable_clients_clone
)
must_become_output
=
not
fuseable_clients_clone
.
get
(
next_out
)
or
unfuseable_clients_clone
.
get
(
next_out
)
# We have backtracked to this node, and it may no longer be a viable output,
# so we remove it and check again as if we had never seen this node
if
must_become_output
and
next_out
in
subgraph_outputs
:
subgraph_outputs
.
remove
(
next_out
)
if
must_become_output
:
subgraph_outputs
.
pop
(
next_out
,
None
)
required_unfuseable_inputs
=
[
inp
...
...
@@ -744,18 +729,19 @@ class FusionOptimizer(GraphRewriter):
if
(
inp
.
owner
in
visited_nodes
# next_node could have the same input repeated
and
next_node
in
fuseable_clients_
temp
[
inp
]
and
next_node
in
fuseable_clients_
clone
[
inp
]
):
fuseable_clients_
temp
[
inp
]
.
remove
(
next_node
)
fuseable_clients_
clone
[
inp
]
.
remove
(
next_node
)
unfuseable_clients_clone
[
inp
]
.
add
(
next_node
)
# This input must become an output of the subgraph,
# because it can't be merged with next_node.
# We will revisit it to make sure this is safe.
fuseable_nodes_to_visit
.
appendleft
(
inp
.
owner
)
for
client
in
fuseable_clients_temp
[
next_out
]:
# need to convert to tuple not to change set size during iteration
for
client
in
tuple
(
fuseable_clients_clone
[
next_out
]):
if
client
in
visited_nodes
:
fuseable_clients_
temp
[
next_out
]
.
remove
(
client
)
fuseable_clients_
clone
[
next_out
]
.
remove
(
client
)
unfuseable_clients_clone
[
next_out
]
.
add
(
client
)
# next_out must become an input of the subgraph.
# We will revisit any of its clients currently
...
...
@@ -771,74 +757,72 @@ class FusionOptimizer(GraphRewriter):
# mappings as if it next_node was part of it.
# Useless inputs will be removed by the useless Composite rewrite
for
inp
in
new_required_unfuseable_inputs
:
if
inp
not
in
subgraph_inputs
:
subgraph_inputs
.
append
(
inp
)
subgraph_inputs
[
inp
]
=
None
if
must_become_output
:
subgraph_outputs
.
append
(
next_out
)
subgraph_outputs
[
next_out
]
=
None
unfuseable_clients_subgraph
.
update
(
new_implied_unfuseable_clients
)
# Expand through unvisited fuseable ancestors
f
or
inp
in
sorted
(
(
inp
for
inp
in
next_node
.
inputs
if
(
i
np
not
in
required_unfuseable_inputs
and
inp
.
owner
not
in
visited_node
s
)
),
key
=
lambda
inp
:
toposort_index
[
inp
.
owner
]
,
reverse
=
True
,
):
fuseable_nodes_to_visit
.
appendleft
(
inp
.
owner
)
f
useable_nodes_to_visit
.
extendleft
(
sorted
(
(
inp
.
owner
for
inp
in
next_node
.
inputs
i
f
(
inp
not
in
required_unfuseable_input
s
and
inp
.
owner
not
in
visited_nodes
)
)
,
key
=
toposort_index
.
get
,
# type: ignore[arg-type]
)
)
# Expand through unvisited fuseable clients
for
next_node
in
sorted
(
(
node
for
node
in
fuseable_clients_temp
.
get
(
next_out
,
())
if
node
not
in
visited_nodes
),
key
=
lambda
node
:
toposort_index
[
node
],
):
fuseable_nodes_to_visit
.
append
(
next_node
)
fuseable_nodes_to_visit
.
extend
(
sorted
(
(
node
for
node
in
fuseable_clients_clone
.
get
(
next_out
,
())
if
node
not
in
visited_nodes
),
key
=
toposort_index
.
get
,
# type: ignore[arg-type]
)
)
# Don't return if final subgraph is just the original Elemwise
if
len
(
subgraph_outputs
)
==
1
and
set
(
subgraph_outputs
[
0
]
.
owner
.
inputs
next
(
iter
(
subgraph_outputs
))
.
owner
.
inputs
)
==
set
(
subgraph_inputs
):
# Update global fuseable mappings
# No input was actually fuseable
for
inp
in
starting_node
.
inputs
:
if
starting_node
in
fuseable_clients
.
get
(
inp
,
()):
fuseable_clients
[
inp
]
.
remove
(
starting_node
)
unfuseable_clients
[
inp
]
.
add
(
starting_node
)
fuseable_clients
[
inp
]
.
discard
(
starting_node
)
unfuseable_clients
[
inp
]
.
add
(
starting_node
)
# No client was actually fuseable
unfuseable_clients
[
starting_out
]
.
update
(
fuseable_clients
.
pop
(
starting_out
,
())
)
continue
return
subgraph_inputs
,
subgraph_outputs
return
list
(
subgraph_inputs
),
list
(
subgraph_outputs
)
raise
ValueError
def
update_fuseable_mappings_after_fg_replace
(
*
,
fg
:
FunctionGraph
,
visited_nodes
:
set
[
Apply
],
fuseable_clients
:
FUSEABLE_MAPPING
,
unfuseable_clients
:
UNFUSEABLE_MAPPING
,
starting_nodes
:
set
[
Apply
],
updated_nodes
:
set
[
Apply
],
)
->
None
:
# Find new composite node and dropped intermediate nodes
# by comparing the current fg.apply nodes with the cached
# original nodes
next_nodes
=
fg
.
apply_nodes
(
new_composite_node
,)
=
next_nodes
-
starting_nodes
dropped_nodes
=
starting_nodes
-
next_nodes
(
new_composite_node
,)
=
updated_nodes
-
starting_nodes
dropped_nodes
=
starting_nodes
-
updated_nodes
# Remove intermediate Composite nodes from mappings
for
dropped_node
in
dropped_nodes
:
...
...
@@ -850,11 +834,11 @@ class FusionOptimizer(GraphRewriter):
# Update fuseable information for subgraph inputs
for
inp
in
subgraph_inputs
:
if
inp
in
fuseable_clients
:
new_fuseable_clients
=
[
new_fuseable_clients
=
{
client
for
client
in
fuseable_clients
[
inp
]
if
client
not
in
dropped_nodes
]
}
if
new_fuseable_clients
:
fuseable_clients
[
inp
]
=
new_fuseable_clients
else
:
...
...
@@ -898,13 +882,15 @@ class FusionOptimizer(GraphRewriter):
# generator. For large models (as in `TestFusion.test_big_fusion`)
# this can provide huge speedups
update_fuseable_mappings_after_fg_replace
(
fg
=
fg
,
visited_nodes
=
visited_nodes
,
fuseable_clients
=
fuseable_clients
,
unfuseable_clients
=
unfuseable_clients
,
starting_nodes
=
starting_nodes
,
updated_nodes
=
fg
.
apply_nodes
,
)
max_operands
=
elemwise_max_operands_fct
(
None
)
reason
=
self
.
__class__
.
__name__
nb_fused
=
0
nb_replacement
=
0
for
inputs
,
outputs
in
find_next_fuseable_subgraph
(
fgraph
):
...
...
@@ -923,13 +909,12 @@ class FusionOptimizer(GraphRewriter):
assert
len
(
outputs
)
==
len
(
composite_outputs
)
for
old_out
,
composite_out
in
zip
(
outputs
,
composite_outputs
):
# Preserve any names on the original outputs
if
old_out
.
name
:
composite_out
.
name
=
old_
out
.
name
if
old_
name
:
=
old_
out
.
name
:
composite_out
.
name
=
old_name
starting_nodes
=
len
(
fgraph
.
apply_nodes
)
fgraph
.
replace_all_validate
(
list
(
zip
(
outputs
,
composite_outputs
,
strict
=
True
)),
reason
=
self
.
__class__
.
__name__
,
tuple
(
zip
(
outputs
,
composite_outputs
)),
reason
=
reason
)
nb_fused
+=
1
nb_replacement
+=
(
starting_nodes
-
len
(
fgraph
.
apply_nodes
))
+
1
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论