Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
75e573d1
提交
75e573d1
authored
8月 12, 2015
作者:
Iban Harlouchet
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
numpydoc for theano/gof/fg.py
上级
18c54eca
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
138 行增加
和
70 行删除
+138
-70
fg.py
theano/gof/fg.py
+138
-70
没有找到文件。
theano/gof/fg.py
浏览文件 @
75e573d1
"""
"""
fg.py: fg stands for FunctionGraph
fg.py: fg stands for FunctionGraph
Contains the FunctionGraph class and exception
Contains the FunctionGraph class and exception
types that it can raise
types that it can raise.
"""
"""
from
__future__
import
print_function
from
__future__
import
print_function
import
sys
import
sys
...
@@ -23,10 +24,13 @@ NullType = None
...
@@ -23,10 +24,13 @@ NullType = None
class
CachedConstantError
(
Exception
):
class
CachedConstantError
(
Exception
):
"""An exception thrown when we put in a FunctionGraph a Constant
"""
that is cached. This should not happen as the user can reuse this
An exception thrown when we put in a FunctionGraph a Constant
that is cached. This should not happen as the user can reuse this
cached constant in other FunctionGraph.
cached constant in other FunctionGraph.
"""
"""
pass
pass
...
@@ -34,24 +38,28 @@ class InconsistencyError(Exception):
...
@@ -34,24 +38,28 @@ class InconsistencyError(Exception):
"""
"""
This exception should be thrown by listeners to FunctionGraph when the
This exception should be thrown by listeners to FunctionGraph when the
graph's state is invalid.
graph's state is invalid.
"""
"""
pass
pass
class
MissingInputError
(
Exception
):
class
MissingInputError
(
Exception
):
"""
"""
A symbolic input needed to compute the outputs is missing.
A symbolic input needed to compute the outputs is missing.
"""
"""
pass
pass
class
FunctionGraph
(
utils
.
object2
):
class
FunctionGraph
(
utils
.
object2
):
"""
WRITEME
"""
A FunctionGraph represents a subgraph bound by a set of input variables and a
WRITEME
set of output variables, ie a subgraph that specifies a theano function.
A FunctionGraph represents a subgraph bound by a set of input variables and
The inputs list should contain all the inputs
a set of output variables, ie a subgraph that specifies a theano function.
on which the outputs depend. Variables of type Constant are
The inputs list should contain all the inputs on which the outputs depend.
not counted as inputs.
Variables of type Constant are
not counted as inputs.
The FunctionGraph supports the replace operation which allows to replace a
The FunctionGraph supports the replace operation which allows to replace a
variable in the subgraph by another, e.g. replace (x + x).out by (2
variable in the subgraph by another, e.g. replace (x + x).out by (2
...
@@ -74,28 +82,35 @@ class FunctionGraph(utils.object2):
...
@@ -74,28 +82,35 @@ class FunctionGraph(utils.object2):
Historically, the FunctionGraph was called an Env. Keep this in mind
Historically, the FunctionGraph was called an Env. Keep this in mind
while reading out-of-date documentation, e-mail support threads, etc.
while reading out-of-date documentation, e-mail support threads, etc.
"""
The constructor creates a FunctionGraph which operates on the subgraph
bound by the inputs and outputs sets.
def
__init__
(
self
,
inputs
,
outputs
,
features
=
None
,
clone
=
True
):
This class keeps a pointer to the inputs and outputs, and also modifies
"""
them.
Create an FunctionGraph which operates on the subgraph bound by the inputs and
outputs sets.
This class keeps a pointer to the inputs and outputs, and also modifies
#TODO: document what variables are[not] set in the FunctionGraph when a
them.
feature is added via the constructor. How constructed is the
FunctionGraph?
#TODO: document what variables are[not] set in the FunctionGraph when a feature
Parameters
is added via the constructor. How constructed is the FunctionGraph?
----------
inputs
Inputs nodes of the graph, usually declared by the user.
outputs
Outputs nodes of the graph.
clone
If true, we will clone the graph. This is useful to remove the constant
cache problem.
Note: the intermediate nodes between 'inputs' and 'outputs' are not explicitely
Notes
passed.
-----
The intermediate nodes between 'inputs' and 'outputs' are not explicitely
passed.
:param inputs: inputs nodes of the graph, usually declared by the user
"""
:param outputs: outputs nodes of the graph.
:param clone: If true, we will clone the graph. This is
def
__init__
(
self
,
inputs
,
outputs
,
features
=
None
,
clone
=
True
):
useful to remove the constant cache problem.
"""
if
clone
:
if
clone
:
inputs
,
outputs
=
graph
.
clone
(
inputs
,
outputs
)
inputs
,
outputs
=
graph
.
clone
(
inputs
,
outputs
)
...
@@ -180,15 +195,17 @@ class FunctionGraph(utils.object2):
...
@@ -180,15 +195,17 @@ class FunctionGraph(utils.object2):
# self.execute_callbacks('on_setup_node', node)
# self.execute_callbacks('on_setup_node', node)
def
disown
(
self
):
def
disown
(
self
):
""" WRITEME
"""
Cleans up all of this FunctionGraph's nodes and variables so they are not
WRITEME
associated with this FunctionGraph anymore.
Cleans up all of this FunctionGraph's nodes and variables so they are
not associated with this FunctionGraph anymore.
The FunctionGraph should not be used anymore after disown is called.
The FunctionGraph should not be used anymore after disown is called.
This may not clean everything this FunctionGraph's features set in the
This may not clean everything this FunctionGraph's features set in the
nodes and variables. If there are no features, this should set
nodes and variables. If there are no features, this should set
them back to what they were originally.
them back to what they were originally.
"""
"""
for
apply_node
in
self
.
apply_nodes
:
for
apply_node
in
self
.
apply_nodes
:
del
apply_node
.
fgraph
del
apply_node
.
fgraph
...
@@ -205,18 +222,25 @@ class FunctionGraph(utils.object2):
...
@@ -205,18 +222,25 @@ class FunctionGraph(utils.object2):
def
clients
(
self
,
r
):
def
clients
(
self
,
r
):
"""
"""
Set of all the (node, i) pairs such that node.inputs[i] is r.
Set of all the (node, i) pairs such that node.inputs[i] is r.
T
ell
differently, a list of (node,i) such that each node have
T
old
differently, a list of (node,i) such that each node have
r as input at index i.
r as input at index i.
"""
"""
return
r
.
clients
return
r
.
clients
def
__add_clients__
(
self
,
r
,
new_clients
):
def
__add_clients__
(
self
,
r
,
new_clients
):
""" WRITEME
"""
r -> variable
new_clients -> list of (node, i) pairs such that node.inputs[i] is r.
Updates the list of clients of r with new_clients.
Updates the list of clients of r with new_clients.
WRITEME
Parameters
----------
r
Variable.
new_clients
List of (node, i) pairs such that node.inputs[i] is r.
"""
"""
if
set
(
r
.
clients
)
.
intersection
(
set
(
new_clients
)):
if
set
(
r
.
clients
)
.
intersection
(
set
(
new_clients
)):
print
(
'ERROR: clients intersect!'
,
file
=
sys
.
stderr
)
print
(
'ERROR: clients intersect!'
,
file
=
sys
.
stderr
)
...
@@ -229,11 +253,18 @@ class FunctionGraph(utils.object2):
...
@@ -229,11 +253,18 @@ class FunctionGraph(utils.object2):
def
__remove_clients__
(
self
,
r
,
clients_to_remove
,
def
__remove_clients__
(
self
,
r
,
clients_to_remove
,
prune
=
True
,
reason
=
None
):
prune
=
True
,
reason
=
None
):
""" WRITEME
"""
r -> variable
clients_to_remove -> list of (op, i) pairs such that node.inputs[i] is not r anymore.
Removes all from the clients list of r.
Removes all from the clients list of r.
WRITEME
Parameters
----------
r
Variable.
clients_to_remove
List of (op, i) pairs such that node.inputs[i] is not r anymore.
"""
"""
for
entry
in
clients_to_remove
:
for
entry
in
clients_to_remove
:
r
.
clients
.
remove
(
entry
)
r
.
clients
.
remove
(
entry
)
...
@@ -286,11 +317,14 @@ class FunctionGraph(utils.object2):
...
@@ -286,11 +317,14 @@ class FunctionGraph(utils.object2):
if
config
.
exception_verbosity
==
'high'
:
if
config
.
exception_verbosity
==
'high'
:
def
find_path_to
(
output_var
,
input_var
):
def
find_path_to
(
output_var
,
input_var
):
""" Returns a list of each variable on a (not necessarily unique)
"""
path from input_var to output_var, where each variable in the
Returns a list of each variable on a (not
list has the preceding variable as one of its inputs.
necessarily unique) path from input_var to
Returns None if no path exists"""
output_var, where each variable in the list has
the preceding variable as one of its inputs.
Returns None if no path exists.
"""
# If output and input are the same we have a singleton path
# If output and input are the same we have a singleton path
if
output_var
is
input_var
:
if
output_var
is
input_var
:
return
[
output_var
]
return
[
output_var
]
...
@@ -376,12 +410,13 @@ class FunctionGraph(utils.object2):
...
@@ -376,12 +410,13 @@ class FunctionGraph(utils.object2):
# prune #
# prune #
def
__prune_r__
(
self
,
variable
,
reason
=
None
):
def
__prune_r__
(
self
,
variable
,
reason
=
None
):
"""Should be called for variable that aren't used anymore:
"""
len(var.clients) == 0
Should be called for variable that aren't used anymore:
len(var.clients) == 0.
This do not mean we will remove it from fgraph.variables. If
This do not mean we will remove it from fgraph.variables. If
the owner stay in the fgraph as other outputs are still used,
the owner stay in the fgraph as other outputs are still used,
the variable will
be
stay in fgraph.variables.
the variable will stay in fgraph.variables.
"""
"""
# Prunes the owners of the variables.
# Prunes the owners of the variables.
...
@@ -409,7 +444,8 @@ class FunctionGraph(utils.object2):
...
@@ -409,7 +444,8 @@ class FunctionGraph(utils.object2):
del
variable
.
fgraph
del
variable
.
fgraph
def
__prune__
(
self
,
apply_node
,
reason
=
None
):
def
__prune__
(
self
,
apply_node
,
reason
=
None
):
"""Always called on owner of pruned variable from the graph.
"""
Always called on owner of pruned variable from the graph.
This do not mean we will remove it from the graph. If other
This do not mean we will remove it from the graph. If other
outputs are still used, we will keep the node in the graph.
outputs are still used, we will keep the node in the graph.
...
@@ -433,14 +469,17 @@ class FunctionGraph(utils.object2):
...
@@ -433,14 +469,17 @@ class FunctionGraph(utils.object2):
# change input #
# change input #
def
change_input
(
self
,
node
,
i
,
new_r
,
reason
=
None
):
def
change_input
(
self
,
node
,
i
,
new_r
,
reason
=
None
):
"""
WRITEME
"""
Changes node.inputs[i] to new_r.
Changes node.inputs[i] to new_r.
WRITEME
new_r.type == old_r.type must be True, where old_r is the
new_r.type == old_r.type must be True, where old_r is the
current value of node.inputs[i] which we want to replace.
current value of node.inputs[i] which we want to replace.
For each feature that has a 'on_change_input' method, calls:
For each feature that has a 'on_change_input' method, calls:
feature.on_change_input(function_graph, node, i, old_r, new_r, reason)
feature.on_change_input(function_graph, node, i, old_r, new_r, reason)
"""
"""
# TODO: ERROR HANDLING FOR LISTENERS (should it complete the change or revert it?)
# TODO: ERROR HANDLING FOR LISTENERS (should it complete the change or revert it?)
if
node
==
'output'
:
if
node
==
'output'
:
...
@@ -478,9 +517,12 @@ class FunctionGraph(utils.object2):
...
@@ -478,9 +517,12 @@ class FunctionGraph(utils.object2):
# replace #
# replace #
def
replace
(
self
,
r
,
new_r
,
reason
=
None
,
verbose
=
None
):
def
replace
(
self
,
r
,
new_r
,
reason
=
None
,
verbose
=
None
):
""" WRITEME
"""
WRITEME
This is the main interface to manipulate the subgraph in FunctionGraph.
This is the main interface to manipulate the subgraph in FunctionGraph.
For every node that uses r as input, makes it use new_r instead.
For every node that uses r as input, makes it use new_r instead.
"""
"""
if
verbose
is
None
:
if
verbose
is
None
:
verbose
=
config
.
optimizer_verbose
verbose
=
config
.
optimizer_verbose
...
@@ -532,16 +574,19 @@ class FunctionGraph(utils.object2):
...
@@ -532,16 +574,19 @@ class FunctionGraph(utils.object2):
# print >> sys.stderr, "WARNING: CLIENTS LEFT AFTER REPLACE", r, r.clients
# print >> sys.stderr, "WARNING: CLIENTS LEFT AFTER REPLACE", r, r.clients
def
replace_all
(
self
,
pairs
,
reason
=
None
):
def
replace_all
(
self
,
pairs
,
reason
=
None
):
"""WRITEME"""
"""
WRITEME
"""
for
r
,
new_r
in
pairs
:
for
r
,
new_r
in
pairs
:
self
.
replace
(
r
,
new_r
,
reason
=
reason
)
self
.
replace
(
r
,
new_r
,
reason
=
reason
)
def
attach_feature
(
self
,
feature
):
def
attach_feature
(
self
,
feature
):
"""
"""
Adds a gof.toolbox.Feature to this function_graph
Adds a gof.toolbox.Feature to this function_graph and triggers its
and triggers its on_attach callback
on_attach callback.
"""
"""
# Filter out literally identical features
# Filter out literally identical features
if
feature
in
self
.
_features
:
if
feature
in
self
.
_features
:
return
# the feature is already present
return
# the feature is already present
...
@@ -567,7 +612,9 @@ class FunctionGraph(utils.object2):
...
@@ -567,7 +612,9 @@ class FunctionGraph(utils.object2):
self
.
_features
.
append
(
feature
)
self
.
_features
.
append
(
feature
)
def
remove_feature
(
self
,
feature
):
def
remove_feature
(
self
,
feature
):
"""WRITEME
"""
WRITEME
Removes the feature from the graph.
Removes the feature from the graph.
Calls feature.on_detach(function_graph) if an on_detach method
Calls feature.on_detach(function_graph) if an on_detach method
...
@@ -585,10 +632,13 @@ class FunctionGraph(utils.object2):
...
@@ -585,10 +632,13 @@ class FunctionGraph(utils.object2):
# callback utils #
# callback utils #
def
execute_callbacks
(
self
,
name
,
*
args
,
**
kwargs
):
def
execute_callbacks
(
self
,
name
,
*
args
,
**
kwargs
):
"""WRITEME
"""
WRITEME
Calls
Calls
getattr(feature, name)(*args)
getattr(feature, name)(*args)
for each feature which has a method called after name.
for each feature which has a method called after name.
"""
"""
t0
=
time
.
time
()
t0
=
time
.
time
()
for
feature
in
self
.
_features
:
for
feature
in
self
.
_features
:
...
@@ -605,10 +655,13 @@ class FunctionGraph(utils.object2):
...
@@ -605,10 +655,13 @@ class FunctionGraph(utils.object2):
self
.
execute_callbacks_time
+=
time
.
time
()
-
t0
self
.
execute_callbacks_time
+=
time
.
time
()
-
t0
def
collect_callbacks
(
self
,
name
,
*
args
):
def
collect_callbacks
(
self
,
name
,
*
args
):
"""WRITEME
"""
WRITEME
Returns a dictionary d such that:
Returns a dictionary d such that:
d[feature] == getattr(feature, name)(*args)
d[feature] == getattr(feature, name)(*args)
For each feature which has a method called after name.
For each feature which has a method called after name.
"""
"""
d
=
{}
d
=
{}
for
feature
in
self
.
_features
:
for
feature
in
self
.
_features
:
...
@@ -621,16 +674,19 @@ class FunctionGraph(utils.object2):
...
@@ -621,16 +674,19 @@ class FunctionGraph(utils.object2):
# misc #
# misc #
def
toposort
(
self
):
def
toposort
(
self
):
"""WRITEME
"""
Returns an ordering of the graph's Apply nodes such that:
WRITEME
- All the nodes of the inputs of a node are before that node.
- Satisfies the orderings provided by each feature that has
Return an ordering of the graph's Apply nodes such that:
an 'orderings' method.
- All the nodes of the inputs of a node are before that node.
- Satisfies the orderings provided by each feature that has
an 'orderings' method.
If a feature has an 'orderings' method, it will be called with
If a feature has an 'orderings' method, it will be called with
this FunctionGraph as sole argument. It should return a dictionary of
this FunctionGraph as sole argument. It should return a dictionary of
{node: predecessors} where predecessors is a list of nodes
{node: predecessors} where predecessors is a list of nodes
that should be computed before the key node.
that should be computed before the key node.
"""
"""
if
len
(
self
.
apply_nodes
)
<
2
:
if
len
(
self
.
apply_nodes
)
<
2
:
# optimization
# optimization
...
@@ -652,11 +708,12 @@ class FunctionGraph(utils.object2):
...
@@ -652,11 +708,12 @@ class FunctionGraph(utils.object2):
before node itself can be evaluated.
before node itself can be evaluated.
This is used primarily by the destroy_handler feature to ensure that
This is used primarily by the destroy_handler feature to ensure that
all clients of any destroyed inputs have already computed their
all clients of any destroyed inputs have already computed their outputs.
outputs.
:note: This only calls the orderings() fct on all features. It does not
Notes
take care of computing dependencies by itself.
-----
This only calls the orderings() fct on all features. It does not
take care of computing dependencies by itself.
"""
"""
ords
=
OrderedDict
()
ords
=
OrderedDict
()
...
@@ -682,8 +739,11 @@ class FunctionGraph(utils.object2):
...
@@ -682,8 +739,11 @@ class FunctionGraph(utils.object2):
return
ords
return
ords
def
check_integrity
(
self
):
def
check_integrity
(
self
):
"""WRITEME
"""
WRITEME
Call this for a diagnosis if things go awry.
Call this for a diagnosis if things go awry.
"""
"""
nodes
=
graph
.
ops
(
self
.
inputs
,
self
.
outputs
)
nodes
=
graph
.
ops
(
self
.
inputs
,
self
.
outputs
)
if
self
.
apply_nodes
!=
nodes
:
if
self
.
apply_nodes
!=
nodes
:
...
@@ -740,11 +800,17 @@ class FunctionGraph(utils.object2):
...
@@ -740,11 +800,17 @@ class FunctionGraph(utils.object2):
# clone #
# clone #
def
clone
(
self
,
check_integrity
=
True
):
def
clone
(
self
,
check_integrity
=
True
):
"""WRITEME"""
"""
WRITEME
"""
return
self
.
clone_get_equiv
(
check_integrity
)[
0
]
return
self
.
clone_get_equiv
(
check_integrity
)[
0
]
def
clone_get_equiv
(
self
,
check_integrity
=
True
):
def
clone_get_equiv
(
self
,
check_integrity
=
True
):
"""WRITEME"""
"""
WRITEME
"""
equiv
=
graph
.
clone_get_equiv
(
self
.
inputs
,
self
.
outputs
)
equiv
=
graph
.
clone_get_equiv
(
self
.
inputs
,
self
.
outputs
)
if
check_integrity
:
if
check_integrity
:
self
.
check_integrity
()
self
.
check_integrity
()
...
@@ -757,8 +823,10 @@ class FunctionGraph(utils.object2):
...
@@ -757,8 +823,10 @@ class FunctionGraph(utils.object2):
return
e
,
equiv
return
e
,
equiv
def
__getstate__
(
self
):
def
__getstate__
(
self
):
"""This is needed as some feature introduce instancemethod and
"""
this is not picklable.
This is needed as some features introduce instance methods.
This is not picklable.
"""
"""
d
=
self
.
__dict__
.
copy
()
d
=
self
.
__dict__
.
copy
()
for
feature
in
self
.
_features
:
for
feature
in
self
.
_features
:
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论