Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
67927672
提交
67927672
authored
6月 01, 2017
作者:
Frédéric Bastien
提交者:
GitHub
6月 01, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #5794 from ReyhaneAskari/faster_topo
Faster topo
上级
ba982dbf
82abb2a9
隐藏空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
141 行增加
和
396 行删除
+141
-396
config.txt
doc/library/config.txt
+14
-0
profiling.py
theano/compile/profiling.py
+0
-1
configdefaults.py
theano/configdefaults.py
+11
-0
destroyhandler.py
theano/gof/destroyhandler.py
+98
-385
fg.py
theano/gof/fg.py
+18
-10
没有找到文件。
doc/library/config.txt
浏览文件 @
67927672
...
@@ -225,6 +225,20 @@ import theano and print the config variable, as in:
...
@@ -225,6 +225,20 @@ import theano and print the config variable, as in:
If True, we will print extra scan debug information.
If True, we will print extra scan debug information.
.. attribute:: cycle_detection
String value, either ``regular`` or ``fast```
Default: ``regular``
If :attr:`cycle_detection` is set to ``regular``, most inplaces are allowed,
but it is slower. If :attr:`cycle_detection` is set to ``faster``,
less inplaces are allowed, but it makes the compilation faster.
The interaction of which one give the lower peak memory usage is complicated and
not predictable, so if you are close to the peak memory usage, triyng both
could give you a small gain.
.. attribute:: openmp
.. attribute:: openmp
Bool value: either ``True`` or ``False``
Bool value: either ``True`` or ``False``
...
...
theano/compile/profiling.py
浏览文件 @
67927672
...
@@ -830,7 +830,6 @@ class ProfileStats(object):
...
@@ -830,7 +830,6 @@ class ProfileStats(object):
"""
"""
from
theano.gpuarray
import
GpuArrayType
from
theano.gpuarray
import
GpuArrayType
# Initial Mem info values [CPU, GPU]
# Initial Mem info values [CPU, GPU]
node_memory_size
=
[
0
,
0
]
node_memory_size
=
[
0
,
0
]
running_memory_size
=
[
0
,
0
]
running_memory_size
=
[
0
,
0
]
...
...
theano/configdefaults.py
浏览文件 @
67927672
...
@@ -1476,6 +1476,17 @@ AddConfigVar('compile.wait',
...
@@ -1476,6 +1476,17 @@ AddConfigVar('compile.wait',
IntParam
(
5
,
lambda
i
:
i
>
0
,
allow_override
=
False
),
IntParam
(
5
,
lambda
i
:
i
>
0
,
allow_override
=
False
),
in_c_key
=
False
)
in_c_key
=
False
)
AddConfigVar
(
'cycle_detection'
,
"If cycle_detection is set to regular, most inplaces are allowed,"
"but it is slower. If cycle_detection is set to faster, less inplaces"
"are allowed, but it makes the compilation faster."
"The interaction of which one give the lower peak memory usage is"
"complicated and not predictable, so if you are close to the peak"
"memory usage, triyng both could give you a small gain. "
,
EnumStr
(
'regular'
,
'fast'
),
in_c_key
=
False
)
def
_timeout_default
():
def
_timeout_default
():
return
theano
.
config
.
compile
.
wait
*
24
return
theano
.
config
.
compile
.
wait
*
24
...
...
theano/gof/destroyhandler.py
浏览文件 @
67927672
...
@@ -8,8 +8,10 @@ from __future__ import absolute_import, print_function, division
...
@@ -8,8 +8,10 @@ from __future__ import absolute_import, print_function, division
from
collections
import
deque
,
OrderedDict
from
collections
import
deque
,
OrderedDict
from
six
import
iteritems
from
six
import
iteritems
import
itertools
import
theano
import
theano
from
theano
import
config
from
.
import
toolbox
from
.
import
toolbox
from
.
import
graph
from
.
import
graph
from
theano.misc.ordered_set
import
OrderedSet
from
theano.misc.ordered_set
import
OrderedSet
...
@@ -252,370 +254,6 @@ def fast_inplace_check(inputs):
...
@@ -252,370 +254,6 @@ def fast_inplace_check(inputs):
i
not
in
protected_inputs
]
i
not
in
protected_inputs
]
return
inputs
return
inputs
if
0
:
# old, non-incremental version of the DestroyHandler
class
DestroyHandler
(
toolbox
.
Bookkeeper
):
"""
The DestroyHandler class detects when a graph is impossible to evaluate
because of aliasing and destructive operations.
Several data structures are used to do this.
When an Op uses its view_map property to declare that an output may be
aliased to an input, then if that output is destroyed, the input is also
considering to be destroyed. The view_maps of several Ops can feed into
one another and form a directed graph. The consequence of destroying any
variable in such a graph is that all variables in the graph must be
considered to be destroyed, because they could all be refering to the
same underlying storage. In the current implementation, that graph is a
tree, and the root of that tree is called the foundation. The `droot`
property of this class maps from every graph variable to its foundation.
The `impact` property maps backward from the foundation to all of the
variables that depend on it. When any variable is destroyed, this class
marks the foundation of that variable as being destroyed, with the
`root_destroyer` property.
"""
droot
=
{}
"""
destroyed view + nonview variables -> foundation.
"""
impact
=
{}
"""
destroyed nonview variable -> it + all views of it.
"""
root_destroyer
=
{}
"""
root -> destroyer apply.
"""
def
__init__
(
self
,
do_imports_on_attach
=
True
):
self
.
fgraph
=
None
self
.
do_imports_on_attach
=
do_imports_on_attach
def
on_attach
(
self
,
fgraph
):
"""
When attaching to a new fgraph, check that
1) This DestroyHandler wasn't already attached to some fgraph
(its data structures are only set up to serve one)
2) The FunctionGraph doesn't already have a DestroyHandler.
This would result in it validating everything twice, causing
compilation to be slower.
TODO: WRITEME: what does this do besides the checks?
"""
# Do the checking #
already_there
=
False
if
self
.
fgraph
not
in
[
None
,
fgraph
]:
raise
Exception
(
"A DestroyHandler instance can only serve"
" one FunctionGraph. (Matthew 6:24)"
)
for
attr
in
(
'destroyers'
,
'destroy_handler'
):
if
hasattr
(
fgraph
,
attr
):
already_there
=
True
if
already_there
:
# FunctionGraph.attach_feature catches AlreadyThere
# and cancels the attachment
raise
toolbox
.
AlreadyThere
(
"DestroyHandler feature is already present or in"
" conflict with another plugin."
)
# end of checking #
def
get_destroyers_of
(
r
):
droot
,
impact
,
root_destroyer
=
self
.
refresh_droot_impact
()
try
:
return
[
root_destroyer
[
droot
[
r
]]]
except
Exception
:
return
[]
fgraph
.
destroyers
=
get_destroyers_of
fgraph
.
destroy_handler
=
self
self
.
fgraph
=
fgraph
self
.
destroyers
=
OrderedSet
()
# set of Apply instances with non-null destroy_map
self
.
view_i
=
{}
# variable -> variable used in calculation
self
.
view_o
=
{}
# variable -> set of variables that use this one as a direct input
# clients: how many times does an apply use a given variable
self
.
clients
=
{}
# variable -> apply -> ninputs
self
.
stale_droot
=
True
# IG: It's unclear if this is meant to be included in deployed code. It looks like
# it is unnecessary if FunctionGraph is working correctly, so I am commenting uses
# of it (for speed) but leaving the commented code in place so it is easy to restore
# for debugging purposes.
# Note: is there anything like the C preprocessor for python? It would be useful to
# just ifdef these things out
# self.debug_all_apps = set()
if
self
.
do_imports_on_attach
:
toolbox
.
Bookkeeper
.
on_attach
(
self
,
fgraph
)
def
refresh_droot_impact
(
self
):
if
self
.
stale_droot
:
self
.
droot
,
self
.
impact
,
self
.
root_destroyer
=
_build_droot_impact
(
self
)
self
.
stale_droot
=
False
return
self
.
droot
,
self
.
impact
,
self
.
root_destroyer
def
on_detach
(
self
,
fgraph
):
if
fgraph
is
not
self
.
fgraph
:
raise
Exception
(
"detaching wrong fgraph"
,
fgraph
)
del
self
.
destroyers
del
self
.
view_i
del
self
.
view_o
del
self
.
clients
del
self
.
stale_droot
assert
self
.
fgraph
.
destroyer_handler
is
self
delattr
(
self
.
fgraph
,
'destroyers'
)
delattr
(
self
.
fgraph
,
'destroy_handler'
)
self
.
fgraph
=
None
def
on_import
(
self
,
fgraph
,
app
,
reason
):
"""
Add Apply instance to set which must be computed.
"""
# if app in self.debug_all_apps: raise ProtocolError("double import")
# self.debug_all_apps.add(app)
# print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
# If it's a destructive op, add it to our watch list
if
getattr
(
app
.
op
,
'destroy_map'
,
{}):
self
.
destroyers
.
add
(
app
)
# add this symbol to the forward and backward maps
for
o_idx
,
i_idx_list
in
iteritems
(
getattr
(
app
.
op
,
'view_map'
,
{})):
if
len
(
i_idx_list
)
>
1
:
raise
NotImplementedError
(
'destroying this output invalidates multiple inputs'
,
(
app
.
op
))
o
=
app
.
outputs
[
o_idx
]
i
=
app
.
inputs
[
i_idx_list
[
0
]]
self
.
view_i
[
o
]
=
i
self
.
view_o
.
setdefault
(
i
,
OrderedSet
())
.
add
(
o
)
# update self.clients
for
i
,
input
in
enumerate
(
app
.
inputs
):
self
.
clients
.
setdefault
(
input
,
{})
.
setdefault
(
app
,
0
)
self
.
clients
[
input
][
app
]
+=
1
for
i
,
output
in
enumerate
(
app
.
outputs
):
self
.
clients
.
setdefault
(
output
,
{})
self
.
stale_droot
=
True
def
on_prune
(
self
,
fgraph
,
app
,
reason
):
"""
Remove Apply instance from set which must be computed.
"""
# if app not in self.debug_all_apps: raise ProtocolError("prune without import")
# self.debug_all_apps.remove(app)
# UPDATE self.clients
for
i
,
input
in
enumerate
(
OrderedSet
(
app
.
inputs
)):
del
self
.
clients
[
input
][
app
]
if
getattr
(
app
.
op
,
'destroy_map'
,
{}):
self
.
destroyers
.
remove
(
app
)
# Note: leaving empty client dictionaries in the struct.
# Why? It's a pain to remove them. I think they aren't doing any harm, they will be
# deleted on_detach().
# UPDATE self.view_i, self.view_o
for
o_idx
,
i_idx_list
in
iteritems
(
getattr
(
app
.
op
,
'view_map'
,
{})):
if
len
(
i_idx_list
)
>
1
:
# destroying this output invalidates multiple inputs
raise
NotImplementedError
()
o
=
app
.
outputs
[
o_idx
]
i
=
app
.
inputs
[
i_idx_list
[
0
]]
del
self
.
view_i
[
o
]
self
.
view_o
[
i
]
.
remove
(
o
)
if
not
self
.
view_o
[
i
]:
del
self
.
view_o
[
i
]
self
.
stale_droot
=
True
def
on_change_input
(
self
,
fgraph
,
app
,
i
,
old_r
,
new_r
,
reason
):
"""
app.inputs[i] changed from old_r to new_r.
"""
if
app
==
'output'
:
# app == 'output' is special key that means FunctionGraph is redefining which nodes are being
# considered 'outputs' of the graph.
pass
else
:
# if app not in self.debug_all_apps: raise ProtocolError("change without import")
# UPDATE self.clients
self
.
clients
[
old_r
][
app
]
-=
1
if
self
.
clients
[
old_r
][
app
]
==
0
:
del
self
.
clients
[
old_r
][
app
]
self
.
clients
.
setdefault
(
new_r
,
{})
.
setdefault
(
app
,
0
)
self
.
clients
[
new_r
][
app
]
+=
1
# UPDATE self.view_i, self.view_o
for
o_idx
,
i_idx_list
in
iteritems
(
getattr
(
app
.
op
,
'view_map'
,
{})):
if
len
(
i_idx_list
)
>
1
:
# destroying this output invalidates multiple inputs
raise
NotImplementedError
()
i_idx
=
i_idx_list
[
0
]
output
=
app
.
outputs
[
o_idx
]
if
i_idx
==
i
:
if
app
.
inputs
[
i_idx
]
is
not
new_r
:
raise
ProtocolError
(
"wrong new_r on change"
)
self
.
view_i
[
output
]
=
new_r
self
.
view_o
[
old_r
]
.
remove
(
output
)
if
not
self
.
view_o
[
old_r
]:
del
self
.
view_o
[
old_r
]
self
.
view_o
.
setdefault
(
new_r
,
OrderedSet
())
.
add
(
output
)
self
.
stale_droot
=
True
def
validate
(
self
,
fgraph
):
"""
Return None.
Raise InconsistencyError when
a) orderings() raises an error
b) orderings cannot be topologically sorted.
"""
if
self
.
destroyers
:
ords
=
self
.
orderings
(
fgraph
)
if
_contains_cycle
(
fgraph
,
ords
):
raise
InconsistencyError
(
"Dependency graph contains cycles"
)
else
:
# James's Conjecture:
# If there are no destructive ops, then there can be no cycles.
pass
return
True
def
orderings
(
self
,
fgraph
):
"""
Return orderings induced by destructive operations.
Raise InconsistencyError when
a) attempting to destroy indestructable variable, or
b) attempting to destroy a value multiple times, or
c) an Apply destroys (illegally) one of its own inputs by aliasing
"""
rval
=
OrderedDict
()
if
self
.
destroyers
:
# BUILD DATA STRUCTURES
# CHECK for multiple destructions during construction of variables
droot
,
impact
,
__ignore
=
self
.
refresh_droot_impact
()
# check for destruction of constants
illegal_destroy
=
[
r
for
r
in
droot
if
getattr
(
r
.
tag
,
'indestructible'
,
False
)
or
isinstance
(
r
,
graph
.
Constant
)]
if
illegal_destroy
:
# print 'destroying illegally'
raise
InconsistencyError
(
"Attempting to destroy indestructible variables:
%
s"
%
illegal_destroy
)
# add destroyed variable clients as computational dependencies
for
app
in
self
.
destroyers
:
# for each destroyed input...
for
output_idx
,
input_idx_list
in
iteritems
(
app
.
op
.
destroy_map
):
destroyed_idx
=
input_idx_list
[
0
]
destroyed_variable
=
app
.
inputs
[
destroyed_idx
]
root
=
droot
[
destroyed_variable
]
root_impact
=
impact
[
root
]
# we generally want to put all clients of things which depend on root
# as pre-requisites of app.
# But, app is itself one such client!
# App will always be a client of the node we're destroying
# (destroyed_variable, but the tricky thing is when it is also a client of
# *another variable* viewing on the root. Generally this is illegal, (e.g.,
# add_inplace(x, x.T). In some special cases though, the in-place op will
# actually be able to work properly with multiple destroyed inputs (e.g,
# add_inplace(x, x). An Op that can still work in this case should declare
# so via the 'destroyhandler_tolerate_same' attribute or
# 'destroyhandler_tolerate_aliased' attribute.
#
# destroyhandler_tolerate_same should be a list of pairs of the form
# [(idx0, idx1), (idx0, idx2), ...]
# The first element of each pair is the input index of a destroyed
# variable.
# The second element of each pair is the index of a different input where
# we will permit exactly the same variable to appear.
# For example, add_inplace.tolerate_same might be [(0,1)] if the destroyed
# input is also allowed to appear as the second argument.
#
# destroyhandler_tolerate_aliased is the same sort of list of
# pairs.
# op.destroyhandler_tolerate_aliased = [(idx0, idx1)] tells the
# destroyhandler to IGNORE an aliasing between a destroyed
# input idx0 and another input idx1.
# This is generally a bad idea, but it is safe in some
# cases, such as
# - the op reads from the aliased idx1 before modifying idx0
# - the idx0 and idx1 are guaranteed not to overlap (e.g.
# they are pointed at different rows of a matrix).
#
# CHECK FOR INPUT ALIASING
# OPT: pre-compute this on import
tolerate_same
=
getattr
(
app
.
op
,
'destroyhandler_tolerate_same'
,
[])
assert
isinstance
(
tolerate_same
,
list
)
tolerated
=
OrderedSet
(
idx1
for
idx0
,
idx1
in
tolerate_same
if
idx0
==
destroyed_idx
)
tolerated
.
add
(
destroyed_idx
)
tolerate_aliased
=
getattr
(
app
.
op
,
'destroyhandler_tolerate_aliased'
,
[])
assert
isinstance
(
tolerate_aliased
,
list
)
ignored
=
OrderedSet
(
idx1
for
idx0
,
idx1
in
tolerate_aliased
if
idx0
==
destroyed_idx
)
# print 'tolerated', tolerated
# print 'ignored', ignored
for
i
,
input
in
enumerate
(
app
.
inputs
):
if
i
in
ignored
:
continue
if
input
in
root_impact
\
and
(
i
not
in
tolerated
or
input
is
not
destroyed_variable
):
raise
InconsistencyError
(
"Input aliasing:
%
s (
%
i,
%
i)"
%
(
app
,
destroyed_idx
,
i
))
# add the rule: app must be preceded by all other Apply instances that
# depend on destroyed_input
root_clients
=
OrderedSet
()
for
r
in
root_impact
:
assert
not
[
a
for
a
,
c
in
iteritems
(
self
.
clients
[
r
])
if
not
c
]
root_clients
.
update
([
a
for
a
,
c
in
iteritems
(
self
.
clients
[
r
])
if
c
])
root_clients
.
remove
(
app
)
if
root_clients
:
rval
[
app
]
=
root_clients
return
rval
class
DestroyHandler
(
toolbox
.
Bookkeeper
):
# noqa
class
DestroyHandler
(
toolbox
.
Bookkeeper
):
# noqa
"""
"""
...
@@ -661,7 +299,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
...
@@ -661,7 +299,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
"""
"""
pickle_rm_attr
=
[
"destroyers"
]
pickle_rm_attr
=
[
"destroyers"
]
def
__init__
(
self
,
do_imports_on_attach
=
True
):
def
__init__
(
self
,
do_imports_on_attach
=
True
,
algo
=
None
):
self
.
fgraph
=
None
self
.
fgraph
=
None
self
.
do_imports_on_attach
=
do_imports_on_attach
self
.
do_imports_on_attach
=
do_imports_on_attach
...
@@ -691,6 +329,10 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
...
@@ -691,6 +329,10 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
"""
"""
self
.
root_destroyer
=
OrderedDict
()
self
.
root_destroyer
=
OrderedDict
()
if
algo
is
None
:
algo
=
config
.
cycle_detection
self
.
algo
=
algo
self
.
fail_validate
=
OrderedDict
()
def
on_attach
(
self
,
fgraph
):
def
on_attach
(
self
,
fgraph
):
"""
"""
...
@@ -733,19 +375,19 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
...
@@ -733,19 +375,19 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
self
.
fgraph
=
fgraph
self
.
fgraph
=
fgraph
self
.
destroyers
=
OrderedSet
()
# set of Apply instances with non-null destroy_map
self
.
destroyers
=
OrderedSet
()
# set of Apply instances with non-null destroy_map
self
.
view_i
=
OrderedDict
()
# variable -> variable used in calculation
self
.
view_i
=
{}
# variable -> variable used in calculation
self
.
view_o
=
OrderedDict
()
# variable -> set of variables that use this one as a direct input
self
.
view_o
=
{}
# variable -> set of variables that use this one as a direct input
# clients: how many times does an apply use a given variable
# clients: how many times does an apply use a given variable
self
.
clients
=
OrderedDict
()
# variable -> apply -> ninputs
self
.
clients
=
OrderedDict
()
# variable -> apply -> ninputs
self
.
stale_droot
=
True
self
.
stale_droot
=
True
self
.
debug_all_apps
=
OrderedS
et
()
self
.
debug_all_apps
=
s
et
()
if
self
.
do_imports_on_attach
:
if
self
.
do_imports_on_attach
:
toolbox
.
Bookkeeper
.
on_attach
(
self
,
fgraph
)
toolbox
.
Bookkeeper
.
on_attach
(
self
,
fgraph
)
def
unpickle
(
self
,
fgraph
):
def
unpickle
(
self
,
fgraph
):
def
get_destroyers_of
(
r
):
def
get_destroyers_of
(
r
):
droot
,
impact
,
root_destroyer
=
self
.
refresh_droot_impact
()
droot
,
_
,
root_destroyer
=
self
.
refresh_droot_impact
()
try
:
try
:
return
[
root_destroyer
[
droot
[
r
]]]
return
[
root_destroyer
[
droot
[
r
]]]
except
Exception
:
except
Exception
:
...
@@ -777,20 +419,65 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
...
@@ -777,20 +419,65 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
delattr
(
self
.
fgraph
,
'destroy_handler'
)
delattr
(
self
.
fgraph
,
'destroy_handler'
)
self
.
fgraph
=
None
self
.
fgraph
=
None
def
fast_destroy
(
self
,
app
,
reason
):
"""
Do the check for only 1 level.
For now:
- Destroyed variables can have only 1 clients.
- Allow view to have multiple clients.
- Allow sequence of view.
- But don't allow to destroy view
"""
dm
=
getattr
(
app
.
op
,
'destroy_map'
,
None
)
if
not
dm
:
return
inputs
=
set
(
itertools
.
chain
.
from_iterable
(
dm
.
values
()))
# list of app's destroyed inputs
for
inp_idx
in
inputs
:
inp
=
app
.
inputs
[
inp_idx
]
if
getattr
(
inp
.
tag
,
'indestructible'
,
False
)
or
isinstance
(
inp
,
graph
.
Constant
):
self
.
fail_validate
[
app
]
=
InconsistencyError
(
"Attempting to destroy indestructible variables:
%
s"
%
inp
)
elif
len
(
inp
.
clients
)
>
1
:
self
.
fail_validate
[
app
]
=
theano
.
gof
.
InconsistencyError
(
"Destroyed variable has more than one client. "
+
str
(
reason
))
elif
inp
.
owner
:
app2
=
inp
.
owner
inp_idx2
=
app2
.
outputs
.
index
(
inp
)
v
=
getattr
(
app2
.
op
,
'view_map'
,
{})
d
=
getattr
(
app2
.
op
,
'destroy_map'
,
{})
if
v
:
v
=
v
.
get
(
inp_idx2
,
[])
if
len
(
v
)
>
0
:
self
.
fail_validate
[
app
]
=
theano
.
gof
.
InconsistencyError
(
"Destroyed variable has view_map. "
+
str
(
reason
))
elif
d
:
d
=
d
.
get
(
inp_idx2
,
[])
if
len
(
d
)
>
0
:
self
.
fail_validate
[
app
]
=
theano
.
gof
.
InconsistencyError
(
"Destroyed variable has destroy_map. "
+
str
(
reason
))
# These 2 assertions are commented since this function is called so many times
# but they should be true.
# assert len(v) <= 1
# assert len(d) <= 1
def
on_import
(
self
,
fgraph
,
app
,
reason
):
def
on_import
(
self
,
fgraph
,
app
,
reason
):
"""
"""
Add Apply instance to set which must be computed.
Add Apply instance to set which must be computed.
"""
"""
if
app
in
self
.
debug_all_apps
:
if
app
in
self
.
debug_all_apps
:
raise
ProtocolError
(
"double import"
)
raise
ProtocolError
(
"double import"
)
self
.
debug_all_apps
.
add
(
app
)
self
.
debug_all_apps
.
add
(
app
)
# print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
# print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
# If it's a destructive op, add it to our watch list
# If it's a destructive op, add it to our watch list
if
getattr
(
app
.
op
,
'destroy_map'
,
{}
):
if
getattr
(
app
.
op
,
'destroy_map'
,
None
):
self
.
destroyers
.
add
(
app
)
self
.
destroyers
.
add
(
app
)
if
self
.
algo
==
'fast'
:
self
.
fast_destroy
(
app
,
reason
)
# add this symbol to the forward and backward maps
# add this symbol to the forward and backward maps
for
o_idx
,
i_idx_list
in
iteritems
(
getattr
(
app
.
op
,
'view_map'
,
{})):
for
o_idx
,
i_idx_list
in
iteritems
(
getattr
(
app
.
op
,
'view_map'
,
{})):
...
@@ -823,7 +510,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
...
@@ -823,7 +510,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
self
.
debug_all_apps
.
remove
(
app
)
self
.
debug_all_apps
.
remove
(
app
)
# UPDATE self.clients
# UPDATE self.clients
for
i
,
input
in
enumerate
(
OrderedSet
(
app
.
inputs
)
):
for
i
nput
in
set
(
app
.
inputs
):
del
self
.
clients
[
input
][
app
]
del
self
.
clients
[
input
][
app
]
if
getattr
(
app
.
op
,
'destroy_map'
,
OrderedDict
()):
if
getattr
(
app
.
op
,
'destroy_map'
,
OrderedDict
()):
...
@@ -849,6 +536,8 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
...
@@ -849,6 +536,8 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
del
self
.
view_o
[
i
]
del
self
.
view_o
[
i
]
self
.
stale_droot
=
True
self
.
stale_droot
=
True
if
app
in
self
.
fail_validate
:
del
self
.
fail_validate
[
app
]
def
on_change_input
(
self
,
fgraph
,
app
,
i
,
old_r
,
new_r
,
reason
):
def
on_change_input
(
self
,
fgraph
,
app
,
i
,
old_r
,
new_r
,
reason
):
"""
"""
...
@@ -891,6 +580,10 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
...
@@ -891,6 +580,10 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
self
.
view_o
.
setdefault
(
new_r
,
OrderedSet
())
.
add
(
output
)
self
.
view_o
.
setdefault
(
new_r
,
OrderedSet
())
.
add
(
output
)
if
self
.
algo
==
'fast'
:
if
app
in
self
.
fail_validate
:
del
self
.
fail_validate
[
app
]
self
.
fast_destroy
(
app
,
reason
)
self
.
stale_droot
=
True
self
.
stale_droot
=
True
def
validate
(
self
,
fgraph
):
def
validate
(
self
,
fgraph
):
...
@@ -903,10 +596,27 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
...
@@ -903,10 +596,27 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
"""
"""
if
self
.
destroyers
:
if
self
.
destroyers
:
ords
=
self
.
orderings
(
fgraph
)
if
self
.
algo
==
'fast'
:
if
self
.
fail_validate
:
if
_contains_cycle
(
fgraph
,
ords
):
app_err_pairs
=
self
.
fail_validate
raise
InconsistencyError
(
"Dependency graph contains cycles"
)
self
.
fail_validate
=
OrderedDict
()
# self.fail_validate can only be a hint that maybe/probably
# there is a cycle.This is because inside replace() we could
# record many reasons to not accept a change, but we don't
# know which one will fail first inside validate(). Thus,the
# graph might have already changed when we raise the
# self.fail_validate error. So before raising the error, we
# double check here.
for
app
in
app_err_pairs
:
if
app
in
fgraph
.
apply_nodes
:
self
.
fast_destroy
(
app
,
'validate'
)
if
self
.
fail_validate
:
self
.
fail_validate
=
app_err_pairs
raise
app_err_pairs
[
app
]
else
:
ords
=
self
.
orderings
(
fgraph
,
ordered
=
False
)
if
_contains_cycle
(
fgraph
,
ords
):
raise
InconsistencyError
(
"Dependency graph contains cycles"
)
else
:
else
:
# James's Conjecture:
# James's Conjecture:
# If there are no destructive ops, then there can be no cycles.
# If there are no destructive ops, then there can be no cycles.
...
@@ -921,7 +631,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
...
@@ -921,7 +631,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
pass
pass
return
True
return
True
def
orderings
(
self
,
fgraph
):
def
orderings
(
self
,
fgraph
,
ordered
=
True
):
"""
"""
Return orderings induced by destructive operations.
Return orderings induced by destructive operations.
...
@@ -931,7 +641,12 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
...
@@ -931,7 +641,12 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
c) an Apply destroys (illegally) one of its own inputs by aliasing
c) an Apply destroys (illegally) one of its own inputs by aliasing
"""
"""
rval
=
OrderedDict
()
if
ordered
:
set_type
=
OrderedSet
rval
=
OrderedDict
()
else
:
set_type
=
set
rval
=
dict
()
if
self
.
destroyers
:
if
self
.
destroyers
:
# BUILD DATA STRUCTURES
# BUILD DATA STRUCTURES
...
@@ -951,7 +666,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
...
@@ -951,7 +666,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
# add destroyed variable clients as computational dependencies
# add destroyed variable clients as computational dependencies
for
app
in
self
.
destroyers
:
for
app
in
self
.
destroyers
:
# keep track of clients that should run before the current Apply
# keep track of clients that should run before the current Apply
root_clients
=
OrderedSet
()
root_clients
=
set_type
()
# for each destroyed input...
# for each destroyed input...
for
output_idx
,
input_idx_list
in
iteritems
(
app
.
op
.
destroy_map
):
for
output_idx
,
input_idx_list
in
iteritems
(
app
.
op
.
destroy_map
):
destroyed_idx
=
input_idx_list
[
0
]
destroyed_idx
=
input_idx_list
[
0
]
...
@@ -996,16 +711,14 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
...
@@ -996,16 +711,14 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
tolerate_same
=
getattr
(
app
.
op
,
tolerate_same
=
getattr
(
app
.
op
,
'destroyhandler_tolerate_same'
,
[])
'destroyhandler_tolerate_same'
,
[])
assert
isinstance
(
tolerate_same
,
list
)
assert
isinstance
(
tolerate_same
,
list
)
tolerated
=
OrderedS
et
(
idx1
for
idx0
,
idx1
in
tolerate_same
tolerated
=
s
et
(
idx1
for
idx0
,
idx1
in
tolerate_same
if
idx0
==
destroyed_idx
)
if
idx0
==
destroyed_idx
)
tolerated
.
add
(
destroyed_idx
)
tolerated
.
add
(
destroyed_idx
)
tolerate_aliased
=
getattr
(
tolerate_aliased
=
getattr
(
app
.
op
,
'destroyhandler_tolerate_aliased'
,
[])
app
.
op
,
'destroyhandler_tolerate_aliased'
,
[])
assert
isinstance
(
tolerate_aliased
,
list
)
assert
isinstance
(
tolerate_aliased
,
list
)
ignored
=
OrderedSet
(
idx1
for
idx0
,
idx1
in
tolerate_aliased
ignored
=
set
(
idx1
for
idx0
,
idx1
in
tolerate_aliased
if
idx0
==
destroyed_idx
)
if
idx0
==
destroyed_idx
)
# print 'tolerated', tolerated
# print 'ignored', ignored
for
i
,
input
in
enumerate
(
app
.
inputs
):
for
i
,
input
in
enumerate
(
app
.
inputs
):
if
i
in
ignored
:
if
i
in
ignored
:
continue
continue
...
...
theano/gof/fg.py
浏览文件 @
67927672
...
@@ -654,8 +654,9 @@ class FunctionGraph(utils.object2):
...
@@ -654,8 +654,9 @@ class FunctionGraph(utils.object2):
take care of computing dependencies by itself.
take care of computing dependencies by itself.
"""
"""
ords
=
OrderedDict
()
assert
isinstance
(
self
.
_features
,
list
)
assert
isinstance
(
self
.
_features
,
list
)
all_orderings
=
[]
for
feature
in
self
.
_features
:
for
feature
in
self
.
_features
:
if
hasattr
(
feature
,
'orderings'
):
if
hasattr
(
feature
,
'orderings'
):
orderings
=
feature
.
orderings
(
self
)
orderings
=
feature
.
orderings
(
self
)
...
@@ -664,17 +665,24 @@ class FunctionGraph(utils.object2):
...
@@ -664,17 +665,24 @@ class FunctionGraph(utils.object2):
str
(
feature
.
orderings
)
+
str
(
feature
.
orderings
)
+
". Nondeterministic object is "
+
". Nondeterministic object is "
+
str
(
orderings
))
str
(
orderings
))
if
len
(
orderings
)
>
0
:
all_orderings
.
append
(
orderings
)
for
node
,
prereqs
in
iteritems
(
orderings
):
if
not
isinstance
(
prereqs
,
(
list
,
OrderedSet
)):
raise
TypeError
(
"prereqs must be a type with a "
"deterministic iteration order, or toposort "
" will be non-deterministic."
)
if
len
(
all_orderings
)
==
1
:
# If there is only 1 ordering, we reuse it directly.
return
all_orderings
[
0
]
.
copy
()
else
:
# If there is more than 1 ordering, combine them.
ords
=
OrderedDict
()
for
orderings
in
all_orderings
:
for
node
,
prereqs
in
iteritems
(
orderings
):
for
node
,
prereqs
in
iteritems
(
orderings
):
if
not
isinstance
(
prereqs
,
(
list
,
OrderedSet
)):
raise
TypeError
(
"prereqs must be a type with a "
"deterministic iteration order, or toposort "
" will be non-deterministic."
)
ords
.
setdefault
(
node
,
[])
.
extend
(
prereqs
)
ords
.
setdefault
(
node
,
[])
.
extend
(
prereqs
)
# eliminate duplicate prereqs
return
ords
for
(
node
,
prereqs
)
in
iteritems
(
ords
):
ords
[
node
]
=
list
(
OrderedSet
(
prereqs
))
return
ords
def
check_integrity
(
self
):
def
check_integrity
(
self
):
"""
"""
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论