提交 198cac60 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

documented _dfs_toposort

上级 c0c25559
......@@ -45,84 +45,108 @@ class DestroyHandler(object):
def orderings(self, fgraph):
return self.map[fgraph].orderings(fgraph)
def _dfs_toposort(i, r_out, orderings):
def _dfs_toposort(inputs, outputs, orderings):
"""
i - list of inputs
o - list of outputs
orderings - dict of additions to the normal inputs and outputs
inputs - list of graph inputs
must be Variable instances
collection must be tuple, list, or deque
outputs - list of graph outputs
must be Variable instances
collection must be tuple, list or deque
orderings - dictionary specifying extra dependencies besides
those encoded in Variable.owner / Apply.inputs
If orderings[my_apply] == dependencies,
then my_apply is an Apply instance,
dependencies is a set of Apply instances,
and every member of dependencies must be executed
before my_apply.
Returns nothing. Raises exception for graph with cycles
The dependencies are typically used to prevent
inplace apply nodes from destroying their input before
other apply nodes with the same input access it.
Returns nothing.
Raises a ValueError whose message contains the substring 'cycle'
if the graph contains a cycle.
The purpose of this function is only to check for cycles.
"""
#this is hard-coded reimplementation of functions from graph.py
# this is hard-coded reimplementation of functions from graph.py
# reason: go faster, prepare for port to C.
# specifically, it is a drop-in replacement for graph.io_toposort
# this version is about 10% faster
assert isinstance(r_out, (tuple, list, deque))
assert isinstance(outputs, (tuple, list, deque))
# TODO: For more speed - use a defaultdict for the orderings
# (defaultdict runs faster than dict in the case where the key
# is not in the dictionary, at least in CPython)
iset = set(i)
if 0:
def expand(obj):
rval = []
if obj not in iset:
if isinstance(obj, graph.Variable):
if obj.owner:
rval = [obj.owner]
if isinstance(obj, graph.Apply):
rval = list(obj.inputs)
rval.extend(orderings.get(obj, []))
else:
assert not orderings.get(obj, [])
return rval
iset = set(inputs)
expand_cache = {}
# reachable, clients = stack_search( deque(r_out), deps, 'dfs', True)
start=deque(r_out)
rval_set = set()
rval_set.add(id(None))
lifo_queue = deque(outputs)
visited_set = set()
visited_set.add(id(None))
rval_list = list()
expand_inv = {}
sources = deque()
while start:
l = start.pop()# this makes the search dfs
if id(l) not in rval_set:
rval_list.append(l)
rval_set.add(id(l))
if l in iset:
assert not orderings.get(l, [])
fifo_queue = deque()
while lifo_queue:
# using pop rather than pop_left makes this queue LIFO
# using a LIFO queue makes the search DFS
cur_var_or_node = lifo_queue.pop()
if id(cur_var_or_node) not in visited_set:
rval_list.append(cur_var_or_node)
visited_set.add(id(cur_var_or_node))
if cur_var_or_node in iset:
# Inputs to the graph must not have any dependencies
# Note: the empty list is treated as false
assert not orderings.get(cur_var_or_node, False)
expand_l = []
else:
try:
if l.owner:
expand_l = [l.owner]
if cur_var_or_node.owner:
expand_l = [cur_var_or_node.owner]
else:
expand_l = []
except AttributeError:
expand_l = list(l.inputs)
expand_l.extend(orderings.get(l, []))
expand_l = list(cur_var_or_node.inputs)
expand_l.extend(orderings.get(cur_var_or_node, []))
if expand_l:
for r in expand_l:
expand_inv.setdefault(r, []).append(l)
start.extend(expand_l)
# insert l in expand_inv[r]
# (if r is not already in expand_inv,
# intialize it to [])
expand_inv.setdefault(r, []).append(cur_var_or_node)
lifo_queue.extend(expand_l)
else:
sources.append(l)
expand_cache[l] = expand_l
assert len(rval_list) == len(rval_set)-1
fifo_queue.append(cur_var_or_node)
expand_cache[cur_var_or_node] = expand_l
# visited_set should be 1 longer because it contains id(None)
# TODO: why does it contain id(None) ?
assert len(rval_list) == len(visited_set)-1
rset = set()
rlist = []
while sources:
node = sources.popleft()
while fifo_queue:
node = fifo_queue.popleft()
if node not in rset:
rlist.append(node)
rset.add(node)
for client in expand_inv.get(node, []):
expand_cache[client] = [a for a in expand_cache[client] if a is not node]
if not expand_cache[client]:
sources.append(client)
fifo_queue.append(client)
if len(rlist) != len(rval_list):
raise ValueError('graph contains cycles')
......@@ -131,7 +155,6 @@ def _dfs_toposort(i, r_out, orderings):
def getroot(r, view_i):
"""
For views: Return non-view variable which is ultimatly viewed by r.
......@@ -399,19 +422,17 @@ class DestroyHandlerHelper2(toolbox.Bookkeeper):
b) orderings cannot be topologically sorted.
"""
#print '\nVALIDATE'
if self.destroyers:
try:
ords = self.orderings(fgraph)
except Exception, e:
#print 'orderings failed with:', type(e), e.args
raise
#print 'orderings:', ords
try:
### graph.io_toposort(fgraph.inputs, fgraph.outputs, ords)
_dfs_toposort(fgraph.inputs, fgraph.outputs, ords)
except ValueError, e:
#print 'not passing.', ords
if 'cycles' in str(e):
raise InconsistencyError("Dependency graph contains cycles")
else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论