Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
befc177d
提交
befc177d
authored
12月 09, 2022
作者:
Maxim Kochurov
提交者:
Ricardo Vieira
12月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add graph_replace function
上级
1b673567
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
174 行增加
和
4 行删除
+174
-4
__init__.py
pytensor/__init__.py
+1
-1
__init__.py
pytensor/graph/__init__.py
+1
-1
replace.py
pytensor/graph/replace.py
+92
-1
test_replace.py
tests/graph/test_replace.py
+80
-1
没有找到文件。
pytensor/__init__.py
浏览文件 @
befc177d
...
@@ -74,7 +74,7 @@ __api_version__ = 1
...
@@ -74,7 +74,7 @@ __api_version__ = 1
# isort: off
# isort: off
from
pytensor.graph.basic
import
Variable
from
pytensor.graph.basic
import
Variable
from
pytensor.graph.replace
import
clone_replace
from
pytensor.graph.replace
import
clone_replace
,
graph_replace
# isort: on
# isort: on
...
...
pytensor/graph/__init__.py
浏览文件 @
befc177d
...
@@ -9,7 +9,7 @@ from pytensor.graph.basic import (
...
@@ -9,7 +9,7 @@ from pytensor.graph.basic import (
clone
,
clone
,
ancestors
,
ancestors
,
)
)
from
pytensor.graph.replace
import
clone_replace
from
pytensor.graph.replace
import
clone_replace
,
graph_replace
from
pytensor.graph.op
import
Op
from
pytensor.graph.op
import
Op
from
pytensor.graph.type
import
Type
from
pytensor.graph.type
import
Type
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
...
...
pytensor/graph/replace.py
浏览文件 @
befc177d
from
functools
import
partial
from
typing
import
(
from
typing
import
(
Collection
,
Collection
,
Dict
,
Dict
,
...
@@ -10,7 +11,8 @@ from typing import (
...
@@ -10,7 +11,8 @@ from typing import (
cast
,
cast
,
)
)
from
pytensor.graph.basic
import
Constant
,
Variable
from
pytensor.graph.basic
import
Constant
,
Variable
,
truncated_graph_inputs
from
pytensor.graph.fg
import
FunctionGraph
def
clone_replace
(
def
clone_replace
(
...
@@ -58,3 +60,92 @@ def clone_replace(
...
@@ -58,3 +60,92 @@ def clone_replace(
_
,
outs
,
_
=
rebuild_collect_shared
(
_outs
,
[],
new_replace
,
[],
**
rebuild_kwds
)
_
,
outs
,
_
=
rebuild_collect_shared
(
_outs
,
[],
new_replace
,
[],
**
rebuild_kwds
)
return
cast
(
List
[
Variable
],
outs
)
return
cast
(
List
[
Variable
],
outs
)
def
graph_replace
(
outputs
:
Sequence
[
Variable
],
replace
:
Dict
[
Variable
,
Variable
],
*
,
strict
=
True
,
)
->
List
[
Variable
]:
"""Replace variables in ``outputs`` by ``replace``.
Parameters
----------
outputs: Sequence[Variable]
Output graph
replace: Dict[Variable, Variable]
Replace mapping
strict: bool
Raise an error if some replacements were not used
return_unused: bool
Return replacements that were not used
Returns
-------
List[Variable]
Output graph with subgraphs replaced
Raises
------
ValueError
If some replacemens could not be applied and strict is True
"""
# collect minimum graph inputs which is required to compute outputs
# and depend on replacements
# additionally remove constants, they do not matter in clone get equiv
conditions
=
[
c
for
c
in
truncated_graph_inputs
(
outputs
,
replace
)
if
not
isinstance
(
c
,
Constant
)
]
# for the function graph we need the clean graph where
# inputs do not have owners
# this is exactly the reason to clone conditions
equiv
=
{
c
:
c
.
clone
(
name
=
f
"i-{i}"
)
for
i
,
c
in
enumerate
(
conditions
)}
# some replace keys may dissapear
# the reason is they are outside the graph
# clone the graph but preserve the equiv mapping
fg
=
FunctionGraph
(
conditions
,
outputs
,
# clone_get_equiv kwargs
copy_orphans
=
False
,
copy_inputs
=
False
,
memo
=
equiv
,
)
# replace the conditions back
fg_replace
=
{
equiv
[
c
]:
c
for
c
in
conditions
}
# add the replacements on top of input mappings
fg_replace
.
update
({
equiv
[
r
]:
v
for
r
,
v
in
replace
.
items
()
if
r
in
equiv
})
# replacements have to be done in reverse topological order so that nested
# expressions get recursively replaced correctly
# some replacements may be initially outside the graph
# but later introduced by a replacement
# So far FunctionGraph does these replacements inplace it is thus unsafe
# apply them using fg.replace, it may change the original graph
if
strict
:
non_fg_replace
=
{
r
:
v
for
r
,
v
in
replace
.
items
()
if
r
not
in
equiv
}
if
non_fg_replace
:
raise
ValueError
(
f
"Some replacements were not used: {non_fg_replace}"
)
toposort
=
fg
.
toposort
()
def
toposort_key
(
fg
:
FunctionGraph
,
ts
,
pair
):
key
,
_
=
pair
if
key
.
owner
is
not
None
:
return
ts
.
index
(
key
.
owner
)
else
:
if
key
in
fg
.
variables
:
return
-
1
else
:
raise
ValueError
(
f
"{key} is not a part of graph"
)
sorted_replacements
=
sorted
(
tuple
(
fg_replace
.
items
()),
# sort based on the fg toposort, if a variable has no owner, it goes first
key
=
partial
(
toposort_key
,
fg
,
toposort
),
reverse
=
True
,
)
fg
.
replace_all
(
sorted_replacements
,
import_missing
=
True
)
return
list
(
fg
.
outputs
)
tests/graph/test_replace.py
浏览文件 @
befc177d
...
@@ -4,7 +4,7 @@ import pytest
...
@@ -4,7 +4,7 @@ import pytest
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor
import
config
,
function
,
shared
from
pytensor
import
config
,
function
,
shared
from
pytensor.graph.basic
import
graph_inputs
from
pytensor.graph.basic
import
graph_inputs
from
pytensor.graph.replace
import
clone_replace
from
pytensor.graph.replace
import
clone_replace
,
graph_replace
from
pytensor.tensor
import
dvector
,
fvector
,
vector
from
pytensor.tensor
import
dvector
,
fvector
,
vector
from
tests
import
unittest_tools
as
utt
from
tests
import
unittest_tools
as
utt
from
tests.graph.utils
import
MyOp
,
MyVariable
from
tests.graph.utils
import
MyOp
,
MyVariable
...
@@ -133,3 +133,82 @@ class TestCloneReplace:
...
@@ -133,3 +133,82 @@ class TestCloneReplace:
utt
.
assert_allclose
(
utt
.
assert_allclose
(
test
(
x
,
pt
.
sum
((
x
+
1
)
**
2
),
mention_y
=
True
),
1.21000003815
test
(
x
,
pt
.
sum
((
x
+
1
)
**
2
),
mention_y
=
True
),
1.21000003815
)
)
class
TestGraphReplace
:
def
test_graph_replace
(
self
):
x
=
MyVariable
(
"x"
)
y
=
MyVariable
(
"y"
)
z
=
MyVariable
(
"z"
)
w
=
MyVariable
(
"w"
)
MyOp
(
"zop"
)(
z
)
x2
=
MyOp
(
"xop"
)(
x
,
w
)
x2
.
name
=
"x2"
y2
=
MyOp
(
"yop"
)(
y
)
y2
.
name
=
"y2"
yc
=
graph_replace
([
x2
],
{
x
:
y2
})[
0
]
assert
yc
.
owner
.
inputs
[
0
]
is
y2
# the old reference is kept
assert
yc
.
owner
.
inputs
[
1
]
is
w
# test replace itself
yc
=
graph_replace
([
x2
],
{
x2
:
y2
})[
0
]
assert
yc
is
y2
assert
yc
.
owner
.
inputs
[
0
]
is
y
assert
len
(
yc
.
owner
.
inputs
)
==
1
# the case where inputs have to be replaced in reverse topological order
o
=
MyOp
(
"xyop"
)(
x2
,
y2
)
new_x
=
x
.
clone
(
name
=
"x_new"
)
new_y2
=
y2
.
clone
(
name
=
"y2_new"
)
oc
=
graph_replace
([
o
],
{
x
:
new_x
,
y2
:
new_y2
})[
0
]
assert
oc
.
owner
.
inputs
[
1
]
is
new_y2
assert
oc
.
owner
.
inputs
[
0
]
.
owner
.
inputs
[
0
]
is
new_x
# the old reference is still kept
assert
oc
.
owner
.
inputs
[
0
]
.
owner
.
inputs
[
1
]
is
w
def
test_graph_replace_advanced
(
self
):
x
=
MyVariable
(
"x"
)
y
=
MyVariable
(
"y"
)
z
=
MyVariable
(
"z"
)
w
=
MyVariable
(
"w"
)
z2
=
MyOp
(
"zop"
)(
z
)
x2
=
MyOp
(
"xop"
)(
x
,
w
)
x2
.
name
=
"x2"
y2
=
MyOp
(
"yop"
)(
y
)
y2
.
name
=
"y2"
o
=
MyOp
(
"xyop"
)(
x2
,
y2
)
new_x
=
x
.
clone
(
name
=
"x_new"
)
new_y2
=
y2
.
clone
(
name
=
"y2_new"
)
new_y21
=
MyOp
(
"ny2op"
)(
new_y2
)
# now yet another replacement that could only appear after new_y2: z
# show we can do that after the prev clone
# the case where new variable is referenced during the replacements
new_y21
=
MyOp
(
"ny2op"
)(
new_y2
)
# the reference new_y2: z2 is not a part of the original graph so the replacement is unsafe
oc
=
graph_replace
([
o
],
{
x
:
new_x
,
y2
:
new_y21
})
oc
=
graph_replace
(
oc
,
{
new_y2
:
z2
})[
0
]
assert
oc
.
owner
.
inputs
[
1
]
.
owner
.
inputs
[
0
]
is
z2
assert
oc
.
owner
.
inputs
[
0
]
.
owner
.
inputs
[
0
]
is
new_x
# the old reference is still kept
assert
oc
.
owner
.
inputs
[
0
]
.
owner
.
inputs
[
1
]
is
w
new_z
=
z
.
clone
(
name
=
"z_new"
)
oc
=
graph_replace
([
oc
],
{
z
:
new_z
})[
0
]
# new reference appear
assert
oc
.
owner
.
inputs
[
1
]
.
owner
.
inputs
[
0
]
is
not
z2
assert
oc
.
owner
.
inputs
[
1
]
.
owner
.
inputs
[
0
]
.
owner
.
inputs
[
0
]
is
new_z
# the old reference is still kept
assert
oc
.
owner
.
inputs
[
0
]
.
owner
.
inputs
[
0
]
is
new_x
assert
oc
.
owner
.
inputs
[
0
]
.
owner
.
inputs
[
1
]
is
w
def
test_graph_replace_disconnected
(
self
):
x
=
MyVariable
(
"x"
)
fake
=
MyOp
(
"fake"
)(
x
)
o
=
MyOp
(
"o"
)(
x
)
oc
=
graph_replace
([
o
],
{
fake
:
x
.
clone
()},
strict
=
False
)
assert
oc
[
0
]
is
o
with
pytest
.
raises
(
ValueError
,
match
=
"Some replacements were not used"
):
oc
=
graph_replace
([
o
],
{
fake
:
x
.
clone
()},
strict
=
True
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论