Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
aa7e4d6b
提交
aa7e4d6b
authored
9月 01, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
9月 20, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Speedup FunctionGraph methods
上级
066307f0
显示空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
43 行增加
和
68 行删除
+43
-68
fg.py
pytensor/graph/fg.py
+43
-68
没有找到文件。
pytensor/graph/fg.py
浏览文件 @
aa7e4d6b
...
@@ -24,7 +24,6 @@ from pytensor.graph.traversal import (
...
@@ -24,7 +24,6 @@ from pytensor.graph.traversal import (
vars_between
,
vars_between
,
)
)
from
pytensor.graph.utils
import
MetaObject
,
MissingInputError
,
TestValueError
from
pytensor.graph.utils
import
MetaObject
,
MissingInputError
,
TestValueError
from
pytensor.misc.ordered_set
import
OrderedSet
ClientType
=
tuple
[
Apply
,
int
]
ClientType
=
tuple
[
Apply
,
int
]
...
@@ -133,7 +132,6 @@ class FunctionGraph(MetaObject):
...
@@ -133,7 +132,6 @@ class FunctionGraph(MetaObject):
features
=
[]
features
=
[]
self
.
_features
:
list
[
Feature
]
=
[]
self
.
_features
:
list
[
Feature
]
=
[]
# All apply nodes in the subgraph defined by inputs and
# All apply nodes in the subgraph defined by inputs and
# outputs are cached in this field
# outputs are cached in this field
self
.
apply_nodes
:
set
[
Apply
]
=
set
()
self
.
apply_nodes
:
set
[
Apply
]
=
set
()
...
@@ -161,7 +159,8 @@ class FunctionGraph(MetaObject):
...
@@ -161,7 +159,8 @@ class FunctionGraph(MetaObject):
"input's owner or use graph.clone."
"input's owner or use graph.clone."
)
)
self
.
add_input
(
in_var
,
check
=
False
)
self
.
inputs
.
append
(
in_var
)
self
.
clients
.
setdefault
(
in_var
,
[])
for
output
in
outputs
:
for
output
in
outputs
:
self
.
add_output
(
output
,
reason
=
"init"
)
self
.
add_output
(
output
,
reason
=
"init"
)
...
@@ -189,16 +188,6 @@ class FunctionGraph(MetaObject):
...
@@ -189,16 +188,6 @@ class FunctionGraph(MetaObject):
return
return
self
.
inputs
.
append
(
var
)
self
.
inputs
.
append
(
var
)
self
.
setup_var
(
var
)
def
setup_var
(
self
,
var
:
Variable
)
->
None
:
"""Set up a variable so it belongs to this `FunctionGraph`.
Parameters
----------
var : pytensor.graph.basic.Variable
"""
self
.
clients
.
setdefault
(
var
,
[])
self
.
clients
.
setdefault
(
var
,
[])
def
get_clients
(
self
,
var
:
Variable
)
->
list
[
ClientType
]:
def
get_clients
(
self
,
var
:
Variable
)
->
list
[
ClientType
]:
...
@@ -322,10 +311,11 @@ class FunctionGraph(MetaObject):
...
@@ -322,10 +311,11 @@ class FunctionGraph(MetaObject):
"""
"""
# Imports the owners of the variables
# Imports the owners of the variables
if
var
.
owner
and
var
.
owner
not
in
self
.
apply_nodes
:
apply
=
var
.
owner
self
.
import_node
(
var
.
owner
,
reason
=
reason
,
import_missing
=
import_missing
)
if
apply
is
not
None
and
apply
not
in
self
.
apply_nodes
:
self
.
import_node
(
apply
,
reason
=
reason
,
import_missing
=
import_missing
)
elif
(
elif
(
var
.
owner
is
None
apply
is
None
and
not
isinstance
(
var
,
AtomicVariable
)
and
not
isinstance
(
var
,
AtomicVariable
)
and
var
not
in
self
.
inputs
and
var
not
in
self
.
inputs
):
):
...
@@ -336,10 +326,11 @@ class FunctionGraph(MetaObject):
...
@@ -336,10 +326,11 @@ class FunctionGraph(MetaObject):
f
"Computation graph contains a NaN. {var.type.why_null}"
f
"Computation graph contains a NaN. {var.type.why_null}"
)
)
if
import_missing
:
if
import_missing
:
self
.
add_input
(
var
)
self
.
inputs
.
append
(
var
)
self
.
clients
.
setdefault
(
var
,
[])
else
:
else
:
raise
MissingInputError
(
f
"Undeclared input: {var}"
,
variable
=
var
)
raise
MissingInputError
(
f
"Undeclared input: {var}"
,
variable
=
var
)
self
.
setup_var
(
var
)
self
.
clients
.
setdefault
(
var
,
[]
)
self
.
variables
.
add
(
var
)
self
.
variables
.
add
(
var
)
def
import_node
(
def
import_node
(
...
@@ -356,29 +347,29 @@ class FunctionGraph(MetaObject):
...
@@ -356,29 +347,29 @@ class FunctionGraph(MetaObject):
apply_node : Apply
apply_node : Apply
The node to be imported.
The node to be imported.
check : bool
check : bool
Check that the inputs for the imported nodes are also present in
Check that the inputs for the imported nodes are also present in the `FunctionGraph`.
the `FunctionGraph`.
reason : str
reason : str
The name of the optimization or operation in progress.
The name of the optimization or operation in progress.
import_missing : bool
import_missing : bool
Add missing inputs instead of raising an exception.
Add missing inputs instead of raising an exception.
"""
"""
# We import the nodes in topological order. We only are interested in
# We import the nodes in topological order. We only are interested in
# new nodes, so we use all variables we know of as if they were the
# new nodes, so we use all nodes we know of as inputs to interrupt the toposort
# input set. (The functions in the graph module only use the input set
self_variables
=
self
.
variables
# to know where to stop going down.)
self_clients
=
self
.
clients
new_nodes
=
tuple
(
toposort
(
apply_node
.
outputs
,
blockers
=
self
.
variables
))
self_apply_nodes
=
self
.
apply_nodes
self_inputs
=
self
.
inputs
for
node
in
toposort
(
apply_node
.
outputs
,
blockers
=
self_variables
):
if
check
:
if
check
:
for
node
in
new_nodes
:
for
var
in
node
.
inputs
:
for
var
in
node
.
inputs
:
if
(
if
(
var
.
owner
is
None
var
.
owner
is
None
and
not
isinstance
(
var
,
AtomicVariable
)
and
not
isinstance
(
var
,
AtomicVariable
)
and
var
not
in
self
.
inputs
and
var
not
in
self
_
inputs
):
):
if
import_missing
:
if
import_missing
:
self
.
add_input
(
var
)
self_inputs
.
append
(
var
)
self_clients
.
setdefault
(
var
,
[])
else
:
else
:
error_msg
=
(
error_msg
=
(
f
"Input {node.inputs.index(var)} ({var})"
f
"Input {node.inputs.index(var)} ({var})"
...
@@ -390,20 +381,20 @@ class FunctionGraph(MetaObject):
...
@@ -390,20 +381,20 @@ class FunctionGraph(MetaObject):
)
)
raise
MissingInputError
(
error_msg
,
variable
=
var
)
raise
MissingInputError
(
error_msg
,
variable
=
var
)
for
node
in
new_nodes
:
self_apply_nodes
.
add
(
node
)
assert
node
not
in
self
.
apply_nodes
tag
=
node
.
tag
self
.
apply_nodes
.
add
(
node
)
if
not
hasattr
(
tag
,
"imported_by"
):
if
not
hasattr
(
node
.
tag
,
"imported_by"
):
tag
.
imported_by
=
[
str
(
reason
)]
node
.
tag
.
imported_by
=
[]
else
:
node
.
tag
.
imported_by
.
append
(
str
(
reason
))
tag
.
imported_by
.
append
(
str
(
reason
))
for
output
in
node
.
outputs
:
for
output
in
node
.
outputs
:
self
.
setup_var
(
output
)
self
_clients
.
setdefault
(
output
,
[]
)
self
.
variables
.
add
(
output
)
self
_
variables
.
add
(
output
)
for
i
,
inp
ut
in
enumerate
(
node
.
inputs
):
for
i
,
inp
in
enumerate
(
node
.
inputs
):
if
inp
ut
not
in
self
.
variables
:
if
inp
not
in
self_
variables
:
self
.
setup_var
(
input
)
self
_clients
.
setdefault
(
inp
,
[]
)
self
.
variables
.
add
(
input
)
self
_variables
.
add
(
inp
)
self
.
add_client
(
input
,
(
node
,
i
))
self
_clients
[
inp
]
.
append
(
(
node
,
i
))
self
.
execute_callbacks
(
"on_import"
,
node
,
reason
)
self
.
execute_callbacks
(
"on_import"
,
node
,
reason
)
def
change_node_input
(
def
change_node_input
(
...
@@ -457,7 +448,7 @@ class FunctionGraph(MetaObject):
...
@@ -457,7 +448,7 @@ class FunctionGraph(MetaObject):
self
.
outputs
[
node
.
op
.
idx
]
=
new_var
self
.
outputs
[
node
.
op
.
idx
]
=
new_var
self
.
import_var
(
new_var
,
reason
=
reason
,
import_missing
=
import_missing
)
self
.
import_var
(
new_var
,
reason
=
reason
,
import_missing
=
import_missing
)
self
.
add_client
(
new_var
,
(
node
,
i
))
self
.
clients
[
new_var
]
.
append
(
(
node
,
i
))
self
.
remove_client
(
r
,
(
node
,
i
),
reason
=
reason
)
self
.
remove_client
(
r
,
(
node
,
i
),
reason
=
reason
)
# Precondition: the substitution is semantically valid However it may
# Precondition: the substitution is semantically valid However it may
# introduce cycles to the graph, in which case the transaction will be
# introduce cycles to the graph, in which case the transaction will be
...
@@ -756,10 +747,6 @@ class FunctionGraph(MetaObject):
...
@@ -756,10 +747,6 @@ class FunctionGraph(MetaObject):
:meth:`FunctionGraph.orderings`.
:meth:`FunctionGraph.orderings`.
"""
"""
if
len
(
self
.
apply_nodes
)
<
2
:
# No sorting is necessary
return
list
(
self
.
apply_nodes
)
return
list
(
toposort_with_orderings
(
self
.
outputs
,
orderings
=
self
.
orderings
()))
return
list
(
toposort_with_orderings
(
self
.
outputs
,
orderings
=
self
.
orderings
()))
def
orderings
(
self
)
->
dict
[
Apply
,
list
[
Apply
]]:
def
orderings
(
self
)
->
dict
[
Apply
,
list
[
Apply
]]:
...
@@ -779,29 +766,17 @@ class FunctionGraph(MetaObject):
...
@@ -779,29 +766,17 @@ class FunctionGraph(MetaObject):
take care of computing the dependencies by itself.
take care of computing the dependencies by itself.
"""
"""
assert
isinstance
(
self
.
_features
,
list
)
all_orderings
:
list
[
dict
]
=
[
all_orderings
:
list
[
dict
]
=
[]
orderings
for
feature
in
self
.
_features
for
feature
in
self
.
_features
:
if
(
if
hasattr
(
feature
,
"orderings"
):
hasattr
(
feature
,
"orderings"
)
and
(
orderings
:
=
feature
.
orderings
(
self
))
orderings
=
feature
.
orderings
(
self
)
if
not
isinstance
(
orderings
,
dict
):
raise
TypeError
(
"Non-deterministic return value from "
+
str
(
feature
.
orderings
)
+
". Nondeterministic object is "
+
str
(
orderings
)
)
if
len
(
orderings
)
>
0
:
all_orderings
.
append
(
orderings
)
for
node
,
prereqs
in
orderings
.
items
():
if
not
isinstance
(
prereqs
,
list
|
OrderedSet
):
raise
TypeError
(
"prereqs must be a type with a "
"deterministic iteration order, or toposort "
" will be non-deterministic."
)
)
if
len
(
all_orderings
)
==
1
:
]
if
not
all_orderings
:
return
{}
elif
len
(
all_orderings
)
==
1
:
# If there is only 1 ordering, we reuse it directly.
# If there is only 1 ordering, we reuse it directly.
return
all_orderings
[
0
]
.
copy
()
return
all_orderings
[
0
]
.
copy
()
else
:
else
:
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论