Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
d5d298a5
提交
d5d298a5
authored
9月 12, 2025
作者:
ricardoV94
提交者:
Ricardo Vieira
9月 30, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Cleanup FusionOptimizer code
上级
71618f60
显示空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
59 行增加
和
74 行删除
+59
-74
elemwise.py
pytensor/tensor/rewriting/elemwise.py
+59
-74
没有找到文件。
pytensor/tensor/rewriting/elemwise.py
浏览文件 @
d5d298a5
...
@@ -5,7 +5,7 @@ import sys
...
@@ -5,7 +5,7 @@ import sys
from
collections
import
defaultdict
,
deque
from
collections
import
defaultdict
,
deque
from
collections.abc
import
Generator
,
Sequence
from
collections.abc
import
Generator
,
Sequence
from
functools
import
cache
,
reduce
from
functools
import
cache
,
reduce
from
typing
import
TypeVar
from
typing
import
Literal
from
warnings
import
warn
from
warnings
import
warn
import
pytensor.scalar.basic
as
ps
import
pytensor.scalar.basic
as
ps
...
@@ -555,8 +555,6 @@ class FusionOptimizer(GraphRewriter):
...
@@ -555,8 +555,6 @@ class FusionOptimizer(GraphRewriter):
callbacks_before
=
fgraph
.
execute_callbacks_times
.
copy
()
callbacks_before
=
fgraph
.
execute_callbacks_times
.
copy
()
callback_before
=
fgraph
.
execute_callbacks_time
callback_before
=
fgraph
.
execute_callbacks_time
max_operands
=
elemwise_max_operands_fct
(
None
)
def
find_next_fuseable_subgraph
(
def
find_next_fuseable_subgraph
(
fg
:
FunctionGraph
,
fg
:
FunctionGraph
,
)
->
Generator
[
tuple
[
list
[
Variable
],
list
[
Variable
]],
None
,
None
]:
)
->
Generator
[
tuple
[
list
[
Variable
],
list
[
Variable
]],
None
,
None
]:
...
@@ -568,8 +566,7 @@ class FusionOptimizer(GraphRewriter):
...
@@ -568,8 +566,7 @@ class FusionOptimizer(GraphRewriter):
This generator assumes that such subgraph is replaced by a single
This generator assumes that such subgraph is replaced by a single
Elemwise Composite before being accessed again in the next iteration.
Elemwise Composite before being accessed again in the next iteration.
"""
"""
FUSEABLE_MAPPING
=
defaultdict
[
Variable
,
set
[
Apply
]]
FUSEABLE_MAPPING
=
defaultdict
[
Variable
,
list
[
Apply
]]
UNFUSEABLE_MAPPING
=
defaultdict
[
Variable
,
set
[
Apply
]]
UNFUSEABLE_MAPPING
=
defaultdict
[
Variable
,
set
[
Apply
]]
def
initialize_fuseable_mappings
(
def
initialize_fuseable_mappings
(
...
@@ -591,35 +588,31 @@ class FusionOptimizer(GraphRewriter):
...
@@ -591,35 +588,31 @@ class FusionOptimizer(GraphRewriter):
# to ensure the rewrite remains deterministic.
# to ensure the rewrite remains deterministic.
# This is not a problem from unfuseable ones, as they can never
# This is not a problem from unfuseable ones, as they can never
# become part of the graph.
# become part of the graph.
fuseable_clients
:
FUSEABLE_MAPPING
=
defaultdict
(
lis
t
)
fuseable_clients
:
FUSEABLE_MAPPING
=
defaultdict
(
se
t
)
unfuseable_clients
:
UNFUSEABLE_MAPPING
=
defaultdict
(
set
)
unfuseable_clients
:
UNFUSEABLE_MAPPING
=
defaultdict
(
set
)
for
out
,
clients
in
fg
.
clients
.
items
():
for
out
,
clients
in
fg
.
clients
.
items
():
# Old FunctionGraph nodes remain in the clients dictionary
# even after they are removed by rewrites
if
not
clients
:
continue
out_maybe_fuseable
=
(
out_maybe_fuseable
=
(
out
.
owner
out
.
owner
is
not
None
and
isinstance
(
out
.
owner
.
op
,
Elemwise
)
and
isinstance
(
out
.
owner
.
op
,
Elemwise
)
# and not isinstance(out.owner.op.scalar_op, ps.Composite)
# and not isinstance(out.owner.op.scalar_op, ps.Composite)
and
len
(
out
.
owner
.
outputs
)
==
1
and
len
(
out
.
owner
.
outputs
)
==
1
and
elemwise_scalar_op_has_c_code
(
out
.
owner
)
and
elemwise_scalar_op_has_c_code
(
out
.
owner
)
)
)
if
out_maybe_fuseable
:
out_bcast
=
out
.
type
.
broadcastable
for
client
,
_
in
clients
:
for
client
,
_
in
clients
:
if
(
if
(
out_maybe_fuseable
isinstance
(
client
.
op
,
Elemwise
)
and
isinstance
(
client
.
op
,
Elemwise
)
# and not isinstance(client.op.scalar_op, ps.Composite)
# and not isinstance(client.op.scalar_op, ps.Composite)
and
len
(
client
.
outputs
)
==
1
and
len
(
client
.
outputs
)
==
1
and
out
.
type
.
broadcastable
and
out_bcast
==
client
.
outputs
[
0
]
.
type
.
broadcastable
==
client
.
outputs
[
0
]
.
type
.
broadcastable
and
elemwise_scalar_op_has_c_code
(
client
)
and
elemwise_scalar_op_has_c_code
(
client
)
):
):
if
client
not
in
fuseable_clients
[
out
]:
fuseable_clients
[
out
]
.
add
(
client
)
fuseable_clients
[
out
]
.
append
(
client
)
else
:
else
:
unfuseable_clients
[
out
]
.
add
(
client
)
unfuseable_clients
[
out
]
.
add
(
client
)
else
:
unfuseable_clients
[
out
]
=
{
client
for
client
,
_
in
clients
}
return
fuseable_clients
,
unfuseable_clients
return
fuseable_clients
,
unfuseable_clients
...
@@ -630,16 +623,6 @@ class FusionOptimizer(GraphRewriter):
...
@@ -630,16 +623,6 @@ class FusionOptimizer(GraphRewriter):
unfuseable_clients
:
UNFUSEABLE_MAPPING
,
unfuseable_clients
:
UNFUSEABLE_MAPPING
,
toposort_index
:
dict
[
Apply
,
int
],
toposort_index
:
dict
[
Apply
,
int
],
)
->
tuple
[
list
[
Variable
],
list
[
Variable
]]:
)
->
tuple
[
list
[
Variable
],
list
[
Variable
]]:
KT
=
TypeVar
(
"KT"
)
VT
=
TypeVar
(
"VT"
,
list
,
set
)
def
shallow_clone_defaultdict
(
d
:
defaultdict
[
KT
,
VT
],
)
->
defaultdict
[
KT
,
VT
]:
new_dict
:
defaultdict
[
KT
,
VT
]
=
defaultdict
(
d
.
default_factory
)
new_dict
.
update
({
k
:
v
.
copy
()
for
k
,
v
in
d
.
items
()})
return
new_dict
def
variables_depend_on
(
def
variables_depend_on
(
variables
,
depend_on
,
stop_search_at
=
None
variables
,
depend_on
,
stop_search_at
=
None
)
->
bool
:
)
->
bool
:
...
@@ -657,17 +640,19 @@ class FusionOptimizer(GraphRewriter):
...
@@ -657,17 +640,19 @@ class FusionOptimizer(GraphRewriter):
visited_nodes
.
add
(
starting_node
)
visited_nodes
.
add
(
starting_node
)
continue
continue
subgraph_inputs
:
list
[
Variable
]
=
[]
subgraph_inputs
:
dict
[
Variable
,
Literal
[
None
]]
=
{}
# ordered set
subgraph_outputs
:
list
[
Variable
]
=
[]
subgraph_outputs
:
dict
[
Variable
,
Literal
[
None
]]
=
{}
# ordered set
unfuseable_clients_subgraph
:
set
[
Variable
]
=
set
()
unfuseable_clients_subgraph
:
set
[
Variable
]
=
set
()
# Shallow cloning of maps so that they can be manipulated in place
# Shallow cloning of maps so that they can be manipulated in place
fuseable_clients_temp
=
shallow_clone_defaultdict
(
fuseable_clients
)
fuseable_clients_clone
:
FUSEABLE_MAPPING
=
defaultdict
(
set
)
unfuseable_clients_clone
=
shallow_clone_defaultdict
(
fuseable_clients_clone
.
update
(
unfuseable_clients
{
k
:
v
.
copy
()
for
k
,
v
in
fuseable_clients
.
items
()}
)
unfuseable_clients_clone
:
UNFUSEABLE_MAPPING
=
defaultdict
(
set
)
unfuseable_clients_clone
.
update
(
{
k
:
v
.
copy
()
for
k
,
v
in
unfuseable_clients
.
items
()}
)
)
fuseable_nodes_to_visit
=
deque
([
starting_node
])
# We now try to expand as much as possible towards the potentially
# We now try to expand as much as possible towards the potentially
# fuseable clients and ancestors to detect the largest possible
# fuseable clients and ancestors to detect the largest possible
...
@@ -676,6 +661,7 @@ class FusionOptimizer(GraphRewriter):
...
@@ -676,6 +661,7 @@ class FusionOptimizer(GraphRewriter):
# some inputs or clients may depend on other nodes of the same
# some inputs or clients may depend on other nodes of the same
# subgraph via a path that cannot be included in the Composite
# subgraph via a path that cannot be included in the Composite
# (unfuseable)
# (unfuseable)
fuseable_nodes_to_visit
=
deque
([
starting_node
])
while
fuseable_nodes_to_visit
:
while
fuseable_nodes_to_visit
:
next_node
=
fuseable_nodes_to_visit
.
popleft
()
next_node
=
fuseable_nodes_to_visit
.
popleft
()
visited_nodes
.
add
(
next_node
)
visited_nodes
.
add
(
next_node
)
...
@@ -684,15 +670,14 @@ class FusionOptimizer(GraphRewriter):
...
@@ -684,15 +670,14 @@ class FusionOptimizer(GraphRewriter):
# If the output variable of next_node has no fuseable clients
# If the output variable of next_node has no fuseable clients
# or has unfuseable clients, then next_node must become an output
# or has unfuseable clients, then next_node must become an output
# if it is to be fused.
# if it is to be fused.
must_become_output
=
(
must_become_output
=
not
fuseable_clients_clone
.
get
(
next_out
not
in
fuseable_clients_temp
next_out
or
next_out
in
unfuseable_clients_clone
)
or
unfuseable_clients_clone
.
get
(
next_out
)
)
# We have backtracked to this node, and it may no longer be a viable output,
# We have backtracked to this node, and it may no longer be a viable output,
# so we remove it and check again as if we had never seen this node
# so we remove it and check again as if we had never seen this node
if
must_become_output
and
next_out
in
subgraph_outputs
:
if
must_become_output
:
subgraph_outputs
.
remove
(
next_out
)
subgraph_outputs
.
pop
(
next_out
,
None
)
required_unfuseable_inputs
=
[
required_unfuseable_inputs
=
[
inp
inp
...
@@ -744,18 +729,19 @@ class FusionOptimizer(GraphRewriter):
...
@@ -744,18 +729,19 @@ class FusionOptimizer(GraphRewriter):
if
(
if
(
inp
.
owner
in
visited_nodes
inp
.
owner
in
visited_nodes
# next_node could have the same input repeated
# next_node could have the same input repeated
and
next_node
in
fuseable_clients_
temp
[
inp
]
and
next_node
in
fuseable_clients_
clone
[
inp
]
):
):
fuseable_clients_
temp
[
inp
]
.
remove
(
next_node
)
fuseable_clients_
clone
[
inp
]
.
remove
(
next_node
)
unfuseable_clients_clone
[
inp
]
.
add
(
next_node
)
unfuseable_clients_clone
[
inp
]
.
add
(
next_node
)
# This input must become an output of the subgraph,
# This input must become an output of the subgraph,
# because it can't be merged with next_node.
# because it can't be merged with next_node.
# We will revisit it to make sure this is safe.
# We will revisit it to make sure this is safe.
fuseable_nodes_to_visit
.
appendleft
(
inp
.
owner
)
fuseable_nodes_to_visit
.
appendleft
(
inp
.
owner
)
for
client
in
fuseable_clients_temp
[
next_out
]:
# need to convert to tuple not to change set size during iteration
for
client
in
tuple
(
fuseable_clients_clone
[
next_out
]):
if
client
in
visited_nodes
:
if
client
in
visited_nodes
:
fuseable_clients_
temp
[
next_out
]
.
remove
(
client
)
fuseable_clients_
clone
[
next_out
]
.
remove
(
client
)
unfuseable_clients_clone
[
next_out
]
.
add
(
client
)
unfuseable_clients_clone
[
next_out
]
.
add
(
client
)
# next_out must become an input of the subgraph.
# next_out must become an input of the subgraph.
# We will revisit any of its clients currently
# We will revisit any of its clients currently
...
@@ -771,50 +757,49 @@ class FusionOptimizer(GraphRewriter):
...
@@ -771,50 +757,49 @@ class FusionOptimizer(GraphRewriter):
# mappings as if it next_node was part of it.
# mappings as if it next_node was part of it.
# Useless inputs will be removed by the useless Composite rewrite
# Useless inputs will be removed by the useless Composite rewrite
for
inp
in
new_required_unfuseable_inputs
:
for
inp
in
new_required_unfuseable_inputs
:
if
inp
not
in
subgraph_inputs
:
subgraph_inputs
[
inp
]
=
None
subgraph_inputs
.
append
(
inp
)
if
must_become_output
:
if
must_become_output
:
subgraph_outputs
.
append
(
next_out
)
subgraph_outputs
[
next_out
]
=
None
unfuseable_clients_subgraph
.
update
(
unfuseable_clients_subgraph
.
update
(
new_implied_unfuseable_clients
new_implied_unfuseable_clients
)
)
# Expand through unvisited fuseable ancestors
# Expand through unvisited fuseable ancestors
for
inp
in
sorted
(
fuseable_nodes_to_visit
.
extendleft
(
sorted
(
(
(
inp
inp
.
owner
for
inp
in
next_node
.
inputs
for
inp
in
next_node
.
inputs
if
(
if
(
inp
not
in
required_unfuseable_inputs
inp
not
in
required_unfuseable_inputs
and
inp
.
owner
not
in
visited_nodes
and
inp
.
owner
not
in
visited_nodes
)
)
),
),
key
=
lambda
inp
:
toposort_index
[
inp
.
owner
],
key
=
toposort_index
.
get
,
# type: ignore[arg-type]
reverse
=
True
,
)
):
)
fuseable_nodes_to_visit
.
appendleft
(
inp
.
owner
)
# Expand through unvisited fuseable clients
# Expand through unvisited fuseable clients
for
next_node
in
sorted
(
fuseable_nodes_to_visit
.
extend
(
sorted
(
(
(
node
node
for
node
in
fuseable_clients_temp
.
get
(
next_out
,
())
for
node
in
fuseable_clients_clone
.
get
(
next_out
,
())
if
node
not
in
visited_nodes
if
node
not
in
visited_nodes
),
),
key
=
lambda
node
:
toposort_index
[
node
],
key
=
toposort_index
.
get
,
# type: ignore[arg-type]
):
)
fuseable_nodes_to_visit
.
append
(
next_node
)
)
# Don't return if final subgraph is just the original Elemwise
# Don't return if final subgraph is just the original Elemwise
if
len
(
subgraph_outputs
)
==
1
and
set
(
if
len
(
subgraph_outputs
)
==
1
and
set
(
subgraph_outputs
[
0
]
.
owner
.
inputs
next
(
iter
(
subgraph_outputs
))
.
owner
.
inputs
)
==
set
(
subgraph_inputs
):
)
==
set
(
subgraph_inputs
):
# Update global fuseable mappings
# Update global fuseable mappings
# No input was actually fuseable
# No input was actually fuseable
for
inp
in
starting_node
.
inputs
:
for
inp
in
starting_node
.
inputs
:
if
starting_node
in
fuseable_clients
.
get
(
inp
,
()):
fuseable_clients
[
inp
]
.
discard
(
starting_node
)
fuseable_clients
[
inp
]
.
remove
(
starting_node
)
unfuseable_clients
[
inp
]
.
add
(
starting_node
)
unfuseable_clients
[
inp
]
.
add
(
starting_node
)
# No client was actually fuseable
# No client was actually fuseable
unfuseable_clients
[
starting_out
]
.
update
(
unfuseable_clients
[
starting_out
]
.
update
(
...
@@ -822,23 +807,22 @@ class FusionOptimizer(GraphRewriter):
...
@@ -822,23 +807,22 @@ class FusionOptimizer(GraphRewriter):
)
)
continue
continue
return
subgraph_inputs
,
subgraph_outputs
return
list
(
subgraph_inputs
),
list
(
subgraph_outputs
)
raise
ValueError
raise
ValueError
def
update_fuseable_mappings_after_fg_replace
(
def
update_fuseable_mappings_after_fg_replace
(
*
,
*
,
fg
:
FunctionGraph
,
visited_nodes
:
set
[
Apply
],
visited_nodes
:
set
[
Apply
],
fuseable_clients
:
FUSEABLE_MAPPING
,
fuseable_clients
:
FUSEABLE_MAPPING
,
unfuseable_clients
:
UNFUSEABLE_MAPPING
,
unfuseable_clients
:
UNFUSEABLE_MAPPING
,
starting_nodes
:
set
[
Apply
],
starting_nodes
:
set
[
Apply
],
updated_nodes
:
set
[
Apply
],
)
->
None
:
)
->
None
:
# Find new composite node and dropped intermediate nodes
# Find new composite node and dropped intermediate nodes
# by comparing the current fg.apply nodes with the cached
# by comparing the current fg.apply nodes with the cached
# original nodes
# original nodes
next_nodes
=
fg
.
apply_nodes
(
new_composite_node
,)
=
updated_nodes
-
starting_nodes
(
new_composite_node
,)
=
next_nodes
-
starting_nodes
dropped_nodes
=
starting_nodes
-
updated_nodes
dropped_nodes
=
starting_nodes
-
next_nodes
# Remove intermediate Composite nodes from mappings
# Remove intermediate Composite nodes from mappings
for
dropped_node
in
dropped_nodes
:
for
dropped_node
in
dropped_nodes
:
...
@@ -850,11 +834,11 @@ class FusionOptimizer(GraphRewriter):
...
@@ -850,11 +834,11 @@ class FusionOptimizer(GraphRewriter):
# Update fuseable information for subgraph inputs
# Update fuseable information for subgraph inputs
for
inp
in
subgraph_inputs
:
for
inp
in
subgraph_inputs
:
if
inp
in
fuseable_clients
:
if
inp
in
fuseable_clients
:
new_fuseable_clients
=
[
new_fuseable_clients
=
{
client
client
for
client
in
fuseable_clients
[
inp
]
for
client
in
fuseable_clients
[
inp
]
if
client
not
in
dropped_nodes
if
client
not
in
dropped_nodes
]
}
if
new_fuseable_clients
:
if
new_fuseable_clients
:
fuseable_clients
[
inp
]
=
new_fuseable_clients
fuseable_clients
[
inp
]
=
new_fuseable_clients
else
:
else
:
...
@@ -898,13 +882,15 @@ class FusionOptimizer(GraphRewriter):
...
@@ -898,13 +882,15 @@ class FusionOptimizer(GraphRewriter):
# generator. For large models (as in `TestFusion.test_big_fusion`)
# generator. For large models (as in `TestFusion.test_big_fusion`)
# this can provide huge speedups
# this can provide huge speedups
update_fuseable_mappings_after_fg_replace
(
update_fuseable_mappings_after_fg_replace
(
fg
=
fg
,
visited_nodes
=
visited_nodes
,
visited_nodes
=
visited_nodes
,
fuseable_clients
=
fuseable_clients
,
fuseable_clients
=
fuseable_clients
,
unfuseable_clients
=
unfuseable_clients
,
unfuseable_clients
=
unfuseable_clients
,
starting_nodes
=
starting_nodes
,
starting_nodes
=
starting_nodes
,
updated_nodes
=
fg
.
apply_nodes
,
)
)
max_operands
=
elemwise_max_operands_fct
(
None
)
reason
=
self
.
__class__
.
__name__
nb_fused
=
0
nb_fused
=
0
nb_replacement
=
0
nb_replacement
=
0
for
inputs
,
outputs
in
find_next_fuseable_subgraph
(
fgraph
):
for
inputs
,
outputs
in
find_next_fuseable_subgraph
(
fgraph
):
...
@@ -923,13 +909,12 @@ class FusionOptimizer(GraphRewriter):
...
@@ -923,13 +909,12 @@ class FusionOptimizer(GraphRewriter):
assert
len
(
outputs
)
==
len
(
composite_outputs
)
assert
len
(
outputs
)
==
len
(
composite_outputs
)
for
old_out
,
composite_out
in
zip
(
outputs
,
composite_outputs
):
for
old_out
,
composite_out
in
zip
(
outputs
,
composite_outputs
):
# Preserve any names on the original outputs
# Preserve any names on the original outputs
if
old_out
.
name
:
if
old_
name
:
=
old_
out
.
name
:
composite_out
.
name
=
old_
out
.
name
composite_out
.
name
=
old_name
starting_nodes
=
len
(
fgraph
.
apply_nodes
)
starting_nodes
=
len
(
fgraph
.
apply_nodes
)
fgraph
.
replace_all_validate
(
fgraph
.
replace_all_validate
(
list
(
zip
(
outputs
,
composite_outputs
,
strict
=
True
)),
tuple
(
zip
(
outputs
,
composite_outputs
)),
reason
=
reason
reason
=
self
.
__class__
.
__name__
,
)
)
nb_fused
+=
1
nb_fused
+=
1
nb_replacement
+=
(
starting_nodes
-
len
(
fgraph
.
apply_nodes
))
+
1
nb_replacement
+=
(
starting_nodes
-
len
(
fgraph
.
apply_nodes
))
+
1
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论