Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
14d2454c
Unverified
提交
14d2454c
authored
7月 26, 2023
作者:
Maxim Kochurov
提交者:
GitHub
7月 26, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Unify signatures of `graph_replace` and `clone_replace` (#398)
* more type hints
上级
e9a7d7ce
显示空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
205 行增加
和
44 行删除
+205
-44
pfunc.py
pytensor/compile/function/pfunc.py
+105
-4
replace.py
pytensor/graph/replace.py
+89
-40
test_replace.py
tests/graph/test_replace.py
+11
-0
没有找到文件。
pytensor/compile/function/pfunc.py
浏览文件 @
14d2454c
...
...
@@ -4,7 +4,7 @@ Provide a simple user friendly API.
"""
from
copy
import
copy
from
typing
import
Optional
from
typing
import
Optional
,
Sequence
,
Union
,
overload
from
pytensor.compile.function.types
import
Function
,
UnusedInputError
,
orig_function
from
pytensor.compile.io
import
In
,
Out
...
...
@@ -15,8 +15,9 @@ from pytensor.graph.basic import Constant, Variable, clone_node_and_cache
from
pytensor.graph.fg
import
FunctionGraph
@overload
def
rebuild_collect_shared
(
outputs
,
outputs
:
Variable
,
inputs
=
None
,
replace
=
None
,
updates
=
None
,
...
...
@@ -24,7 +25,107 @@ def rebuild_collect_shared(
copy_inputs_over
=
True
,
no_default_updates
=
False
,
clone_inner_graphs
=
False
,
):
)
->
tuple
[
list
[
Variable
],
Variable
,
tuple
[
dict
[
Variable
,
Variable
],
dict
[
SharedVariable
,
Variable
],
list
[
Variable
],
list
[
SharedVariable
],
],
]:
...
@overload
def
rebuild_collect_shared
(
outputs
:
Sequence
[
Variable
],
inputs
=
None
,
replace
=
None
,
updates
=
None
,
rebuild_strict
=
True
,
copy_inputs_over
=
True
,
no_default_updates
=
False
,
clone_inner_graphs
=
False
,
)
->
tuple
[
list
[
Variable
],
list
[
Variable
],
tuple
[
dict
[
Variable
,
Variable
],
dict
[
SharedVariable
,
Variable
],
list
[
Variable
],
list
[
SharedVariable
],
],
]:
...
@overload
def
rebuild_collect_shared
(
outputs
:
Out
,
inputs
=
None
,
replace
=
None
,
updates
=
None
,
rebuild_strict
=
True
,
copy_inputs_over
=
True
,
no_default_updates
=
False
,
clone_inner_graphs
=
False
,
)
->
tuple
[
list
[
Variable
],
Out
,
tuple
[
dict
[
Variable
,
Variable
],
dict
[
SharedVariable
,
Variable
],
list
[
Variable
],
list
[
SharedVariable
],
],
]:
...
@overload
def
rebuild_collect_shared
(
outputs
:
Sequence
[
Out
],
inputs
=
None
,
replace
=
None
,
updates
=
None
,
rebuild_strict
=
True
,
copy_inputs_over
=
True
,
no_default_updates
=
False
,
clone_inner_graphs
=
False
,
)
->
tuple
[
list
[
Variable
],
list
[
Out
],
tuple
[
dict
[
Variable
,
Variable
],
dict
[
SharedVariable
,
Variable
],
list
[
Variable
],
list
[
SharedVariable
],
],
]:
...
def
rebuild_collect_shared
(
outputs
:
Union
[
Sequence
[
Variable
],
Variable
,
Out
,
Sequence
[
Out
]],
inputs
=
None
,
replace
=
None
,
updates
=
None
,
rebuild_strict
=
True
,
copy_inputs_over
=
True
,
no_default_updates
=
False
,
clone_inner_graphs
=
False
,
)
->
tuple
[
list
[
Variable
],
Union
[
list
[
Variable
],
Variable
,
Out
,
list
[
Out
]],
tuple
[
dict
[
Variable
,
Variable
],
dict
[
SharedVariable
,
Variable
],
list
[
Variable
],
list
[
SharedVariable
],
],
]:
r"""Replace subgraphs of a computational graph.
It returns a set of dictionaries and lists which collect (partial?)
...
...
@@ -260,7 +361,7 @@ def rebuild_collect_shared(
return
(
input_variables
,
cloned_outputs
,
[
clone_d
,
update_d
,
update_expr
,
shared_inputs
]
,
(
clone_d
,
update_d
,
update_expr
,
shared_inputs
)
,
)
...
...
pytensor/graph/replace.py
浏览文件 @
14d2454c
from
functools
import
partial
from
typing
import
(
Collection
,
Dict
,
Iterable
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
,
cast
,
)
from
pytensor.graph.basic
import
Constant
,
Variable
,
truncated_graph_inputs
from
typing
import
Iterable
,
Optional
,
Sequence
,
Union
,
cast
,
overload
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
,
truncated_graph_inputs
from
pytensor.graph.fg
import
FunctionGraph
ReplaceTypes
=
Union
[
Iterable
[
tuple
[
Variable
,
Variable
]],
dict
[
Variable
,
Variable
]]
def
_format_replace
(
replace
:
Optional
[
ReplaceTypes
]
=
None
)
->
dict
[
Variable
,
Variable
]:
items
:
dict
[
Variable
,
Variable
]
if
isinstance
(
replace
,
dict
):
# PyLance has issues with type resolution
items
=
cast
(
dict
[
Variable
,
Variable
],
replace
)
elif
isinstance
(
replace
,
Iterable
):
items
=
dict
(
replace
)
elif
replace
is
None
:
items
=
{}
else
:
raise
ValueError
(
"replace is neither a dictionary, list, "
f
"tuple or None ! The value provided is {replace},"
f
"of type {type(replace)}"
)
return
items
@overload
def
clone_replace
(
output
:
Sequence
[
Variable
],
replace
:
Optional
[
ReplaceTypes
]
=
None
,
**
rebuild_kwds
,
)
->
list
[
Variable
]:
...
@overload
def
clone_replace
(
output
:
Collection
[
Variable
]
,
output
:
Variable
,
replace
:
Optional
[
Union
[
Iterable
[
Tuple
[
Variable
,
Variable
]],
D
ict
[
Variable
,
Variable
]]
Union
[
Iterable
[
tuple
[
Variable
,
Variable
]],
d
ict
[
Variable
,
Variable
]]
]
=
None
,
**
rebuild_kwds
,
)
->
List
[
Variable
]:
)
->
Variable
:
...
def
clone_replace
(
output
:
Union
[
Sequence
[
Variable
],
Variable
],
replace
:
Optional
[
ReplaceTypes
]
=
None
,
**
rebuild_kwds
,
)
->
Union
[
list
[
Variable
],
Variable
]:
"""Clone a graph and replace subgraphs within it.
It returns a copy of the initial subgraph with the corresponding
...
...
@@ -39,19 +68,8 @@ def clone_replace(
"""
from
pytensor.compile.function.pfunc
import
rebuild_collect_shared
items
:
Union
[
List
[
Tuple
[
Variable
,
Variable
]],
Tuple
[
Tuple
[
Variable
,
Variable
],
...
]]
if
isinstance
(
replace
,
dict
):
items
=
list
(
replace
.
items
())
elif
isinstance
(
replace
,
(
list
,
tuple
)):
items
=
replace
elif
replace
is
None
:
items
=
[]
else
:
raise
ValueError
(
"replace is neither a dictionary, list, "
f
"tuple or None ! The value provided is {replace},"
f
"of type {type(replace)}"
)
items
=
list
(
_format_replace
(
replace
)
.
items
())
tmp_replace
=
[(
x
,
x
.
type
())
for
x
,
y
in
items
]
new_replace
=
[(
x
,
y
)
for
((
_
,
x
),
(
_
,
y
))
in
zip
(
tmp_replace
,
items
)]
_
,
_outs
,
_
=
rebuild_collect_shared
(
output
,
[],
tmp_replace
,
[],
**
rebuild_kwds
)
...
...
@@ -59,20 +77,40 @@ def clone_replace(
# TODO Explain why we call it twice ?!
_
,
outs
,
_
=
rebuild_collect_shared
(
_outs
,
[],
new_replace
,
[],
**
rebuild_kwds
)
return
cast
(
List
[
Variable
],
outs
)
return
outs
@overload
def
graph_replace
(
outputs
:
Variable
,
replace
:
Optional
[
ReplaceTypes
]
=
None
,
*
,
strict
=
True
,
)
->
Variable
:
...
@overload
def
graph_replace
(
outputs
:
Sequence
[
Variable
],
replace
:
Dict
[
Variable
,
Variable
],
replace
:
Optional
[
ReplaceTypes
]
=
None
,
*
,
strict
=
True
,
)
->
list
[
Variable
]:
...
def
graph_replace
(
outputs
:
Union
[
Sequence
[
Variable
],
Variable
],
replace
:
Optional
[
ReplaceTypes
]
=
None
,
*
,
strict
=
True
,
)
->
List
[
Variable
]:
)
->
Union
[
list
[
Variable
],
Variable
]:
"""Replace variables in ``outputs`` by ``replace``.
Parameters
----------
outputs:
Sequence[
Variable]
outputs:
Union[Sequence[Variable],
Variable]
Output graph
replace: Dict[Variable, Variable]
Replace mapping
...
...
@@ -83,20 +121,26 @@ def graph_replace(
Returns
-------
List[Variable
]
Output graph with subgraphs replaced
Union[Variable, List[Variable]
]
Output graph with subgraphs replaced
, see function overload for the exact type
Raises
------
ValueError
If some replacemens could not be applied and strict is True
If some replacemen
t
s could not be applied and strict is True
"""
as_list
=
False
if
not
isinstance
(
outputs
,
Sequence
):
outputs
=
[
outputs
]
else
:
as_list
=
True
replace_dict
=
_format_replace
(
replace
)
# collect minimum graph inputs which is required to compute outputs
# and depend on replacements
# additionally remove constants, they do not matter in clone get equiv
conditions
=
[
c
for
c
in
truncated_graph_inputs
(
outputs
,
replace
)
for
c
in
truncated_graph_inputs
(
outputs
,
replace
_dict
)
if
not
isinstance
(
c
,
Constant
)
]
# for the function graph we need the clean graph where
...
...
@@ -117,7 +161,7 @@ def graph_replace(
# replace the conditions back
fg_replace
=
{
equiv
[
c
]:
c
for
c
in
conditions
}
# add the replacements on top of input mappings
fg_replace
.
update
({
equiv
[
r
]:
v
for
r
,
v
in
replace
.
items
()
if
r
in
equiv
})
fg_replace
.
update
({
equiv
[
r
]:
v
for
r
,
v
in
replace
_dict
.
items
()
if
r
in
equiv
})
# replacements have to be done in reverse topological order so that nested
# expressions get recursively replaced correctly
...
...
@@ -126,12 +170,14 @@ def graph_replace(
# So far FunctionGraph does these replacements inplace it is thus unsafe
# apply them using fg.replace, it may change the original graph
if
strict
:
non_fg_replace
=
{
r
:
v
for
r
,
v
in
replace
.
items
()
if
r
not
in
equiv
}
non_fg_replace
=
{
r
:
v
for
r
,
v
in
replace
_dict
.
items
()
if
r
not
in
equiv
}
if
non_fg_replace
:
raise
ValueError
(
f
"Some replacements were not used: {non_fg_replace}"
)
toposort
=
fg
.
toposort
()
def
toposort_key
(
fg
:
FunctionGraph
,
ts
,
pair
):
def
toposort_key
(
fg
:
FunctionGraph
,
ts
:
list
[
Apply
],
pair
:
tuple
[
Variable
,
Variable
]
)
->
int
:
key
,
_
=
pair
if
key
.
owner
is
not
None
:
return
ts
.
index
(
key
.
owner
)
...
...
@@ -148,4 +194,7 @@ def graph_replace(
reverse
=
True
,
)
fg
.
replace_all
(
sorted_replacements
,
import_missing
=
True
)
if
as_list
:
return
list
(
fg
.
outputs
)
else
:
return
fg
.
outputs
[
0
]
tests/graph/test_replace.py
浏览文件 @
14d2454c
...
...
@@ -169,6 +169,17 @@ class TestGraphReplace:
# the old reference is still kept
assert
oc
.
owner
.
inputs
[
0
]
.
owner
.
inputs
[
1
]
is
w
def
test_non_list_input
(
self
):
x
=
MyVariable
(
"x"
)
y
=
MyVariable
(
"y"
)
o
=
MyOp
(
"xyop"
)(
x
,
y
)
new_x
=
x
.
clone
(
name
=
"x_new"
)
new_y
=
y
.
clone
(
name
=
"y2_new"
)
# test non list inputs as well
oc
=
graph_replace
(
o
,
{
x
:
new_x
,
y
:
new_y
})
assert
oc
.
owner
.
inputs
[
1
]
is
new_y
assert
oc
.
owner
.
inputs
[
0
]
is
new_x
def
test_graph_replace_advanced
(
self
):
x
=
MyVariable
(
"x"
)
y
=
MyVariable
(
"y"
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论