Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
aca35acc
提交
aca35acc
authored
8月 02, 2013
作者:
Frederic
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
pep8
上级
2de0bc5e
显示空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
58 行增加
和
50 行删除
+58
-50
fg.py
theano/gof/fg.py
+58
-48
toolbox.py
theano/gof/toolbox.py
+0
-2
没有找到文件。
theano/gof/fg.py
浏览文件 @
aca35acc
...
@@ -16,6 +16,7 @@ NullType = None
...
@@ -16,6 +16,7 @@ NullType = None
from
theano.gof.python25
import
OrderedDict
from
theano.gof.python25
import
OrderedDict
from
theano.misc.ordered_set
import
OrderedSet
from
theano.misc.ordered_set
import
OrderedSet
class
InconsistencyError
(
Exception
):
class
InconsistencyError
(
Exception
):
"""
"""
This exception should be thrown by listeners to FunctionGraph when the
This exception should be thrown by listeners to FunctionGraph when the
...
@@ -82,7 +83,8 @@ class FunctionGraph(utils.object2):
...
@@ -82,7 +83,8 @@ class FunctionGraph(utils.object2):
# so I probably am) this should be a set.
# so I probably am) this should be a set.
self
.
_features
=
[]
self
.
_features
=
[]
# All apply nodes in the subgraph defined by inputs and outputs are cached in this field
# All apply nodes in the subgraph defined by inputs and
# outputs are cached in this field
self
.
apply_nodes
=
set
()
self
.
apply_nodes
=
set
()
# Ditto for variable nodes
# Ditto for variable nodes
...
@@ -112,12 +114,12 @@ class FunctionGraph(utils.object2):
...
@@ -112,12 +114,12 @@ class FunctionGraph(utils.object2):
self
.
variable_locks
=
{}
self
.
variable_locks
=
{}
self
.
profile
=
None
self
.
profile
=
None
### Setup a Variable ###
### Setup a Variable ###
def
__setup_r__
(
self
,
r
):
def
__setup_r__
(
self
,
r
):
# sets up r so it belongs to this fgraph
# sets up r so it belongs to this fgraph
if
hasattr
(
r
,
'fgraph'
)
and
r
.
fgraph
is
not
None
and
r
.
fgraph
is
not
self
:
if
(
hasattr
(
r
,
'fgraph'
)
and
r
.
fgraph
is
not
None
and
r
.
fgraph
is
not
self
):
raise
Exception
(
"
%
s is already owned by another fgraph"
%
r
)
raise
Exception
(
"
%
s is already owned by another fgraph"
%
r
)
r
.
fgraph
=
self
r
.
fgraph
=
self
r
.
clients
=
[]
r
.
clients
=
[]
...
@@ -165,13 +167,13 @@ class FunctionGraph(utils.object2):
...
@@ -165,13 +167,13 @@ class FunctionGraph(utils.object2):
self
.
inputs
=
None
self
.
inputs
=
None
self
.
outputs
=
None
self
.
outputs
=
None
### clients ###
### clients ###
def
clients
(
self
,
r
):
def
clients
(
self
,
r
):
"""
"""
Set of all the (node, i) pairs such that node.inputs[i] is r.
Set of all the (node, i) pairs such that node.inputs[i] is r.
Tell differently, a list of (node,i) such that each node have r as input at index i.
Tell differently, a list of (node,i) such that each node have
r as input at index i.
"""
"""
return
r
.
clients
return
r
.
clients
...
@@ -184,12 +186,14 @@ class FunctionGraph(utils.object2):
...
@@ -184,12 +186,14 @@ class FunctionGraph(utils.object2):
"""
"""
if
set
(
r
.
clients
)
.
intersection
(
set
(
new_clients
)):
if
set
(
r
.
clients
)
.
intersection
(
set
(
new_clients
)):
print
>>
sys
.
stderr
,
'ERROR: clients intersect!'
print
>>
sys
.
stderr
,
'ERROR: clients intersect!'
print
>>
sys
.
stderr
,
' RCLIENTS of'
,
r
,
[(
n
,
i
,
type
(
n
),
id
(
n
))
for
n
,
i
in
r
.
clients
]
print
>>
sys
.
stderr
,
' RCLIENTS of'
,
r
,
[(
n
,
i
,
type
(
n
),
id
(
n
))
print
>>
sys
.
stderr
,
' NCLIENTS of'
,
r
,
[(
n
,
i
,
type
(
n
),
id
(
n
))
for
n
,
i
in
new_clients
]
for
n
,
i
in
r
.
clients
]
print
>>
sys
.
stderr
,
' NCLIENTS of'
,
r
,
[(
n
,
i
,
type
(
n
),
id
(
n
))
for
n
,
i
in
new_clients
]
assert
not
set
(
r
.
clients
)
.
intersection
(
set
(
new_clients
))
assert
not
set
(
r
.
clients
)
.
intersection
(
set
(
new_clients
))
r
.
clients
+=
new_clients
r
.
clients
+=
new_clients
def
__remove_clients__
(
self
,
r
,
clients_to_remove
,
prune
=
True
):
def
__remove_clients__
(
self
,
r
,
clients_to_remove
,
prune
=
True
):
""" WRITEME
""" WRITEME
r -> variable
r -> variable
clients_to_remove -> list of (op, i) pairs such that node.inputs[i] is not r anymore.
clients_to_remove -> list of (op, i) pairs such that node.inputs[i] is not r anymore.
...
@@ -210,9 +214,7 @@ class FunctionGraph(utils.object2):
...
@@ -210,9 +214,7 @@ class FunctionGraph(utils.object2):
return
True
return
True
return
False
return
False
### import ###
### import ###
def
__import_r__
(
self
,
variables
):
def
__import_r__
(
self
,
variables
):
global
NullType
global
NullType
if
NullType
is
None
:
if
NullType
is
None
:
...
@@ -225,14 +227,15 @@ class FunctionGraph(utils.object2):
...
@@ -225,14 +227,15 @@ class FunctionGraph(utils.object2):
self
.
__import__
(
apply_node
)
self
.
__import__
(
apply_node
)
for
r
in
variables
:
for
r
in
variables
:
if
r
.
owner
is
None
and
not
isinstance
(
r
,
graph
.
Constant
)
and
r
not
in
self
.
inputs
:
if
r
.
owner
is
None
and
not
isinstance
(
r
,
graph
.
Constant
)
and
r
not
in
self
.
inputs
:
if
isinstance
(
r
.
type
,
NullType
):
if
isinstance
(
r
.
type
,
NullType
):
raise
TypeError
(
"Computation graph contains a NaN. "
+
r
.
type
.
why_null
)
raise
TypeError
(
"Computation graph contains a NaN. "
+
r
.
type
.
why_null
)
raise
MissingInputError
(
"Undeclared input"
,
r
)
raise
MissingInputError
(
"Undeclared input"
,
r
)
if
not
getattr
(
r
,
'fgraph'
,
None
)
is
self
:
if
not
getattr
(
r
,
'fgraph'
,
None
)
is
self
:
self
.
__setup_r__
(
r
)
self
.
__setup_r__
(
r
)
self
.
variables
.
add
(
r
)
self
.
variables
.
add
(
r
)
def
__import__
(
self
,
apply_node
,
check
=
True
):
def
__import__
(
self
,
apply_node
,
check
=
True
):
node
=
apply_node
node
=
apply_node
# We import the nodes in topological order. We only are interested
# We import the nodes in topological order. We only are interested
...
@@ -248,7 +251,9 @@ class FunctionGraph(utils.object2):
...
@@ -248,7 +251,9 @@ class FunctionGraph(utils.object2):
for
r
in
node
.
inputs
:
for
r
in
node
.
inputs
:
if
hasattr
(
r
,
'fgraph'
)
and
r
.
fgraph
is
not
self
:
if
hasattr
(
r
,
'fgraph'
)
and
r
.
fgraph
is
not
self
:
raise
Exception
(
"
%
s is already owned by another fgraph"
%
r
)
raise
Exception
(
"
%
s is already owned by another fgraph"
%
r
)
if
r
.
owner
is
None
and
not
isinstance
(
r
,
graph
.
Constant
)
and
r
not
in
self
.
inputs
:
if
(
r
.
owner
is
None
and
not
isinstance
(
r
,
graph
.
Constant
)
and
r
not
in
self
.
inputs
):
#Verbose error message
#Verbose error message
#Show a complete chain of variables from the missing input to an output
#Show a complete chain of variables from the missing input to an output
...
@@ -330,9 +335,7 @@ class FunctionGraph(utils.object2):
...
@@ -330,9 +335,7 @@ class FunctionGraph(utils.object2):
assert
node
.
fgraph
is
self
assert
node
.
fgraph
is
self
self
.
execute_callbacks
(
'on_import'
,
node
)
self
.
execute_callbacks
(
'on_import'
,
node
)
### prune ###
### prune ###
def
__prune_r__
(
self
,
variables
):
def
__prune_r__
(
self
,
variables
):
# Prunes the owners of the variables.
# Prunes the owners of the variables.
for
node
in
set
(
r
.
owner
for
r
in
variables
if
r
.
owner
is
not
None
):
for
node
in
set
(
r
.
owner
for
r
in
variables
if
r
.
owner
is
not
None
):
...
@@ -362,10 +365,7 @@ class FunctionGraph(utils.object2):
...
@@ -362,10 +365,7 @@ class FunctionGraph(utils.object2):
self
.
__remove_clients__
(
input
,
[(
node
,
i
)])
self
.
__remove_clients__
(
input
,
[(
node
,
i
)])
#self.__prune_r__(node.inputs)
#self.__prune_r__(node.inputs)
### change input ###
### change input ###
def
change_input
(
self
,
node
,
i
,
new_r
,
reason
=
None
):
def
change_input
(
self
,
node
,
i
,
new_r
,
reason
=
None
):
"""WRITEME
"""WRITEME
Changes node.inputs[i] to new_r.
Changes node.inputs[i] to new_r.
...
@@ -404,7 +404,8 @@ class FunctionGraph(utils.object2):
...
@@ -404,7 +404,8 @@ class FunctionGraph(utils.object2):
# Precondition: the substitution is semantically valid
# Precondition: the substitution is semantically valid
# However it may introduce cycles to the graph, in which case the
# However it may introduce cycles to the graph, in which case the
# transaction will be reverted later.
# transaction will be reverted later.
self
.
execute_callbacks
(
'on_change_input'
,
node
,
i
,
r
,
new_r
,
reason
=
reason
)
self
.
execute_callbacks
(
'on_change_input'
,
node
,
i
,
r
,
new_r
,
reason
=
reason
)
if
prune
:
if
prune
:
self
.
__prune_r__
([
r
])
self
.
__prune_r__
([
r
])
...
@@ -440,8 +441,6 @@ class FunctionGraph(utils.object2):
...
@@ -440,8 +441,6 @@ class FunctionGraph(utils.object2):
for
r
,
new_r
in
pairs
:
for
r
,
new_r
in
pairs
:
self
.
replace
(
r
,
new_r
,
reason
=
reason
)
self
.
replace
(
r
,
new_r
,
reason
=
reason
)
def
extend
(
self
,
feature
):
def
extend
(
self
,
feature
):
warnings
.
warn
(
"FunctionGraph.extend is deprecatd. It has been "
warnings
.
warn
(
"FunctionGraph.extend is deprecatd. It has been "
"renamed to FunctionGraph.attach_feature"
)
"renamed to FunctionGraph.attach_feature"
)
...
@@ -481,7 +480,9 @@ class FunctionGraph(utils.object2):
...
@@ -481,7 +480,9 @@ class FunctionGraph(utils.object2):
"""WRITEME
"""WRITEME
Removes the feature from the graph.
Removes the feature from the graph.
Calls feature.on_detach(function_graph) if an on_detach method is defined.
Calls feature.on_detach(function_graph) if an on_detach method
is defined.
"""
"""
try
:
try
:
self
.
_features
.
remove
(
feature
)
self
.
_features
.
remove
(
feature
)
...
@@ -491,9 +492,7 @@ class FunctionGraph(utils.object2):
...
@@ -491,9 +492,7 @@ class FunctionGraph(utils.object2):
if
detach
is
not
None
:
if
detach
is
not
None
:
detach
(
self
)
detach
(
self
)
### callback utils ###
### callback utils ###
def
execute_callbacks
(
self
,
name
,
*
args
,
**
kwargs
):
def
execute_callbacks
(
self
,
name
,
*
args
,
**
kwargs
):
"""WRITEME
"""WRITEME
Calls
Calls
...
@@ -518,7 +517,6 @@ class FunctionGraph(utils.object2):
...
@@ -518,7 +517,6 @@ class FunctionGraph(utils.object2):
else
:
else
:
raise
raise
def
collect_callbacks
(
self
,
name
,
*
args
):
def
collect_callbacks
(
self
,
name
,
*
args
):
"""WRITEME
"""WRITEME
Returns a dictionary d such that:
Returns a dictionary d such that:
...
@@ -534,9 +532,7 @@ class FunctionGraph(utils.object2):
...
@@ -534,9 +532,7 @@ class FunctionGraph(utils.object2):
d
[
feature
]
=
fn
(
*
args
)
d
[
feature
]
=
fn
(
*
args
)
return
d
return
d
### misc ###
### misc ###
def
toposort
(
self
):
def
toposort
(
self
):
"""WRITEME
"""WRITEME
Returns an ordering of the graph's Apply nodes such that:
Returns an ordering of the graph's Apply nodes such that:
...
@@ -552,8 +548,8 @@ class FunctionGraph(utils.object2):
...
@@ -552,8 +548,8 @@ class FunctionGraph(utils.object2):
if
len
(
self
.
apply_nodes
)
<
2
:
if
len
(
self
.
apply_nodes
)
<
2
:
# optimization
# optimization
# when there are 0 or 1 nodes, no sorting is necessary
# when there are 0 or 1 nodes, no sorting is necessary
# This special case happens a lot because the OpWiseCLinker
produces
# This special case happens a lot because the OpWiseCLinker
# 1-element graphs.
#
produces
1-element graphs.
return
list
(
self
.
apply_nodes
)
return
list
(
self
.
apply_nodes
)
fg
=
self
fg
=
self
...
@@ -568,8 +564,9 @@ class FunctionGraph(utils.object2):
...
@@ -568,8 +564,9 @@ class FunctionGraph(utils.object2):
Return dict d s.t. d[node] is a list of nodes that must be evaluated
Return dict d s.t. d[node] is a list of nodes that must be evaluated
before node itself can be evaluated.
before node itself can be evaluated.
This is used primarily by the destroy_handler feature to ensure that all
This is used primarily by the destroy_handler feature to ensure that
clients of any destroyed inputs have already computed their outputs.
all clients of any destroyed inputs have already computed their
outputs.
:note: This only calls the orderings() fct on all features. It does not
:note: This only calls the orderings() fct on all features. It does not
take care of computing dependencies by itself.
take care of computing dependencies by itself.
...
@@ -586,12 +583,13 @@ class FunctionGraph(utils.object2):
...
@@ -586,12 +583,13 @@ class FunctionGraph(utils.object2):
+
". Nondeterministic object is "
+
str
(
orderings
))
+
". Nondeterministic object is "
+
str
(
orderings
))
for
node
,
prereqs
in
orderings
.
items
():
for
node
,
prereqs
in
orderings
.
items
():
if
not
isinstance
(
prereqs
,
(
list
,
OrderedSet
)):
if
not
isinstance
(
prereqs
,
(
list
,
OrderedSet
)):
raise
TypeError
(
"prereqs must be a type with a "
raise
TypeError
(
"prereqs must be a type with a "
"deterministic iteration order, or toposort "
"deterministic iteration order, or toposort "
" will be non-deterministic."
)
" will be non-deterministic."
)
ords
.
setdefault
(
node
,
[])
.
extend
(
prereqs
)
ords
.
setdefault
(
node
,
[])
.
extend
(
prereqs
)
# eliminate duplicate prereqs
# eliminate duplicate prereqs
for
(
node
,
prereqs
)
in
ords
.
items
():
for
(
node
,
prereqs
)
in
ords
.
items
():
ords
[
node
]
=
list
(
OrderedSet
(
prereqs
))
ords
[
node
]
=
list
(
OrderedSet
(
prereqs
))
return
ords
return
ords
...
@@ -624,34 +622,48 @@ class FunctionGraph(utils.object2):
...
@@ -624,34 +622,48 @@ class FunctionGraph(utils.object2):
if
self
.
apply_nodes
!=
nodes
:
if
self
.
apply_nodes
!=
nodes
:
missing
=
nodes
.
difference
(
self
.
apply_nodes
)
missing
=
nodes
.
difference
(
self
.
apply_nodes
)
excess
=
self
.
apply_nodes
.
difference
(
nodes
)
excess
=
self
.
apply_nodes
.
difference
(
nodes
)
raise
Exception
(
"The nodes are inappropriately cached. missing, in excess: "
,
missing
,
excess
)
raise
Exception
(
"The nodes are inappropriately cached. missing, in excess: "
,
missing
,
excess
)
for
node
in
nodes
:
for
node
in
nodes
:
if
node
.
fgraph
is
not
self
:
if
node
.
fgraph
is
not
self
:
raise
Exception
(
"Node should belong to the FunctionGraph."
,
node
)
raise
Exception
(
"Node should belong to the FunctionGraph."
,
node
)
for
i
,
variable
in
enumerate
(
node
.
inputs
):
for
i
,
variable
in
enumerate
(
node
.
inputs
):
if
variable
.
fgraph
is
not
self
:
if
variable
.
fgraph
is
not
self
:
raise
Exception
(
"Input of node should belong to the FunctionGraph."
,
variable
,
(
node
,
i
))
raise
Exception
(
"Input of node should belong to the FunctionGraph."
,
variable
,
(
node
,
i
))
if
(
node
,
i
)
not
in
variable
.
clients
:
if
(
node
,
i
)
not
in
variable
.
clients
:
raise
Exception
(
"Inconsistent clients list."
,
(
node
,
i
),
variable
.
clients
)
raise
Exception
(
"Inconsistent clients list."
,
(
node
,
i
),
variable
.
clients
)
variables
=
set
(
graph
.
variables
(
self
.
inputs
,
self
.
outputs
))
variables
=
set
(
graph
.
variables
(
self
.
inputs
,
self
.
outputs
))
if
set
(
self
.
variables
)
!=
variables
:
if
set
(
self
.
variables
)
!=
variables
:
missing
=
variables
.
difference
(
self
.
variables
)
missing
=
variables
.
difference
(
self
.
variables
)
excess
=
self
.
variables
.
difference
(
variables
)
excess
=
self
.
variables
.
difference
(
variables
)
raise
Exception
(
"The variables are inappropriately cached. missing, in excess: "
,
missing
,
excess
)
raise
Exception
(
"The variables are inappropriately cached. missing, in excess: "
,
missing
,
excess
)
for
variable
in
variables
:
for
variable
in
variables
:
if
variable
.
owner
is
None
and
variable
not
in
self
.
inputs
and
not
isinstance
(
variable
,
graph
.
Constant
):
if
(
variable
.
owner
is
None
and
variable
not
in
self
.
inputs
and
not
isinstance
(
variable
,
graph
.
Constant
)):
raise
Exception
(
"Undeclared input."
,
variable
)
raise
Exception
(
"Undeclared input."
,
variable
)
if
variable
.
fgraph
is
not
self
:
if
variable
.
fgraph
is
not
self
:
raise
Exception
(
"Variable should belong to the FunctionGraph."
,
variable
)
raise
Exception
(
"Variable should belong to the FunctionGraph."
,
variable
)
for
node
,
i
in
variable
.
clients
:
for
node
,
i
in
variable
.
clients
:
if
node
==
'output'
:
if
node
==
'output'
:
if
self
.
outputs
[
i
]
is
not
variable
:
if
self
.
outputs
[
i
]
is
not
variable
:
raise
Exception
(
"Inconsistent clients list."
,
variable
,
self
.
outputs
[
i
])
raise
Exception
(
"Inconsistent clients list."
,
variable
,
self
.
outputs
[
i
])
continue
continue
if
node
not
in
nodes
:
if
node
not
in
nodes
:
raise
Exception
(
"Client not in FunctionGraph."
,
variable
,
(
node
,
i
))
raise
Exception
(
"Client not in FunctionGraph."
,
variable
,
(
node
,
i
))
if
node
.
inputs
[
i
]
is
not
variable
:
if
node
.
inputs
[
i
]
is
not
variable
:
raise
Exception
(
"Inconsistent clients list."
,
variable
,
node
.
inputs
[
i
])
raise
Exception
(
"Inconsistent clients list."
,
variable
,
node
.
inputs
[
i
])
def
__str__
(
self
):
def
__str__
(
self
):
return
"[
%
s]"
%
", "
.
join
(
graph
.
as_string
(
self
.
inputs
,
self
.
outputs
))
return
"[
%
s]"
%
", "
.
join
(
graph
.
as_string
(
self
.
inputs
,
self
.
outputs
))
...
@@ -659,9 +671,7 @@ class FunctionGraph(utils.object2):
...
@@ -659,9 +671,7 @@ class FunctionGraph(utils.object2):
def
__repr__
(
self
):
def
__repr__
(
self
):
return
self
.
__str__
()
return
self
.
__str__
()
### clone ###
### clone ###
def
clone
(
self
):
def
clone
(
self
):
"""WRITEME"""
"""WRITEME"""
return
self
.
clone_get_equiv
()[
0
]
return
self
.
clone_get_equiv
()[
0
]
...
...
theano/gof/toolbox.py
浏览文件 @
aca35acc
...
@@ -3,11 +3,9 @@ import time
...
@@ -3,11 +3,9 @@ import time
from
theano.gof.python25
import
partial
from
theano.gof.python25
import
partial
from
theano.gof.python25
import
OrderedDict
from
theano.gof.python25
import
OrderedDict
from
theano.gof
import
graph
from
theano.gof
import
graph
class
AlreadyThere
(
Exception
):
class
AlreadyThere
(
Exception
):
"""Raised by a Feature's on_attach callback method if the FunctionGraph
"""Raised by a Feature's on_attach callback method if the FunctionGraph
attempting to attach the feature already has a functionally identical
attempting to attach the feature already has a functionally identical
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论