Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
01020a18
提交
01020a18
authored
4月 02, 2021
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
4月 07, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add memo option, automatic inputs, and copy options to FunctionGraph
上级
52bad109
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
121 行增加
和
77 行删除
+121
-77
fg.py
aesara/graph/fg.py
+103
-76
test_fg.py
tests/graph/test_fg.py
+17
-0
utils.py
tests/graph/utils.py
+1
-1
没有找到文件。
aesara/graph/fg.py
浏览文件 @
01020a18
...
...
@@ -2,14 +2,14 @@
import
time
import
warnings
from
collections
import
OrderedDict
from
typing
import
Any
,
Dict
,
List
,
NoReturn
,
Optional
,
Tuple
,
Union
import
aesara
from
aesara.configdefaults
import
config
from
aesara.graph.basic
import
Apply
,
Constant
,
Variable
,
applys_between
from
aesara.graph.basic
import
as_string
as
graph_as_string
from
aesara.graph.basic
import
clone
as
clone_graph
from
aesara.graph.basic
import
clone_get_equiv
,
io_toposort
,
vars_between
from
aesara.graph.toolbox
import
AlreadyThere
,
ReplaceValidate
from
aesara.graph.basic
import
clone_get_equiv
,
graph_inputs
,
io_toposort
,
vars_between
from
aesara.graph.toolbox
import
AlreadyThere
,
Feature
,
ReplaceValidate
from
aesara.graph.utils
import
MetaObject
,
TestValueError
,
get_variable_trace_string
from
aesara.misc.ordered_set
import
OrderedSet
...
...
@@ -44,23 +44,23 @@ class FunctionGraph(MetaObject):
A `FunctionGraph` represents a subgraph bound by a set of input variables and
a set of output variables, ie a subgraph that specifies an Aesara function.
The inputs list should contain all the inputs on which the outputs depend.
`
Variable`s of type `Constant
` are not counted as inputs.
`
`Variable``s of type ``Constant`
` are not counted as inputs.
The `FunctionGraph` supports the replace operation which allows to replace
a variable in the subgraph by another, e.g. replace ``(x + x).out`` by
``(2
* x).out``. This is the basis for optimization in Aesara.
a variable in the subgraph by another, e.g. replace ``(x + x).out`` by
``(2
* x).out``. This is the basis for optimization in Aesara.
This class is also responsible for verifying that a graph is valid
(ie, all the dtypes and broadcast patterns are compatible with the
way the
the `Variable`s are used) and for tracking the `Variable
`s with
a `
clients` field that specifies which `Apply` nodes use the `Variable
`.
The `
clients` field combined with the `Variable.owner
` field and the
`
Apply` nodes' `Apply.inputs
` field allows the graph to be traversed in
way the
``Variable``s are used) and for tracking the ``Variable`
`s with
a `
`clients`` field that specifies which ``Apply`` nodes use the ``Variable`
`.
The `
`clients`` field combined with the ``Variable.owner`
` field and the
`
`Apply`` nodes' ``Apply.inputs`
` field allows the graph to be traversed in
both directions.
It can also be extended with new features using
`
FunctionGraph.attach_feature`(<Feature instance>)
.
See `
Feature
` for event types and documentation.
`
`FunctionGraph.attach_feature(<Feature instance>)``
.
See `
`Feature`
` for event types and documentation.
Extra features allow the `FunctionGraph` to verify new properties of
a graph as it is optimized.
...
...
@@ -73,52 +73,59 @@ class FunctionGraph(MetaObject):
This class keeps a pointer to the inputs and outputs, and also modifies
them.
Parameters
----------
inputs
Inputs nodes of the graph, usually declared by the user.
outputs
Outputs nodes of the graph.
clone
If true, we will clone the graph. This is useful to remove the constant
cache problem.
Notes
-----
The intermediate nodes between 'inputs' and 'outputs' are not explicitely
passed.
"""
def
__init__
(
self
,
inputs
,
outputs
,
features
=
None
,
clone
=
True
,
update_mapping
=
None
):
def
__init__
(
self
,
inputs
:
Optional
[
List
[
Variable
]]
=
None
,
outputs
:
Optional
[
List
[
Variable
]]
=
None
,
features
:
Optional
[
List
[
Feature
]]
=
None
,
clone
:
bool
=
True
,
update_mapping
:
Optional
[
Dict
[
Variable
,
Variable
]]
=
None
,
memo
:
Optional
[
Dict
[
Variable
,
Variable
]]
=
None
,
copy_inputs
:
bool
=
True
,
copy_orphans
:
bool
=
True
,
):
"""
Create a
n FunctionGraph which operates on the subgraph bound by
the
inputs and outputs sets
.
Create a
`FunctionGraph` which operates on the subgraph between
the
`inputs` and `outputs`
.
Parameters
----------
inputs : list of aesara.graph.basic.Variable
Inputs nodes of the graph, usually declared by the user
outputs : list of aesara.graph.basic.Variable
Outputs nodes of the graph.
clone : boolean
If true, we will clone the graph. This is useful to remove the
constant cache problem.
features : list of aesara.graph.toolbox.Feature
inputs
Input variables of the graph.
outputs
Output variables of the graph.
clone
If ``True``, the graph will be cloned.
features
A list of features to be added to the `FunctionGraph`.
update_mapping
: dict
Mapping between the
inputs with updates and the outputs
update_mapping
Mapping between the
`inputs` with updates and the `outputs`
corresponding to their updates.
memo
See ``clone_get_equiv``.
copy_inputs
See ``clone_get_equiv``.
copy_orphans
See ``clone_get_equiv``.
"""
if
outputs
is
None
:
raise
ValueError
(
"No outputs specified"
)
if
clone
:
inputs
,
outputs
=
clone_graph
(
inputs
,
outputs
)
if
not
isinstance
(
inputs
,
list
):
raise
TypeError
(
"Argument `inputs` should be a list"
)
if
inputs
is
None
:
inputs
=
[
i
for
i
in
graph_inputs
(
outputs
)]
if
not
isinstance
(
outputs
,
list
):
raise
TypeError
(
"Argument `outputs` should be a list"
)
if
clone
:
memo
=
clone_get_equiv
(
inputs
,
outputs
,
copy_inputs
=
copy_inputs
,
copy_orphans
=
copy_orphans
,
memo
=
memo
,
)
outputs
=
[
memo
[
o
]
for
o
in
outputs
]
inputs
=
[
memo
[
i
]
for
i
in
inputs
]
self
.
execute_callbacks_time
=
0
self
.
execute_callbacks_times
=
{}
...
...
@@ -165,7 +172,7 @@ class FunctionGraph(MetaObject):
self
.
profile
=
None
self
.
update_mapping
=
update_mapping
def
add_input
(
self
,
var
,
check
=
True
)
:
def
add_input
(
self
,
var
:
Variable
,
check
:
bool
=
True
)
->
NoReturn
:
"""Add a new variable as an input to this `FunctionGraph`.
Parameters
...
...
@@ -180,7 +187,7 @@ class FunctionGraph(MetaObject):
self
.
setup_var
(
var
)
self
.
variables
.
add
(
var
)
def
setup_var
(
self
,
var
)
:
def
setup_var
(
self
,
var
:
Variable
)
->
NoReturn
:
"""Set up a variable so it belongs to this `FunctionGraph`.
Parameters
...
...
@@ -190,7 +197,7 @@ class FunctionGraph(MetaObject):
"""
self
.
clients
.
setdefault
(
var
,
[])
def
setup_node
(
self
,
node
)
:
def
setup_node
(
self
,
node
:
Apply
)
->
NoReturn
:
"""Set up node so it belongs to this `FunctionGraph`.
Parameters
...
...
@@ -214,14 +221,8 @@ class FunctionGraph(MetaObject):
" the values must be tuples or lists."
)
def
disown
(
self
):
"""
Cleans up all of this FunctionGraph's nodes and variables so they are
not associated with this FunctionGraph anymore.
The FunctionGraph should not be used anymore after disown is called.
"""
def
disown
(
self
)
->
NoReturn
:
"""Clear internal variables."""
for
f
in
self
.
_features
:
self
.
remove_feature
(
f
)
self
.
clients
=
{}
...
...
@@ -232,11 +233,11 @@ class FunctionGraph(MetaObject):
self
.
profile
=
None
self
.
update_mapping
=
None
def
get_clients
(
self
,
var
)
:
def
get_clients
(
self
,
var
:
Variable
)
->
List
[
Tuple
[
Apply
,
int
]]
:
"""Return a list of all the `(node, i)` pairs such that `node.inputs[i]` is `var`."""
return
self
.
clients
[
var
]
def
add_client
(
self
,
var
,
new_client
)
:
def
add_client
(
self
,
var
:
Variable
,
new_client
:
Tuple
[
Apply
,
int
])
->
NoReturn
:
"""Update the clients of `var` with `new_clients`.
Parameters
...
...
@@ -248,7 +249,9 @@ class FunctionGraph(MetaObject):
"""
self
.
clients
[
var
]
.
append
(
new_client
)
def
remove_client
(
self
,
var
,
client_to_remove
,
reason
=
None
):
def
remove_client
(
self
,
var
:
Variable
,
client_to_remove
:
Tuple
[
Apply
,
int
],
reason
:
str
=
None
)
->
NoReturn
:
"""Recursively removes clients of a variable.
This is the main method to remove variables or `Apply` nodes from
...
...
@@ -312,7 +315,9 @@ class FunctionGraph(MetaObject):
for
i
,
in_var
in
enumerate
(
apply_node
.
inputs
):
removal_stack
.
append
((
in_var
,
(
apply_node
,
i
)))
def
import_var
(
self
,
var
,
reason
=
None
,
import_missing
=
False
):
def
import_var
(
self
,
var
:
Variable
,
reason
:
str
=
None
,
import_missing
:
bool
=
False
)
->
NoReturn
:
"""Import variables into this `FunctionGraph`.
This will also import the `variable`'s `Apply` node.
...
...
@@ -348,7 +353,13 @@ class FunctionGraph(MetaObject):
self
.
setup_var
(
var
)
self
.
variables
.
add
(
var
)
def
import_node
(
self
,
apply_node
,
check
=
True
,
reason
=
None
,
import_missing
=
False
):
def
import_node
(
self
,
apply_node
:
Apply
,
check
:
bool
=
True
,
reason
:
str
=
None
,
import_missing
:
bool
=
False
,
)
->
NoReturn
:
"""Recursively import everything between an `Apply` node and the `FunctionGraph`'s outputs.
Parameters:
...
...
@@ -407,7 +418,14 @@ class FunctionGraph(MetaObject):
self
.
add_client
(
input
,
(
node
,
i
))
self
.
execute_callbacks
(
"on_import"
,
node
,
reason
)
def
change_input
(
self
,
node
,
i
,
new_var
,
reason
=
None
,
import_missing
=
False
):
def
change_input
(
self
,
node
:
Apply
,
i
:
int
,
new_var
:
Variable
,
reason
:
str
=
None
,
import_missing
:
bool
=
False
,
)
->
NoReturn
:
"""Change ``node.inputs[i]`` to `new_var`.
``new_var.type == old_var.type`` must be ``True``, where ``old_var`` is the
...
...
@@ -462,7 +480,14 @@ class FunctionGraph(MetaObject):
# reverted later.
self
.
execute_callbacks
(
"on_change_input"
,
node
,
i
,
r
,
new_var
,
reason
=
reason
)
def
replace
(
self
,
var
,
new_var
,
reason
=
None
,
verbose
=
None
,
import_missing
=
False
):
def
replace
(
self
,
var
:
Variable
,
new_var
:
Variable
,
reason
:
str
=
None
,
verbose
:
bool
=
None
,
import_missing
:
bool
=
False
,
)
->
NoReturn
:
"""Replace a variable in the `FunctionGraph`.
This is the main interface to manipulate the subgraph in `FunctionGraph`.
...
...
@@ -526,12 +551,12 @@ class FunctionGraph(MetaObject):
node
,
i
,
new_var
,
reason
=
reason
,
import_missing
=
import_missing
)
def
replace_all
(
self
,
pairs
,
**
kwargs
)
:
def
replace_all
(
self
,
pairs
:
List
[
Tuple
[
Variable
,
Variable
]],
**
kwargs
)
->
NoReturn
:
"""Replace variables in the `FunctionGraph` according to `(var, new_var)` pairs in a list."""
for
var
,
new_var
in
pairs
:
self
.
replace
(
var
,
new_var
,
**
kwargs
)
def
attach_feature
(
self
,
feature
)
:
def
attach_feature
(
self
,
feature
:
Feature
)
->
NoReturn
:
"""
Adds a graph.toolbox.Feature to this function_graph and triggers its
on_attach callback.
...
...
@@ -561,7 +586,7 @@ class FunctionGraph(MetaObject):
# Add the feature
self
.
_features
.
append
(
feature
)
def
remove_feature
(
self
,
feature
)
:
def
remove_feature
(
self
,
feature
:
Feature
)
->
NoReturn
:
"""
Removes the feature from the graph.
...
...
@@ -578,7 +603,7 @@ class FunctionGraph(MetaObject):
if
detach
is
not
None
:
detach
(
self
)
def
execute_callbacks
(
self
,
name
,
*
args
,
**
kwargs
)
:
def
execute_callbacks
(
self
,
name
:
str
,
*
args
,
**
kwargs
)
->
NoReturn
:
"""Execute callbacks
Calls `getattr(feature, name)(*args)` for each feature which has
...
...
@@ -599,7 +624,7 @@ class FunctionGraph(MetaObject):
self
.
execute_callbacks_times
[
feature
]
+=
time
.
time
()
-
tf0
self
.
execute_callbacks_time
+=
time
.
time
()
-
t0
def
collect_callbacks
(
self
,
name
,
*
args
)
:
def
collect_callbacks
(
self
,
name
:
str
,
*
args
)
->
Dict
[
Feature
,
Any
]
:
"""Collects callbacks
Returns a dictionary d such that
...
...
@@ -615,7 +640,7 @@ class FunctionGraph(MetaObject):
d
[
feature
]
=
fn
(
*
args
)
return
d
def
toposort
(
self
):
def
toposort
(
self
)
->
List
[
Apply
]
:
"""Toposort
Return an ordering of the graph's Apply nodes such that
...
...
@@ -643,7 +668,7 @@ class FunctionGraph(MetaObject):
return
order
def
orderings
(
self
):
def
orderings
(
self
)
->
Dict
[
Apply
,
List
[
Apply
]]
:
"""Return `dict` `d` s.t. `d[node]` is a list of nodes that must be evaluated before `node` itself can be evaluated.
This is used primarily by the destroy_handler feature to ensure that
...
...
@@ -689,7 +714,7 @@ class FunctionGraph(MetaObject):
ords
.
setdefault
(
node
,
[])
.
extend
(
prereqs
)
return
ords
def
check_integrity
(
self
):
def
check_integrity
(
self
)
->
NoReturn
:
"""
Call this for a diagnosis if things go awry.
...
...
@@ -745,14 +770,16 @@ class FunctionGraph(MetaObject):
def
__repr__
(
self
):
return
f
"FunctionGraph({', '.join(graph_as_string(self.inputs, self.outputs))})"
def
clone
(
self
,
check_integrity
=
True
):
def
clone
(
self
,
check_integrity
=
True
)
->
"FunctionGraph"
:
"""
Clone the graph and get a memo( a dict )that map old node to new node
"""
return
self
.
clone_get_equiv
(
check_integrity
)[
0
]
def
clone_get_equiv
(
self
,
check_integrity
=
True
,
attach_feature
=
True
):
def
clone_get_equiv
(
self
,
check_integrity
:
bool
=
True
,
attach_feature
:
bool
=
True
)
->
Union
[
"FunctionGraph"
,
Dict
[
Variable
,
Variable
]]:
"""Clone the graph and get a dict that maps old nodes to new ones
Parameters:
...
...
@@ -810,7 +837,7 @@ class FunctionGraph(MetaObject):
if
hasattr
(
feature
,
"unpickle"
):
feature
.
unpickle
(
self
)
def
__contains__
(
self
,
item
)
:
def
__contains__
(
self
,
item
:
Union
[
Variable
,
Apply
])
->
bool
:
if
isinstance
(
item
,
Variable
):
return
item
in
self
.
variables
elif
isinstance
(
item
,
Apply
):
...
...
tests/graph/test_fg.py
浏览文件 @
01020a18
...
...
@@ -41,6 +41,10 @@ class TestFunctionGraph:
var3
=
op1
(
var1
)
FunctionGraph
([
var3
],
[
var2
],
clone
=
False
)
with
pytest
.
raises
(
ValueError
):
var3
=
op1
(
var1
)
FunctionGraph
([
var3
],
clone
=
False
)
def
test_init
(
self
):
var1
=
MyVariable
(
"var1"
)
var2
=
MyVariable
(
"var2"
)
...
...
@@ -58,6 +62,19 @@ class TestFunctionGraph:
assert
fg
.
get_clients
(
var3
)
==
[(
var4
.
owner
,
0
),
(
"output"
,
0
)]
assert
fg
.
get_clients
(
var4
)
==
[(
"output"
,
1
)]
fg
=
FunctionGraph
(
outputs
=
[
var3
,
var4
],
clone
=
False
)
assert
fg
.
inputs
==
[
var1
,
var2
]
memo
=
{}
fg
=
FunctionGraph
(
outputs
=
[
var3
,
var4
],
clone
=
True
,
memo
=
memo
)
assert
memo
[
var1
]
.
type
==
var1
.
type
assert
memo
[
var1
]
.
name
==
var1
.
name
assert
memo
[
var2
]
.
type
==
var2
.
type
assert
memo
[
var2
]
.
name
==
var2
.
name
assert
var3
in
memo
assert
var4
in
memo
def
test_remove_client
(
self
):
var1
=
MyVariable
(
"var1"
)
var2
=
MyVariable
(
"var2"
)
...
...
tests/graph/utils.py
浏览文件 @
01020a18
...
...
@@ -58,7 +58,7 @@ class MyOp(Op):
return
Apply
(
self
,
inputs
,
outputs
)
def
perform
(
self
,
node
,
inputs
,
outputs
):
outputs
[
0
]
=
np
.
array
(
inputs
)
outputs
[
0
]
=
np
.
array
(
inputs
,
dtype
=
np
.
object
)
def
__str__
(
self
):
return
self
.
name
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论