Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
aca35acc
提交
aca35acc
authored
8月 02, 2013
作者:
Frederic
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
pep8
上级
2de0bc5e
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
71 行增加
和
63 行删除
+71
-63
fg.py
theano/gof/fg.py
+71
-61
toolbox.py
theano/gof/toolbox.py
+0
-2
没有找到文件。
theano/gof/fg.py
浏览文件 @
aca35acc
...
...
@@ -16,6 +16,7 @@ NullType = None
from
theano.gof.python25
import
OrderedDict
from
theano.misc.ordered_set
import
OrderedSet
class
InconsistencyError
(
Exception
):
"""
This exception should be thrown by listeners to FunctionGraph when the
...
...
@@ -82,7 +83,8 @@ class FunctionGraph(utils.object2):
# so I probably am) this should be a set.
self
.
_features
=
[]
# All apply nodes in the subgraph defined by inputs and outputs are cached in this field
# All apply nodes in the subgraph defined by inputs and
# outputs are cached in this field
self
.
apply_nodes
=
set
()
# Ditto for variable nodes
...
...
@@ -112,12 +114,12 @@ class FunctionGraph(utils.object2):
self
.
variable_locks
=
{}
self
.
profile
=
None
### Setup a Variable ###
def
__setup_r__
(
self
,
r
):
# sets up r so it belongs to this fgraph
if
hasattr
(
r
,
'fgraph'
)
and
r
.
fgraph
is
not
None
and
r
.
fgraph
is
not
self
:
if
(
hasattr
(
r
,
'fgraph'
)
and
r
.
fgraph
is
not
None
and
r
.
fgraph
is
not
self
):
raise
Exception
(
"
%
s is already owned by another fgraph"
%
r
)
r
.
fgraph
=
self
r
.
clients
=
[]
...
...
@@ -165,13 +167,13 @@ class FunctionGraph(utils.object2):
self
.
inputs
=
None
self
.
outputs
=
None
### clients ###
def
clients
(
self
,
r
):
"""
Set of all the (node, i) pairs such that node.inputs[i] is r.
Tell differently, a list of (node,i) such that each node have r as input at index i.
Tell differently, a list of (node,i) such that each node have
r as input at index i.
"""
return
r
.
clients
...
...
@@ -184,12 +186,14 @@ class FunctionGraph(utils.object2):
"""
if
set
(
r
.
clients
)
.
intersection
(
set
(
new_clients
)):
print
>>
sys
.
stderr
,
'ERROR: clients intersect!'
print
>>
sys
.
stderr
,
' RCLIENTS of'
,
r
,
[(
n
,
i
,
type
(
n
),
id
(
n
))
for
n
,
i
in
r
.
clients
]
print
>>
sys
.
stderr
,
' NCLIENTS of'
,
r
,
[(
n
,
i
,
type
(
n
),
id
(
n
))
for
n
,
i
in
new_clients
]
print
>>
sys
.
stderr
,
' RCLIENTS of'
,
r
,
[(
n
,
i
,
type
(
n
),
id
(
n
))
for
n
,
i
in
r
.
clients
]
print
>>
sys
.
stderr
,
' NCLIENTS of'
,
r
,
[(
n
,
i
,
type
(
n
),
id
(
n
))
for
n
,
i
in
new_clients
]
assert
not
set
(
r
.
clients
)
.
intersection
(
set
(
new_clients
))
r
.
clients
+=
new_clients
def
__remove_clients__
(
self
,
r
,
clients_to_remove
,
prune
=
True
):
def
__remove_clients__
(
self
,
r
,
clients_to_remove
,
prune
=
True
):
""" WRITEME
r -> variable
clients_to_remove -> list of (op, i) pairs such that node.inputs[i] is not r anymore.
...
...
@@ -202,7 +206,7 @@ class FunctionGraph(utils.object2):
print
>>
sys
.
stderr
,
'ERROR: DUPLICATE CLIENT ENTRY...'
print
>>
sys
.
stderr
,
' ENTRY'
,
repr
(
entry
),
type
(
entry
[
0
])
print
>>
sys
.
stderr
,
' CLIENTS'
,
repr
(
r
.
clients
)
assert
entry
not
in
r
.
clients
# an op,i pair should be unique
assert
entry
not
in
r
.
clients
# an op,i pair should be unique
if
not
r
.
clients
:
if
prune
:
self
.
__prune_r__
([
r
])
...
...
@@ -210,9 +214,7 @@ class FunctionGraph(utils.object2):
return
True
return
False
### import ###
def
__import_r__
(
self
,
variables
):
global
NullType
if
NullType
is
None
:
...
...
@@ -225,14 +227,15 @@ class FunctionGraph(utils.object2):
self
.
__import__
(
apply_node
)
for
r
in
variables
:
if
r
.
owner
is
None
and
not
isinstance
(
r
,
graph
.
Constant
)
and
r
not
in
self
.
inputs
:
if
isinstance
(
r
.
type
,
NullType
):
raise
TypeError
(
"Computation graph contains a NaN. "
+
r
.
type
.
why_null
)
if
isinstance
(
r
.
type
,
NullType
):
raise
TypeError
(
"Computation graph contains a NaN. "
+
r
.
type
.
why_null
)
raise
MissingInputError
(
"Undeclared input"
,
r
)
if
not
getattr
(
r
,
'fgraph'
,
None
)
is
self
:
self
.
__setup_r__
(
r
)
self
.
variables
.
add
(
r
)
def
__import__
(
self
,
apply_node
,
check
=
True
):
def
__import__
(
self
,
apply_node
,
check
=
True
):
node
=
apply_node
# We import the nodes in topological order. We only are interested
...
...
@@ -248,7 +251,9 @@ class FunctionGraph(utils.object2):
for
r
in
node
.
inputs
:
if
hasattr
(
r
,
'fgraph'
)
and
r
.
fgraph
is
not
self
:
raise
Exception
(
"
%
s is already owned by another fgraph"
%
r
)
if
r
.
owner
is
None
and
not
isinstance
(
r
,
graph
.
Constant
)
and
r
not
in
self
.
inputs
:
if
(
r
.
owner
is
None
and
not
isinstance
(
r
,
graph
.
Constant
)
and
r
not
in
self
.
inputs
):
#Verbose error message
#Show a complete chain of variables from the missing input to an output
...
...
@@ -330,9 +335,7 @@ class FunctionGraph(utils.object2):
assert
node
.
fgraph
is
self
self
.
execute_callbacks
(
'on_import'
,
node
)
### prune ###
def
__prune_r__
(
self
,
variables
):
# Prunes the owners of the variables.
for
node
in
set
(
r
.
owner
for
r
in
variables
if
r
.
owner
is
not
None
):
...
...
@@ -362,10 +365,7 @@ class FunctionGraph(utils.object2):
self
.
__remove_clients__
(
input
,
[(
node
,
i
)])
#self.__prune_r__(node.inputs)
### change input ###
def
change_input
(
self
,
node
,
i
,
new_r
,
reason
=
None
):
"""WRITEME
Changes node.inputs[i] to new_r.
...
...
@@ -381,18 +381,18 @@ class FunctionGraph(utils.object2):
r
=
self
.
outputs
[
i
]
if
not
r
.
type
==
new_r
.
type
:
raise
TypeError
(
"The type of the replacement must be the"
" same as the type of the original Variable."
,
r
,
new_r
)
" same as the type of the original Variable."
,
r
,
new_r
)
self
.
outputs
[
i
]
=
new_r
else
:
if
node
.
fgraph
is
not
self
:
raise
Exception
(
"Cannot operate on
%
s because it does not"
" belong to this FunctionGraph"
%
node
)
" belong to this FunctionGraph"
%
node
)
r
=
node
.
inputs
[
i
]
if
not
r
.
type
==
new_r
.
type
:
raise
TypeError
(
"The type of the replacement must be the"
" same as the type of the original Variable."
,
r
,
new_r
)
" same as the type of the original Variable."
,
r
,
new_r
)
node
.
inputs
[
i
]
=
new_r
if
r
is
new_r
:
...
...
@@ -404,7 +404,8 @@ class FunctionGraph(utils.object2):
# Precondition: the substitution is semantically valid
# However it may introduce cycles to the graph, in which case the
# transaction will be reverted later.
self
.
execute_callbacks
(
'on_change_input'
,
node
,
i
,
r
,
new_r
,
reason
=
reason
)
self
.
execute_callbacks
(
'on_change_input'
,
node
,
i
,
r
,
new_r
,
reason
=
reason
)
if
prune
:
self
.
__prune_r__
([
r
])
...
...
@@ -426,7 +427,7 @@ class FunctionGraph(utils.object2):
# because it makes it easier to implement some optimizations for multiple-output ops
return
for
node
,
i
in
list
(
r
.
clients
):
# copy the client list for iteration
for
node
,
i
in
list
(
r
.
clients
):
# copy the client list for iteration
assert
(
node
==
'output'
and
self
.
outputs
[
i
]
is
r
)
or
(
node
.
inputs
[
i
]
is
r
)
self
.
change_input
(
node
,
i
,
new_r
,
reason
=
reason
)
...
...
@@ -440,11 +441,9 @@ class FunctionGraph(utils.object2):
for
r
,
new_r
in
pairs
:
self
.
replace
(
r
,
new_r
,
reason
=
reason
)
def
extend
(
self
,
feature
):
warnings
.
warn
(
"FunctionGraph.extend is deprecatd. It has been "
"renamed to FunctionGraph.attach_feature"
)
"renamed to FunctionGraph.attach_feature"
)
return
self
.
attach_feature
(
feature
)
def
attach_feature
(
self
,
feature
):
...
...
@@ -455,7 +454,7 @@ class FunctionGraph(utils.object2):
# Filter out literally identical features
if
feature
in
self
.
_features
:
return
# the feature is already present
return
# the feature is already present
# Filter out functionally identical features.
# Features may use their on_attach method to raise
...
...
@@ -481,7 +480,9 @@ class FunctionGraph(utils.object2):
"""WRITEME
Removes the feature from the graph.
Calls feature.on_detach(function_graph) if an on_detach method is defined.
Calls feature.on_detach(function_graph) if an on_detach method
is defined.
"""
try
:
self
.
_features
.
remove
(
feature
)
...
...
@@ -491,9 +492,7 @@ class FunctionGraph(utils.object2):
if
detach
is
not
None
:
detach
(
self
)
### callback utils ###
def
execute_callbacks
(
self
,
name
,
*
args
,
**
kwargs
):
"""WRITEME
Calls
...
...
@@ -518,7 +517,6 @@ class FunctionGraph(utils.object2):
else
:
raise
def
collect_callbacks
(
self
,
name
,
*
args
):
"""WRITEME
Returns a dictionary d such that:
...
...
@@ -534,9 +532,7 @@ class FunctionGraph(utils.object2):
d
[
feature
]
=
fn
(
*
args
)
return
d
### misc ###
def
toposort
(
self
):
"""WRITEME
Returns an ordering of the graph's Apply nodes such that:
...
...
@@ -552,8 +548,8 @@ class FunctionGraph(utils.object2):
if
len
(
self
.
apply_nodes
)
<
2
:
# optimization
# when there are 0 or 1 nodes, no sorting is necessary
# This special case happens a lot because the OpWiseCLinker
produces
# 1-element graphs.
# This special case happens a lot because the OpWiseCLinker
#
produces
1-element graphs.
return
list
(
self
.
apply_nodes
)
fg
=
self
...
...
@@ -568,14 +564,15 @@ class FunctionGraph(utils.object2):
Return dict d s.t. d[node] is a list of nodes that must be evaluated
before node itself can be evaluated.
This is used primarily by the destroy_handler feature to ensure that all
clients of any destroyed inputs have already computed their outputs.
This is used primarily by the destroy_handler feature to ensure that
all clients of any destroyed inputs have already computed their
outputs.
:note: This only calls the orderings() fct on all features. It does not
take care of computing dependencies by itself.
"""
ords
=
OrderedDict
()
ords
=
OrderedDict
()
assert
isinstance
(
self
.
_features
,
list
)
for
feature
in
self
.
_features
:
if
hasattr
(
feature
,
'orderings'
):
...
...
@@ -586,12 +583,13 @@ class FunctionGraph(utils.object2):
+
". Nondeterministic object is "
+
str
(
orderings
))
for
node
,
prereqs
in
orderings
.
items
():
if
not
isinstance
(
prereqs
,
(
list
,
OrderedSet
)):
raise
TypeError
(
"prereqs must be a type with a "
"deterministic iteration order, or toposort "
" will be non-deterministic."
)
raise
TypeError
(
"prereqs must be a type with a "
"deterministic iteration order, or toposort "
" will be non-deterministic."
)
ords
.
setdefault
(
node
,
[])
.
extend
(
prereqs
)
# eliminate duplicate prereqs
for
(
node
,
prereqs
)
in
ords
.
items
():
for
(
node
,
prereqs
)
in
ords
.
items
():
ords
[
node
]
=
list
(
OrderedSet
(
prereqs
))
return
ords
...
...
@@ -624,34 +622,48 @@ class FunctionGraph(utils.object2):
if
self
.
apply_nodes
!=
nodes
:
missing
=
nodes
.
difference
(
self
.
apply_nodes
)
excess
=
self
.
apply_nodes
.
difference
(
nodes
)
raise
Exception
(
"The nodes are inappropriately cached. missing, in excess: "
,
missing
,
excess
)
raise
Exception
(
"The nodes are inappropriately cached. missing, in excess: "
,
missing
,
excess
)
for
node
in
nodes
:
if
node
.
fgraph
is
not
self
:
raise
Exception
(
"Node should belong to the FunctionGraph."
,
node
)
raise
Exception
(
"Node should belong to the FunctionGraph."
,
node
)
for
i
,
variable
in
enumerate
(
node
.
inputs
):
if
variable
.
fgraph
is
not
self
:
raise
Exception
(
"Input of node should belong to the FunctionGraph."
,
variable
,
(
node
,
i
))
raise
Exception
(
"Input of node should belong to the FunctionGraph."
,
variable
,
(
node
,
i
))
if
(
node
,
i
)
not
in
variable
.
clients
:
raise
Exception
(
"Inconsistent clients list."
,
(
node
,
i
),
variable
.
clients
)
raise
Exception
(
"Inconsistent clients list."
,
(
node
,
i
),
variable
.
clients
)
variables
=
set
(
graph
.
variables
(
self
.
inputs
,
self
.
outputs
))
if
set
(
self
.
variables
)
!=
variables
:
missing
=
variables
.
difference
(
self
.
variables
)
excess
=
self
.
variables
.
difference
(
variables
)
raise
Exception
(
"The variables are inappropriately cached. missing, in excess: "
,
missing
,
excess
)
raise
Exception
(
"The variables are inappropriately cached. missing, in excess: "
,
missing
,
excess
)
for
variable
in
variables
:
if
variable
.
owner
is
None
and
variable
not
in
self
.
inputs
and
not
isinstance
(
variable
,
graph
.
Constant
):
if
(
variable
.
owner
is
None
and
variable
not
in
self
.
inputs
and
not
isinstance
(
variable
,
graph
.
Constant
)):
raise
Exception
(
"Undeclared input."
,
variable
)
if
variable
.
fgraph
is
not
self
:
raise
Exception
(
"Variable should belong to the FunctionGraph."
,
variable
)
raise
Exception
(
"Variable should belong to the FunctionGraph."
,
variable
)
for
node
,
i
in
variable
.
clients
:
if
node
==
'output'
:
if
self
.
outputs
[
i
]
is
not
variable
:
raise
Exception
(
"Inconsistent clients list."
,
variable
,
self
.
outputs
[
i
])
raise
Exception
(
"Inconsistent clients list."
,
variable
,
self
.
outputs
[
i
])
continue
if
node
not
in
nodes
:
raise
Exception
(
"Client not in FunctionGraph."
,
variable
,
(
node
,
i
))
raise
Exception
(
"Client not in FunctionGraph."
,
variable
,
(
node
,
i
))
if
node
.
inputs
[
i
]
is
not
variable
:
raise
Exception
(
"Inconsistent clients list."
,
variable
,
node
.
inputs
[
i
])
raise
Exception
(
"Inconsistent clients list."
,
variable
,
node
.
inputs
[
i
])
def
__str__
(
self
):
return
"[
%
s]"
%
", "
.
join
(
graph
.
as_string
(
self
.
inputs
,
self
.
outputs
))
...
...
@@ -659,9 +671,7 @@ class FunctionGraph(utils.object2):
def
__repr__
(
self
):
return
self
.
__str__
()
### clone ###
def
clone
(
self
):
"""WRITEME"""
return
self
.
clone_get_equiv
()[
0
]
...
...
@@ -671,7 +681,7 @@ class FunctionGraph(utils.object2):
equiv
=
graph
.
clone_get_equiv
(
self
.
inputs
,
self
.
outputs
)
self
.
check_integrity
()
e
=
FunctionGraph
([
equiv
[
i
]
for
i
in
self
.
inputs
],
[
equiv
[
o
]
for
o
in
self
.
outputs
])
[
equiv
[
o
]
for
o
in
self
.
outputs
])
e
.
check_integrity
()
for
feature
in
self
.
_features
:
e
.
attach_feature
(
feature
)
...
...
theano/gof/toolbox.py
浏览文件 @
aca35acc
...
...
@@ -3,11 +3,9 @@ import time
from
theano.gof.python25
import
partial
from
theano.gof.python25
import
OrderedDict
from
theano.gof
import
graph
class
AlreadyThere
(
Exception
):
"""Raised by a Feature's on_attach callback method if the FunctionGraph
attempting to attach the feature already has a functionally identical
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论