提交 198cac60 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

documented _dfs_toposort

上级 c0c25559
...@@ -45,84 +45,108 @@ class DestroyHandler(object): ...@@ -45,84 +45,108 @@ class DestroyHandler(object):
def orderings(self, fgraph): def orderings(self, fgraph):
return self.map[fgraph].orderings(fgraph) return self.map[fgraph].orderings(fgraph)
def _dfs_toposort(inputs, outputs, orderings):
def _dfs_toposort(i, r_out, orderings):
""" """
i - list of inputs inputs - list of graph inputs
o - list of outputs must be Variable instances
orderings - dict of additions to the normal inputs and outputs collection must be tuple, list, or deque
outputs - list of graph outputs
must be Variable instances
collection must be tuple, list or deque
orderings - dictionary specifying extra dependencies besides
those encoded in Variable.owner / Apply.inputs
If orderings[my_apply] == dependencies,
then my_apply is an Apply instance,
dependencies is a set of Apply instances,
and every member of dependencies must be executed
before my_apply.
Returns nothing. Raises exception for graph with cycles The dependencies are typically used to prevent
inplace apply nodes from destroying their input before
other apply nodes with the same input access it.
Returns nothing.
Raises a ValueError whose message contains the substring 'cycle'
if the graph contains a cycle.
The purpose of this function is only to check for cycles.
""" """
#this is hard-coded reimplementation of functions from graph.py
# this is hard-coded reimplementation of functions from graph.py
# reason: go faster, prepare for port to C. # reason: go faster, prepare for port to C.
# specifically, it is a drop-in replacement for graph.io_toposort
# this version is about 10% faster
assert isinstance(r_out, (tuple, list, deque)) assert isinstance(outputs, (tuple, list, deque))
# TODO: For more speed - use a defaultdict for the orderings # TODO: For more speed - use a defaultdict for the orderings
# (defaultdict runs faster than dict in the case where the key
# is not in the dictionary, at least in CPython)
iset = set(inputs)
iset = set(i)
if 0:
def expand(obj):
rval = []
if obj not in iset:
if isinstance(obj, graph.Variable):
if obj.owner:
rval = [obj.owner]
if isinstance(obj, graph.Apply):
rval = list(obj.inputs)
rval.extend(orderings.get(obj, []))
else:
assert not orderings.get(obj, [])
return rval
expand_cache = {} expand_cache = {}
# reachable, clients = stack_search( deque(r_out), deps, 'dfs', True) lifo_queue = deque(outputs)
start=deque(r_out) visited_set = set()
rval_set = set() visited_set.add(id(None))
rval_set.add(id(None))
rval_list = list() rval_list = list()
expand_inv = {} expand_inv = {}
sources = deque() fifo_queue = deque()
while start:
l = start.pop()# this makes the search dfs while lifo_queue:
if id(l) not in rval_set: # using pop rather than pop_left makes this queue LIFO
rval_list.append(l) # using a LIFO queue makes the search DFS
rval_set.add(id(l)) cur_var_or_node = lifo_queue.pop()
if l in iset:
assert not orderings.get(l, []) if id(cur_var_or_node) not in visited_set:
rval_list.append(cur_var_or_node)
visited_set.add(id(cur_var_or_node))
if cur_var_or_node in iset:
# Inputs to the graph must not have any dependencies
# Note: the empty list is treated as false
assert not orderings.get(cur_var_or_node, False)
expand_l = [] expand_l = []
else: else:
try: try:
if l.owner: if cur_var_or_node.owner:
expand_l = [l.owner] expand_l = [cur_var_or_node.owner]
else: else:
expand_l = [] expand_l = []
except AttributeError: except AttributeError:
expand_l = list(l.inputs) expand_l = list(cur_var_or_node.inputs)
expand_l.extend(orderings.get(l, [])) expand_l.extend(orderings.get(cur_var_or_node, []))
if expand_l: if expand_l:
for r in expand_l: for r in expand_l:
expand_inv.setdefault(r, []).append(l) # insert l in expand_inv[r]
start.extend(expand_l) # (if r is not already in expand_inv,
# intialize it to [])
expand_inv.setdefault(r, []).append(cur_var_or_node)
lifo_queue.extend(expand_l)
else: else:
sources.append(l) fifo_queue.append(cur_var_or_node)
expand_cache[l] = expand_l expand_cache[cur_var_or_node] = expand_l
assert len(rval_list) == len(rval_set)-1 # visited_set should be 1 longer because it contains id(None)
# TODO: why does it contain id(None) ?
assert len(rval_list) == len(visited_set)-1
rset = set() rset = set()
rlist = [] rlist = []
while sources: while fifo_queue:
node = sources.popleft() node = fifo_queue.popleft()
if node not in rset: if node not in rset:
rlist.append(node) rlist.append(node)
rset.add(node) rset.add(node)
for client in expand_inv.get(node, []): for client in expand_inv.get(node, []):
expand_cache[client] = [a for a in expand_cache[client] if a is not node] expand_cache[client] = [a for a in expand_cache[client] if a is not node]
if not expand_cache[client]: if not expand_cache[client]:
sources.append(client) fifo_queue.append(client)
if len(rlist) != len(rval_list): if len(rlist) != len(rval_list):
raise ValueError('graph contains cycles') raise ValueError('graph contains cycles')
...@@ -131,7 +155,6 @@ def _dfs_toposort(i, r_out, orderings): ...@@ -131,7 +155,6 @@ def _dfs_toposort(i, r_out, orderings):
def getroot(r, view_i): def getroot(r, view_i):
""" """
For views: Return non-view variable which is ultimatly viewed by r. For views: Return non-view variable which is ultimatly viewed by r.
...@@ -399,19 +422,17 @@ class DestroyHandlerHelper2(toolbox.Bookkeeper): ...@@ -399,19 +422,17 @@ class DestroyHandlerHelper2(toolbox.Bookkeeper):
b) orderings cannot be topologically sorted. b) orderings cannot be topologically sorted.
""" """
#print '\nVALIDATE'
if self.destroyers: if self.destroyers:
try: try:
ords = self.orderings(fgraph) ords = self.orderings(fgraph)
except Exception, e: except Exception, e:
#print 'orderings failed with:', type(e), e.args
raise raise
#print 'orderings:', ords
try: try:
### graph.io_toposort(fgraph.inputs, fgraph.outputs, ords) ### graph.io_toposort(fgraph.inputs, fgraph.outputs, ords)
_dfs_toposort(fgraph.inputs, fgraph.outputs, ords) _dfs_toposort(fgraph.inputs, fgraph.outputs, ords)
except ValueError, e: except ValueError, e:
#print 'not passing.', ords
if 'cycles' in str(e): if 'cycles' in str(e):
raise InconsistencyError("Dependency graph contains cycles") raise InconsistencyError("Dependency graph contains cycles")
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论