Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
d58d482a
提交
d58d482a
authored
2月 16, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
5月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add FunctionGraph methods add_output, remove_node, remove_input, remove_output
上级
124ed5df
显示空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
478 行增加
和
32 行删除
+478
-32
fg.py
aesara/graph/fg.py
+167
-20
test_fg.py
tests/graph/test_fg.py
+302
-5
utils.py
tests/graph/utils.py
+9
-7
没有找到文件。
aesara/graph/fg.py
浏览文件 @
d58d482a
...
@@ -126,13 +126,13 @@ class FunctionGraph(MetaObject):
...
@@ -126,13 +126,13 @@ class FunctionGraph(MetaObject):
# outputs are cached in this field
# outputs are cached in this field
self
.
apply_nodes
:
Set
[
Apply
]
=
set
()
self
.
apply_nodes
:
Set
[
Apply
]
=
set
()
#
Ditto for variable nodes.
#
It includes inputs, outputs, and all intermediate variables
#
It must contain all fgraph.inputs and all apply_nodes
#
connecting the inputs and outputs. It also contains irrelevant
# outputs
even if they aren't used in the graph
.
# outputs
the nodes in `self.apply_nodes`
.
self
.
variables
:
Set
[
Variable
]
=
set
()
self
.
variables
:
Set
[
Variable
]
=
set
()
self
.
inputs
:
List
[
Variable
]
=
[]
self
.
inputs
:
List
[
Variable
]
=
[]
self
.
outputs
:
List
[
Variable
]
=
list
(
outputs
)
self
.
outputs
:
List
[
Variable
]
=
[]
self
.
clients
:
Dict
[
Variable
,
List
[
ClientType
]]
=
{}
self
.
clients
:
Dict
[
Variable
,
List
[
ClientType
]]
=
{}
for
f
in
features
:
for
f
in
features
:
...
@@ -152,13 +152,19 @@ class FunctionGraph(MetaObject):
...
@@ -152,13 +152,19 @@ class FunctionGraph(MetaObject):
self
.
add_input
(
in_var
,
check
=
False
)
self
.
add_input
(
in_var
,
check
=
False
)
for
output
in
outputs
:
for
output
in
outputs
:
self
.
import_var
(
output
,
reason
=
"init"
)
self
.
add_output
(
output
,
reason
=
"init"
)
for
i
,
output
in
enumerate
(
outputs
):
self
.
clients
[
output
]
.
append
((
"output"
,
i
))
self
.
profile
=
None
self
.
profile
=
None
self
.
update_mapping
=
update_mapping
self
.
update_mapping
=
update_mapping
def
add_output
(
self
,
var
:
Variable
,
reason
:
Optional
[
str
]
=
None
,
import_missing
:
bool
=
False
):
"""Add a new variable as an output to this `FunctionGraph`."""
self
.
outputs
.
append
(
var
)
self
.
import_var
(
var
,
reason
=
reason
,
import_missing
=
import_missing
)
self
.
clients
[
var
]
.
append
((
"output"
,
len
(
self
.
outputs
)
-
1
))
def
add_input
(
self
,
var
:
Variable
,
check
:
bool
=
True
)
->
None
:
def
add_input
(
self
,
var
:
Variable
,
check
:
bool
=
True
)
->
None
:
"""Add a new variable as an input to this `FunctionGraph`.
"""Add a new variable as an input to this `FunctionGraph`.
...
@@ -172,7 +178,6 @@ class FunctionGraph(MetaObject):
...
@@ -172,7 +178,6 @@ class FunctionGraph(MetaObject):
self
.
inputs
.
append
(
var
)
self
.
inputs
.
append
(
var
)
self
.
setup_var
(
var
)
self
.
setup_var
(
var
)
self
.
variables
.
add
(
var
)
def
setup_var
(
self
,
var
:
Variable
)
->
None
:
def
setup_var
(
self
,
var
:
Variable
)
->
None
:
"""Set up a variable so it belongs to this `FunctionGraph`.
"""Set up a variable so it belongs to this `FunctionGraph`.
...
@@ -210,6 +215,7 @@ class FunctionGraph(MetaObject):
...
@@ -210,6 +215,7 @@ class FunctionGraph(MetaObject):
var
:
Variable
,
var
:
Variable
,
client_to_remove
:
ClientType
,
client_to_remove
:
ClientType
,
reason
:
Optional
[
str
]
=
None
,
reason
:
Optional
[
str
]
=
None
,
remove_if_empty
:
bool
=
False
,
)
->
None
:
)
->
None
:
"""Recursively remove clients of a variable.
"""Recursively remove clients of a variable.
...
@@ -222,11 +228,14 @@ class FunctionGraph(MetaObject):
...
@@ -222,11 +228,14 @@ class FunctionGraph(MetaObject):
Parameters
Parameters
----------
----------
var
: Variable
var
The clients of `var` that will be removed.
The clients of `var` that will be removed.
client_to_remove
: pair of (Apply, int)
client_to_remove
A ``(node, i)`` pair such that ``node.inputs[i]`` will no longer be
A ``(node, i)`` pair such that ``node.inputs[i]`` will no longer be
`var` in this `FunctionGraph`.
`var` in this `FunctionGraph`.
remove_if_empty
When ``True``, if `var`'s `Apply` node is removed, remove the
entry for `var` in `self.clients`.
"""
"""
...
@@ -250,8 +259,6 @@ class FunctionGraph(MetaObject):
...
@@ -250,8 +259,6 @@ class FunctionGraph(MetaObject):
# Now, `var` has no more clients, so check if we need to remove it
# Now, `var` has no more clients, so check if we need to remove it
# and its `Apply` node
# and its `Apply` node
if
not
var
.
owner
:
if
not
var
.
owner
:
# The `var` is a `Constant` or an input without a client, so we
# remove it
self
.
variables
.
remove
(
var
)
self
.
variables
.
remove
(
var
)
else
:
else
:
apply_node
=
var
.
owner
apply_node
=
var
.
owner
...
@@ -274,12 +281,15 @@ class FunctionGraph(MetaObject):
...
@@ -274,12 +281,15 @@ class FunctionGraph(MetaObject):
for
i
,
in_var
in
enumerate
(
apply_node
.
inputs
):
for
i
,
in_var
in
enumerate
(
apply_node
.
inputs
):
removal_stack
.
append
((
in_var
,
(
apply_node
,
i
)))
removal_stack
.
append
((
in_var
,
(
apply_node
,
i
)))
if
remove_if_empty
:
del
self
.
clients
[
var
]
def
import_var
(
def
import_var
(
self
,
var
:
Variable
,
reason
:
Optional
[
str
]
=
None
,
import_missing
:
bool
=
False
self
,
var
:
Variable
,
reason
:
Optional
[
str
]
=
None
,
import_missing
:
bool
=
False
)
->
None
:
)
->
None
:
"""Import
variables
into this `FunctionGraph`.
"""Import
a `Variable`
into this `FunctionGraph`.
This will
also import the `variable`'s `Apply` node
.
This will
import the `var`'s `Apply` node and inputs
.
Parameters
Parameters
----------
----------
...
@@ -517,6 +527,147 @@ class FunctionGraph(MetaObject):
...
@@ -517,6 +527,147 @@ class FunctionGraph(MetaObject):
for
var
,
new_var
in
pairs
:
for
var
,
new_var
in
pairs
:
self
.
replace
(
var
,
new_var
,
**
kwargs
)
self
.
replace
(
var
,
new_var
,
**
kwargs
)
def
_remove_output
(
self
,
idx
:
int
):
"""Remove the output at index `idx` and update the indices in the clients entries.
`FunctionGraph.clients` contains entries like ``("output", i)`` under
each output variable in `FunctionGraph.outputs`. The ``i`` values
correspond to each output's location within the `FunctionGraph.outputs`
list, so, when an output is removed from the graph, all these entries
need to be updated. This method performs those updates.
TODO: We could track these entries in a new instance attribute and make
them lists, then each could be updated in-place very easily. This
seems fine, because the `FunctionGraph.clients` ``dict`` and list in
which they're contained are already being updated in-place.
"""
old_idx_mappings
=
tuple
((
out
,
i
)
for
i
,
out
in
enumerate
(
self
.
outputs
))
self
.
outputs
.
pop
(
idx
)
new_idx
=
0
for
(
out
,
old_idx
)
in
old_idx_mappings
:
if
old_idx
==
idx
:
continue
out_clients
=
self
.
clients
[
out
]
arrow
:
ClientType
=
(
"output"
,
old_idx
)
arrow_idx
=
out_clients
.
index
(
arrow
)
out_clients
[
arrow_idx
]
=
(
"output"
,
new_idx
)
new_idx
+=
1
def
remove_node
(
self
,
node
:
Apply
,
reason
:
Optional
[
str
]
=
None
):
"""Remove an `Apply` node from the `FunctionGraph`.
This will remove everything that depends on the outputs of `node`, as
well as any "orphaned" variables and nodes created by `node`'s removal.
"""
if
node
not
in
self
.
apply_nodes
:
return
self
.
apply_nodes
.
remove
(
node
)
if
not
hasattr
(
node
.
tag
,
"removed_by"
):
node
.
tag
.
removed_by
=
[]
node
.
tag
.
removed_by
.
append
(
str
(
reason
))
# Remove the outputs of the node (i.e. everything "below" it)
for
out
in
node
.
outputs
:
self
.
variables
.
remove
(
out
)
out_clients
=
self
.
clients
.
get
(
out
,
())
while
out_clients
:
out_client
,
out_idx
=
out_clients
.
pop
()
if
out_client
==
"output"
:
self
.
_remove_output
(
out_idx
)
# TODO: We could short-circuit all of the graph walking and
# clear everything at once when all the outputs are gone.
# if not self.outputs:
# self.clients = {inp: [] for inp in self.inputs}
# self.variables = set()
# while self.apply_nodes:
# node = self.apply_nodes.pop()
# if not hasattr(node.tag, "removed_by"):
# node.tag.removed_by = []
#
# node.tag.removed_by.append(str(reason))
#
# self.execute_callbacks("on_prune", node, reason)
else
:
assert
isinstance
(
out_client
,
Apply
)
self
.
remove_node
(
out_client
,
reason
=
reason
)
if
out
in
self
.
clients
:
del
self
.
clients
[
out
]
# Remove all the arrows pointing to this `node`, and any orphaned
# variables created by removing those arrows
for
inp_idx
,
inp
in
enumerate
(
node
.
inputs
):
inp_clients
:
List
[
ClientType
]
=
self
.
clients
.
get
(
inp
,
[])
arrow
=
(
node
,
inp_idx
)
if
arrow
not
in
inp_clients
:
continue
inp_clients
.
remove
(
arrow
)
if
not
inp_clients
and
inp
not
in
self
.
outputs
:
if
inp
.
owner
:
# If this input has no clients (after removing this arrow),
# is not an input (i.e. it has a non-`None` owner) or an
# output to the `FunctionGraph`, then it's an orphan
# We need to check whether or not this orphaned input's
# node is still needed in the graph
inp_node
=
inp
.
owner
if
not
any
(
out
in
self
.
variables
for
out
in
inp_node
.
outputs
if
out
is
not
inp
):
self
.
remove_node
(
inp_node
,
reason
=
reason
)
else
:
# This is an unused input
self
.
variables
.
remove
(
inp
)
# The callbacks be triggered after everything has been removed so that
# the `FunctionGraph` state subscribers see is valid.
self
.
execute_callbacks
(
"on_prune"
,
node
,
reason
)
def
remove_input
(
self
,
input_idx
:
int
,
reason
:
Optional
[
str
]
=
None
):
"""Remove the input at index `input_idx`."""
var
=
self
.
inputs
.
pop
(
input_idx
)
for
client
,
idx
in
list
(
self
.
clients
[
var
]):
if
client
==
"output"
:
out_var
=
self
.
outputs
[
idx
]
out_node
=
out_var
.
owner
if
out_node
is
None
:
assert
out_var
in
self
.
inputs
self
.
outputs
.
pop
(
idx
)
continue
client_node
=
out_node
else
:
assert
isinstance
(
client
,
Apply
)
client_node
=
client
self
.
remove_node
(
client_node
,
reason
=
reason
)
def
remove_output
(
self
,
output_idx
:
int
,
reason
:
Optional
[
str
]
=
None
):
"""Remove the output at index `input_idx`."""
var
=
self
.
outputs
[
output_idx
]
self
.
_remove_output
(
output_idx
)
self
.
remove_client
(
var
,
(
"output"
,
output_idx
),
reason
=
reason
,
remove_if_empty
=
True
)
def
attach_feature
(
self
,
feature
:
Feature
)
->
None
:
def
attach_feature
(
self
,
feature
:
Feature
)
->
None
:
"""Add a ``graph.features.Feature`` to this function graph and trigger its ``on_attach`` callback."""
"""Add a ``graph.features.Feature`` to this function graph and trigger its ``on_attach`` callback."""
# Filter out literally identical `Feature`s
# Filter out literally identical `Feature`s
...
@@ -668,9 +819,7 @@ class FunctionGraph(MetaObject):
...
@@ -668,9 +819,7 @@ class FunctionGraph(MetaObject):
nodes_missing
=
nodes
.
difference
(
self
.
apply_nodes
)
nodes_missing
=
nodes
.
difference
(
self
.
apply_nodes
)
nodes_excess
=
self
.
apply_nodes
.
difference
(
nodes
)
nodes_excess
=
self
.
apply_nodes
.
difference
(
nodes
)
raise
Exception
(
raise
Exception
(
"The nodes are inappropriately cached. missing, in excess: "
,
f
"The following nodes are inappropriately cached:
\n
missing: {nodes_missing}
\n
in excess: {nodes_excess}"
nodes_missing
,
nodes_excess
,
)
)
for
node
in
nodes
:
for
node
in
nodes
:
for
i
,
variable
in
enumerate
(
node
.
inputs
):
for
i
,
variable
in
enumerate
(
node
.
inputs
):
...
@@ -684,9 +833,7 @@ class FunctionGraph(MetaObject):
...
@@ -684,9 +833,7 @@ class FunctionGraph(MetaObject):
vars_missing
=
variables
.
difference
(
self
.
variables
)
vars_missing
=
variables
.
difference
(
self
.
variables
)
vars_excess
=
self
.
variables
.
difference
(
variables
)
vars_excess
=
self
.
variables
.
difference
(
variables
)
raise
Exception
(
raise
Exception
(
"The variables are inappropriately cached. missing, in excess: "
,
f
"The following variables are inappropriately cached:
\n
missing: {vars_missing}
\n
in excess: {vars_excess}"
vars_missing
,
vars_excess
,
)
)
for
variable
in
variables
:
for
variable
in
variables
:
if
(
if
(
...
...
tests/graph/test_fg.py
浏览文件 @
d58d482a
...
@@ -6,7 +6,7 @@ import pytest
...
@@ -6,7 +6,7 @@ import pytest
from
aesara.configdefaults
import
config
from
aesara.configdefaults
import
config
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.utils
import
MissingInputError
from
aesara.graph.utils
import
MissingInputError
from
tests.graph.utils
import
MyConstant
,
MyVariable
,
MyVariable2
,
op1
,
op2
,
op3
from
tests.graph.utils
import
MyConstant
,
My
Op
,
My
Variable
,
MyVariable2
,
op1
,
op2
,
op3
class
TestFunctionGraph
:
class
TestFunctionGraph
:
...
@@ -60,7 +60,7 @@ class TestFunctionGraph:
...
@@ -60,7 +60,7 @@ class TestFunctionGraph:
assert
fg
.
variables
==
{
var1
,
var2
,
var3
,
var4
}
assert
fg
.
variables
==
{
var1
,
var2
,
var3
,
var4
}
assert
fg
.
get_clients
(
var1
)
==
[(
var3
.
owner
,
0
)]
assert
fg
.
get_clients
(
var1
)
==
[(
var3
.
owner
,
0
)]
assert
fg
.
get_clients
(
var2
)
==
[(
var4
.
owner
,
1
)]
assert
fg
.
get_clients
(
var2
)
==
[(
var4
.
owner
,
1
)]
assert
fg
.
get_clients
(
var3
)
==
[(
var4
.
owner
,
0
),
(
"output"
,
0
)]
assert
fg
.
get_clients
(
var3
)
==
[(
"output"
,
0
),
(
var4
.
owner
,
0
)]
assert
fg
.
get_clients
(
var4
)
==
[(
"output"
,
1
)]
assert
fg
.
get_clients
(
var4
)
==
[(
"output"
,
1
)]
varC
=
MyConstant
(
"varC"
)
varC
=
MyConstant
(
"varC"
)
...
@@ -304,7 +304,7 @@ class TestFunctionGraph:
...
@@ -304,7 +304,7 @@ class TestFunctionGraph:
# FIXME TODO XXX: This breaks the state of the `FunctionGraph`,
# FIXME TODO XXX: This breaks the state of the `FunctionGraph`,
# because it doesn't check for validity of the replacement *first*.
# because it doesn't check for validity of the replacement *first*.
fg
.
replace
(
var1
,
var0
,
verbose
=
True
)
fg
.
replace
(
var1
,
var0
)
def
test_check_integrity
(
self
):
def
test_check_integrity
(
self
):
...
@@ -315,7 +315,7 @@ class TestFunctionGraph:
...
@@ -315,7 +315,7 @@ class TestFunctionGraph:
var5
=
op3
(
var4
,
var2
,
var2
)
var5
=
op3
(
var4
,
var2
,
var2
)
fg
=
FunctionGraph
([
var1
,
var2
],
[
var3
,
var5
],
clone
=
False
)
fg
=
FunctionGraph
([
var1
,
var2
],
[
var3
,
var5
],
clone
=
False
)
with
pytest
.
raises
(
Exception
,
match
=
"The nodes are .*"
):
with
pytest
.
raises
(
Exception
,
match
=
"The
following
nodes are .*"
):
fg
.
apply_nodes
.
remove
(
var5
.
owner
)
fg
.
apply_nodes
.
remove
(
var5
.
owner
)
fg
.
check_integrity
()
fg
.
check_integrity
()
...
@@ -328,7 +328,7 @@ class TestFunctionGraph:
...
@@ -328,7 +328,7 @@ class TestFunctionGraph:
fg
.
add_client
(
var2
,
(
var5
.
owner
,
1
))
fg
.
add_client
(
var2
,
(
var5
.
owner
,
1
))
with
pytest
.
raises
(
Exception
,
match
=
"The variables are.*"
):
with
pytest
.
raises
(
Exception
,
match
=
"The
following
variables are.*"
):
fg
.
variables
.
remove
(
var4
)
fg
.
variables
.
remove
(
var4
)
fg
.
check_integrity
()
fg
.
check_integrity
()
...
@@ -386,3 +386,300 @@ class TestFunctionGraph:
...
@@ -386,3 +386,300 @@ class TestFunctionGraph:
assert
var3
.
owner
in
fg
assert
var3
.
owner
in
fg
assert
var5
in
fg
assert
var5
in
fg
assert
var5
.
owner
in
fg
assert
var5
.
owner
in
fg
def
test_remove_node
(
self
):
var1
=
MyVariable
(
"var1"
)
var2
=
MyVariable
(
"var2"
)
node1_out
=
op1
(
var1
)
node2_out
=
op2
(
var2
,
node1_out
)
node3_out
=
op3
(
node2_out
)
fg
=
FunctionGraph
([
var1
,
var2
],
[
node3_out
],
clone
=
False
)
fg
.
remove_node
(
node3_out
.
owner
)
fg
.
check_integrity
()
assert
not
fg
.
apply_nodes
fg
=
FunctionGraph
([
var1
,
var2
],
[
node2_out
,
node3_out
],
clone
=
False
)
fg
.
remove_node
(
node3_out
.
owner
)
fg
.
check_integrity
()
assert
fg
.
apply_nodes
==
{
node1_out
.
owner
,
node2_out
.
owner
}
fg
=
FunctionGraph
([
var1
,
var2
],
[
node2_out
,
node3_out
],
clone
=
False
)
fg
.
remove_node
(
node2_out
.
owner
)
fg
.
check_integrity
()
assert
not
fg
.
apply_nodes
def
test_remove_output
(
self
):
var1
=
MyVariable
(
"var1"
)
var2
=
MyVariable
(
"var2"
)
node1_out
=
op1
(
var1
)
node2_out
=
op2
(
var2
,
node1_out
)
node3_out
=
op3
(
node2_out
)
fg
=
FunctionGraph
([
var1
,
var2
],
[
node2_out
,
node3_out
],
clone
=
False
)
fg
.
remove_output
(
0
)
fg
.
check_integrity
()
assert
fg
.
apply_nodes
==
{
node1_out
.
owner
,
node2_out
.
owner
,
node3_out
.
owner
}
assert
fg
.
inputs
==
[
var1
,
var2
]
assert
fg
.
outputs
==
[
node3_out
]
fg
=
FunctionGraph
([
var1
,
var2
],
[
node2_out
,
node3_out
],
clone
=
False
)
fg
.
remove_output
(
1
)
fg
.
check_integrity
()
assert
fg
.
apply_nodes
==
{
node1_out
.
owner
,
node2_out
.
owner
}
assert
fg
.
inputs
==
[
var1
,
var2
]
assert
fg
.
outputs
==
[
node2_out
]
fg
=
FunctionGraph
([
var1
,
var2
],
[
node2_out
,
node3_out
,
var1
],
clone
=
False
)
fg
.
remove_output
(
2
)
fg
.
check_integrity
()
assert
fg
.
apply_nodes
==
{
node1_out
.
owner
,
node2_out
.
owner
,
node3_out
.
owner
}
assert
fg
.
inputs
==
[
var1
,
var2
]
assert
fg
.
outputs
==
[
node2_out
,
node3_out
]
fg
=
FunctionGraph
([
var1
,
var2
],
[
var1
],
clone
=
False
)
fg
.
remove_output
(
0
)
fg
.
check_integrity
()
assert
fg
.
inputs
==
[
var1
,
var2
]
assert
fg
.
outputs
==
[]
def
test_remove_output_2
(
self
):
var0
=
MyVariable
(
"var0"
)
var1
=
MyVariable
(
"var1"
)
var2
=
MyVariable
(
"var2"
)
var3
=
MyVariable
(
"var3"
)
var4
=
MyVariable
(
"var4"
)
op1_out
=
op1
(
var1
,
var0
)
out0
=
op2
(
op1_out
,
var2
)
out1
=
op1
(
var3
,
var4
)
out1
.
name
=
"out1"
out2
=
op1
(
out1
,
var0
)
out2
.
name
=
"out2"
out3
=
out1
fg
=
FunctionGraph
(
[
var0
,
var1
,
var2
,
var3
,
var4
],
[
out0
,
out1
,
out2
,
out3
],
clone
=
False
,
)
fg
.
remove_output
(
1
)
fg
.
check_integrity
()
assert
fg
.
outputs
==
[
out0
,
out2
,
out3
]
fg
=
FunctionGraph
(
[
var0
,
var1
,
var2
,
var3
,
var4
],
[
out0
,
out1
,
out2
,
var4
,
var4
],
clone
=
False
,
)
fg
.
remove_output
(
3
)
fg
.
check_integrity
()
assert
fg
.
inputs
==
[
var0
,
var1
,
var2
,
var3
,
var4
]
assert
fg
.
outputs
==
[
out0
,
out1
,
out2
,
var4
]
def
test_remove_output_3
(
self
):
var0
=
MyVariable
(
"var0"
)
var1
=
MyVariable
(
"var1"
)
var2
=
MyVariable
(
"var2"
)
var3
=
MyVariable
(
"var3"
)
var4
=
MyVariable
(
"var4"
)
var5
=
MyVariable
(
"var5"
)
var6
=
MyVariable
(
"var6"
)
op1_out
=
op1
(
var1
,
var0
)
out0
=
op2
(
op1_out
,
var2
)
out1
=
op1
(
var3
,
var4
)
out1
.
name
=
"out1"
out2
=
op1
(
op1_out
,
var5
)
out2
.
name
=
"out2"
out3
=
op1
(
var3
,
var6
)
out3
.
name
=
"out3"
out4
=
op1_out
out5
=
var3
fg
=
FunctionGraph
(
[
var0
,
var1
,
var2
,
var3
,
var4
,
var5
,
var6
],
[
out0
,
out1
,
out2
,
out3
,
out4
,
out5
],
clone
=
False
,
)
fg
.
remove_output
(
1
)
fg
.
check_integrity
()
assert
fg
.
inputs
==
[
var0
,
var1
,
var2
,
var3
,
var4
,
var5
,
var6
]
assert
fg
.
outputs
==
[
out0
,
out2
,
out3
,
out4
,
out5
]
assert
out1
not
in
fg
.
clients
def
test_remove_input
(
self
):
var0
=
MyVariable
(
"var0"
)
var1
=
MyVariable
(
"var1"
)
var2
=
MyVariable
(
"var2"
)
var3
=
MyVariable
(
"var3"
)
var4
=
MyVariable
(
"var4"
)
op1_out
=
op1
(
var1
,
var0
)
out0
=
op2
(
op1_out
,
var2
)
out1
=
op1
(
var3
,
var4
)
out1
.
name
=
"out1"
out2
=
op1
(
out1
,
var0
)
out2
.
name
=
"out2"
out3
=
out1
fg
=
FunctionGraph
(
[
var0
,
var1
,
var2
,
var3
,
var4
],
[
out0
,
out1
,
out2
,
out3
],
clone
=
False
,
)
fg
.
remove_input
(
4
)
fg
.
check_integrity
()
assert
fg
.
inputs
==
[
var0
,
var1
,
var2
,
var3
]
assert
fg
.
outputs
==
[
out0
]
def
test_remove_in_and_out
(
self
):
var1
=
MyVariable
(
"var1"
)
var2
=
MyVariable
(
"var2"
)
op1_out
=
op1
(
var2
,
var1
)
op2_out
=
op2
(
op1_out
,
var2
)
op3_out
=
op3
(
op2_out
,
var2
,
var2
)
fg
=
FunctionGraph
([
var1
,
var2
],
[
op1_out
,
op3_out
],
clone
=
False
)
# Remove an output
fg
.
remove_output
(
1
)
fg
.
check_integrity
()
assert
fg
.
outputs
==
[
op1_out
]
assert
op3_out
not
in
fg
.
clients
assert
not
any
(
op3_out
.
owner
in
clients
for
clients
in
sum
(
fg
.
clients
.
values
(),
[])
)
# Remove an input
fg
.
remove_input
(
0
)
fg
.
check_integrity
()
assert
var1
not
in
fg
.
variables
assert
fg
.
inputs
==
[
var2
]
assert
fg
.
outputs
==
[]
assert
not
any
(
op1_out
.
owner
in
clients
for
clients
in
sum
(
fg
.
clients
.
values
(),
[])
)
def
test_remove_duplicates
(
self
):
var1
=
MyVariable
(
"var1"
)
var2
=
MyVariable
(
"var2"
)
op1_out
=
op1
(
var2
,
var1
)
op2_out
=
op2
(
op1_out
,
var2
)
op3_out
=
op3
(
op2_out
,
var2
,
var2
)
fg
=
FunctionGraph
([
var1
,
var1
,
var2
],
[
op1_out
,
op3_out
,
op3_out
],
clone
=
False
)
fg
.
remove_output
(
2
)
fg
.
check_integrity
()
assert
fg
.
outputs
==
[
op1_out
,
op3_out
]
fg
.
remove_input
(
0
)
fg
.
check_integrity
()
assert
var1
not
in
fg
.
variables
assert
fg
.
inputs
==
[
var1
,
var2
]
assert
fg
.
outputs
==
[]
def
test_remove_output_empty
(
self
):
var1
=
MyVariable
(
"var1"
)
var2
=
MyVariable
(
"var2"
)
op1_out
=
op1
(
var1
)
op3_out
=
op3
(
op1_out
,
var2
)
fg
=
FunctionGraph
([
var1
,
var2
],
[
op3_out
],
clone
=
False
)
fg
.
remove_output
(
0
)
fg
.
check_integrity
()
assert
fg
.
inputs
==
[
var1
,
var2
]
assert
not
fg
.
apply_nodes
assert
op1_out
not
in
fg
.
clients
assert
not
any
(
op1_out
.
owner
in
clients
for
clients
in
sum
(
fg
.
clients
.
values
(),
[])
)
assert
not
any
(
op3_out
.
owner
in
clients
for
clients
in
sum
(
fg
.
clients
.
values
(),
[])
)
def
test_remove_node_multi_out
(
self
):
var1
=
MyVariable
(
"var1"
)
var2
=
MyVariable
(
"var2"
)
multi_op
=
MyOp
(
"mop"
,
n_outs
=
2
)
op1_out
=
op1
(
var1
)
mop_out_1
,
mop_out_2
=
multi_op
(
op1_out
,
var2
)
op3_out
=
op3
(
mop_out_2
)
fg
=
FunctionGraph
([
var1
,
var2
],
[
mop_out_1
,
op3_out
],
clone
=
False
)
fg
.
remove_node
(
mop_out_1
.
owner
)
fg
.
check_integrity
()
assert
fg
.
inputs
==
[
var1
,
var2
]
assert
fg
.
outputs
==
[]
assert
mop_out_1
not
in
fg
.
clients
assert
mop_out_2
not
in
fg
.
clients
assert
mop_out_1
not
in
fg
.
variables
assert
mop_out_2
not
in
fg
.
variables
mop1_out_1
,
mop1_out_2
=
multi_op
(
var1
)
op2_out
=
op2
(
mop1_out_1
)
op3_out
=
op3
(
mop1_out_1
,
mop1_out_2
)
fg
=
FunctionGraph
([
var1
],
[
op2_out
,
op3_out
],
clone
=
False
)
fg
.
remove_node
(
op3_out
.
owner
)
fg
.
check_integrity
()
assert
fg
.
inputs
==
[
var1
]
assert
fg
.
outputs
==
[
op2_out
]
# If we only want to track "active" variables in the graphs, the
# following would need to be true, as well
# assert mop1_out_2 not in fg.clients
# assert mop1_out_2 not in fg.variables
fg
=
FunctionGraph
([
var1
],
[
op2_out
,
op3_out
,
mop1_out_2
],
clone
=
False
)
fg
.
remove_node
(
op3_out
.
owner
)
fg
.
check_integrity
()
assert
fg
.
inputs
==
[
var1
]
assert
fg
.
outputs
==
[
op2_out
,
mop1_out_2
]
assert
mop1_out_2
in
fg
.
clients
assert
mop1_out_2
in
fg
.
variables
assert
mop1_out_2
in
fg
.
outputs
def
test_empty
(
self
):
var1
=
MyVariable
(
"var1"
)
var2
=
MyVariable
(
"var2"
)
fg
=
FunctionGraph
([
var1
,
var2
],
[],
clone
=
False
)
fg
.
check_integrity
()
assert
fg
.
inputs
==
[
var1
,
var2
]
assert
fg
.
outputs
==
[]
assert
not
fg
.
variables
assert
not
fg
.
apply_nodes
assert
fg
.
clients
==
{
var1
:
[],
var2
:
[]}
tests/graph/utils.py
浏览文件 @
d58d482a
...
@@ -46,19 +46,20 @@ def MyVariable2(name):
...
@@ -46,19 +46,20 @@ def MyVariable2(name):
class
MyOp
(
Op
):
class
MyOp
(
Op
):
def
__init__
(
self
,
name
,
dmap
=
None
,
x
=
None
):
def
__init__
(
self
,
name
,
dmap
=
None
,
x
=
None
,
n_outs
=
1
):
self
.
name
=
name
self
.
name
=
name
if
dmap
is
None
:
if
dmap
is
None
:
dmap
=
{}
dmap
=
{}
self
.
destroy_map
=
dmap
self
.
destroy_map
=
dmap
self
.
x
=
x
self
.
x
=
x
self
.
n_outs
=
n_outs
def
make_node
(
self
,
*
inputs
):
def
make_node
(
self
,
*
inputs
):
inputs
=
list
(
map
(
is_variable
,
inputs
))
inputs
=
list
(
map
(
is_variable
,
inputs
))
for
input
in
inputs
:
for
input
in
inputs
:
if
not
isinstance
(
input
.
type
,
MyType
):
if
not
isinstance
(
input
.
type
,
MyType
):
raise
Exception
(
"Error 1"
)
raise
Exception
(
"Error 1"
)
outputs
=
[
MyType
()()]
outputs
=
[
MyType
()()
for
i
in
range
(
self
.
n_outs
)
]
return
Apply
(
self
,
inputs
,
outputs
)
return
Apply
(
self
,
inputs
,
outputs
)
def
perform
(
self
,
node
,
inputs
,
outputs
):
def
perform
(
self
,
node
,
inputs
,
outputs
):
...
@@ -71,18 +72,19 @@ class MyOp(Op):
...
@@ -71,18 +72,19 @@ class MyOp(Op):
return
self
.
name
return
self
.
name
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
# rval = (self is other) or (isinstance(other, MyOp) and self.x is not None and self.x == other.x and self.name == other.name)
rval
=
(
self
is
other
)
or
(
rval
=
(
self
is
other
)
or
(
isinstance
(
other
,
MyOp
)
and
self
.
x
is
not
None
and
self
.
x
==
other
.
x
isinstance
(
other
,
MyOp
)
and
self
.
x
is
not
None
and
self
.
x
==
other
.
x
and
self
.
n_outs
==
other
.
n_outs
)
)
return
rval
return
rval
def
__hash__
(
self
):
def
__hash__
(
self
):
# return hash(self.x if self.x is not None else id(self)) ^ hash(self.name)
if
self
.
x
is
not
None
:
if
self
.
x
is
not
None
:
return
hash
(
self
.
x
)
return
hash
(
(
self
.
x
,
self
.
n_outs
)
)
else
:
else
:
return
id
(
self
)
return
hash
((
id
(
self
),
self
.
n_outs
)
)
class
MyOpCastType2
(
MyOp
):
class
MyOpCastType2
(
MyOp
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论