Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
aa7e4d6b
提交
aa7e4d6b
authored
9月 01, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
9月 20, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Speedup FunctionGraph methods
上级
066307f0
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
44 行增加
和
69 行删除
+44
-69
fg.py
pytensor/graph/fg.py
+44
-69
没有找到文件。
pytensor/graph/fg.py
浏览文件 @
aa7e4d6b
...
...
@@ -24,7 +24,6 @@ from pytensor.graph.traversal import (
vars_between
,
)
from
pytensor.graph.utils
import
MetaObject
,
MissingInputError
,
TestValueError
from
pytensor.misc.ordered_set
import
OrderedSet
ClientType
=
tuple
[
Apply
,
int
]
...
...
@@ -133,7 +132,6 @@ class FunctionGraph(MetaObject):
features
=
[]
self
.
_features
:
list
[
Feature
]
=
[]
# All apply nodes in the subgraph defined by inputs and
# outputs are cached in this field
self
.
apply_nodes
:
set
[
Apply
]
=
set
()
...
...
@@ -161,7 +159,8 @@ class FunctionGraph(MetaObject):
"input's owner or use graph.clone."
)
self
.
add_input
(
in_var
,
check
=
False
)
self
.
inputs
.
append
(
in_var
)
self
.
clients
.
setdefault
(
in_var
,
[])
for
output
in
outputs
:
self
.
add_output
(
output
,
reason
=
"init"
)
...
...
@@ -189,16 +188,6 @@ class FunctionGraph(MetaObject):
return
self
.
inputs
.
append
(
var
)
self
.
setup_var
(
var
)
def
setup_var
(
self
,
var
:
Variable
)
->
None
:
"""Set up a variable so it belongs to this `FunctionGraph`.
Parameters
----------
var : pytensor.graph.basic.Variable
"""
self
.
clients
.
setdefault
(
var
,
[])
def
get_clients
(
self
,
var
:
Variable
)
->
list
[
ClientType
]:
...
...
@@ -322,10 +311,11 @@ class FunctionGraph(MetaObject):
"""
# Imports the owners of the variables
if
var
.
owner
and
var
.
owner
not
in
self
.
apply_nodes
:
self
.
import_node
(
var
.
owner
,
reason
=
reason
,
import_missing
=
import_missing
)
apply
=
var
.
owner
if
apply
is
not
None
and
apply
not
in
self
.
apply_nodes
:
self
.
import_node
(
apply
,
reason
=
reason
,
import_missing
=
import_missing
)
elif
(
var
.
owner
is
None
apply
is
None
and
not
isinstance
(
var
,
AtomicVariable
)
and
var
not
in
self
.
inputs
):
...
...
@@ -336,10 +326,11 @@ class FunctionGraph(MetaObject):
f
"Computation graph contains a NaN. {var.type.why_null}"
)
if
import_missing
:
self
.
add_input
(
var
)
self
.
inputs
.
append
(
var
)
self
.
clients
.
setdefault
(
var
,
[])
else
:
raise
MissingInputError
(
f
"Undeclared input: {var}"
,
variable
=
var
)
self
.
setup_var
(
var
)
self
.
clients
.
setdefault
(
var
,
[]
)
self
.
variables
.
add
(
var
)
def
import_node
(
...
...
@@ -356,29 +347,29 @@ class FunctionGraph(MetaObject):
apply_node : Apply
The node to be imported.
check : bool
Check that the inputs for the imported nodes are also present in
the `FunctionGraph`.
Check that the inputs for the imported nodes are also present in the `FunctionGraph`.
reason : str
The name of the optimization or operation in progress.
import_missing : bool
Add missing inputs instead of raising an exception.
"""
# We import the nodes in topological order. We only are interested in
# new nodes, so we use all
variables we know of as if they were the
# input set. (The functions in the graph module only use the input set
# to know where to stop going down.)
new_nodes
=
tuple
(
toposort
(
apply_node
.
outputs
,
blockers
=
self
.
variables
))
if
check
:
for
node
in
new_nodes
:
# new nodes, so we use all
nodes we know of as inputs to interrupt the toposort
self_variables
=
self
.
variables
self_clients
=
self
.
clients
self_apply_nodes
=
self
.
apply_nodes
self_inputs
=
self
.
inputs
for
node
in
toposort
(
apply_node
.
outputs
,
blockers
=
self_variables
)
:
if
check
:
for
var
in
node
.
inputs
:
if
(
var
.
owner
is
None
and
not
isinstance
(
var
,
AtomicVariable
)
and
var
not
in
self
.
inputs
and
var
not
in
self
_
inputs
):
if
import_missing
:
self
.
add_input
(
var
)
self_inputs
.
append
(
var
)
self_clients
.
setdefault
(
var
,
[])
else
:
error_msg
=
(
f
"Input {node.inputs.index(var)} ({var})"
...
...
@@ -390,20 +381,20 @@ class FunctionGraph(MetaObject):
)
raise
MissingInputError
(
error_msg
,
variable
=
var
)
for
node
in
new_nodes
:
assert
node
not
in
self
.
apply_nodes
self
.
apply_nodes
.
add
(
node
)
if
not
hasattr
(
node
.
tag
,
"imported_by"
):
node
.
tag
.
imported_by
=
[]
node
.
tag
.
imported_by
.
append
(
str
(
reason
))
self_apply_nodes
.
add
(
node
)
tag
=
node
.
tag
if
not
hasattr
(
tag
,
"imported_by"
):
tag
.
imported_by
=
[
str
(
reason
)]
else
:
tag
.
imported_by
.
append
(
str
(
reason
))
for
output
in
node
.
outputs
:
self
.
setup_var
(
output
)
self
.
variables
.
add
(
output
)
for
i
,
inp
ut
in
enumerate
(
node
.
inputs
):
if
inp
ut
not
in
self
.
variables
:
self
.
setup_var
(
input
)
self
.
variables
.
add
(
input
)
self
.
add_client
(
input
,
(
node
,
i
))
self
_clients
.
setdefault
(
output
,
[]
)
self
_
variables
.
add
(
output
)
for
i
,
inp
in
enumerate
(
node
.
inputs
):
if
inp
not
in
self_
variables
:
self
_clients
.
setdefault
(
inp
,
[]
)
self
_variables
.
add
(
inp
)
self
_clients
[
inp
]
.
append
(
(
node
,
i
))
self
.
execute_callbacks
(
"on_import"
,
node
,
reason
)
def
change_node_input
(
...
...
@@ -457,7 +448,7 @@ class FunctionGraph(MetaObject):
self
.
outputs
[
node
.
op
.
idx
]
=
new_var
self
.
import_var
(
new_var
,
reason
=
reason
,
import_missing
=
import_missing
)
self
.
add_client
(
new_var
,
(
node
,
i
))
self
.
clients
[
new_var
]
.
append
(
(
node
,
i
))
self
.
remove_client
(
r
,
(
node
,
i
),
reason
=
reason
)
# Precondition: the substitution is semantically valid However it may
# introduce cycles to the graph, in which case the transaction will be
...
...
@@ -756,10 +747,6 @@ class FunctionGraph(MetaObject):
:meth:`FunctionGraph.orderings`.
"""
if
len
(
self
.
apply_nodes
)
<
2
:
# No sorting is necessary
return
list
(
self
.
apply_nodes
)
return
list
(
toposort_with_orderings
(
self
.
outputs
,
orderings
=
self
.
orderings
()))
def
orderings
(
self
)
->
dict
[
Apply
,
list
[
Apply
]]:
...
...
@@ -779,29 +766,17 @@ class FunctionGraph(MetaObject):
take care of computing the dependencies by itself.
"""
assert
isinstance
(
self
.
_features
,
list
)
all_orderings
:
list
[
dict
]
=
[]
all_orderings
:
list
[
dict
]
=
[
orderings
for
feature
in
self
.
_features
if
(
hasattr
(
feature
,
"orderings"
)
and
(
orderings
:
=
feature
.
orderings
(
self
))
)
]
for
feature
in
self
.
_features
:
if
hasattr
(
feature
,
"orderings"
):
orderings
=
feature
.
orderings
(
self
)
if
not
isinstance
(
orderings
,
dict
):
raise
TypeError
(
"Non-deterministic return value from "
+
str
(
feature
.
orderings
)
+
". Nondeterministic object is "
+
str
(
orderings
)
)
if
len
(
orderings
)
>
0
:
all_orderings
.
append
(
orderings
)
for
node
,
prereqs
in
orderings
.
items
():
if
not
isinstance
(
prereqs
,
list
|
OrderedSet
):
raise
TypeError
(
"prereqs must be a type with a "
"deterministic iteration order, or toposort "
" will be non-deterministic."
)
if
len
(
all_orderings
)
==
1
:
if
not
all_orderings
:
return
{}
elif
len
(
all_orderings
)
==
1
:
# If there is only 1 ordering, we reuse it directly.
return
all_orderings
[
0
]
.
copy
()
else
:
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论