提交 75e65054 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

doc + var renaming

上级 91c18b6d
...@@ -101,14 +101,14 @@ def _contains_cycle(inputs, outputs, orderings): ...@@ -101,14 +101,14 @@ def _contains_cycle(inputs, outputs, orderings):
# get_parents worked better. # get_parents worked better.
# IG: I tried tagging each variable and node with a visited flag # IG: I tried tagging each variable and node with a visited flag
# to avoid needing to do an expand_cache lookup to tell if a # to avoid needing to do a node_to_parents lookup to tell if a
# node was visited. This requires wrapping everything in a # node was visited. This requires wrapping everything in a
# try-finally and setting all the flags to false in the finally. # try-finally and setting all the flags to false in the finally.
# It resulted in a net slowdown, whether I used iteration # It resulted in a net slowdown, whether I used iteration
# on expand_cache or rval_list. (rval_list was a list # on node_to_parents or rval_list. (rval_list was a list
# whose contents were the same as expand_cache.keys()) # whose contents were the same as node_to_parents.keys())
# IG: I tried converting expand_cache to use an id for the key, # IG: I tried converting node_to_parents to use an id for the key,
# so that the dict would do reference counting on its keys. # so that the dict would do reference counting on its keys.
# For some reason this caused a slowdown--not sure if dict is # For some reason this caused a slowdown--not sure if dict is
# slow for int keys, or if call to id function is expensive. # slow for int keys, or if call to id function is expensive.
...@@ -116,43 +116,50 @@ def _contains_cycle(inputs, outputs, orderings): ...@@ -116,43 +116,50 @@ def _contains_cycle(inputs, outputs, orderings):
# DWF tried implementing this as cython, including the deque # DWF tried implementing this as cython, including the deque
# class when compiling cython, and only got a 10% speedup. # class when compiling cython, and only got a 10% speedup.
expand_cache = {}
# dict mapping an Apply or Variable instance to its parents
# (including parents imposed by orderings)
node_to_parents = {}
# the inverse mapping
node_to_children = {}
lifo_queue = deque(outputs) lifo_queue = deque(outputs)
#visited_set = set()
#visited_set.add(id(None))
#rval_list = list()
expand_inv = {}
fifo_queue = deque() fifo_queue = deque()
# Do a DFS through the graph, following the edges backwards from
# the outputs to the inputs. Build the node_to_parents and
# node_to_children dictionaries. Put the roots of the graph
# into fifo_queue
# TODO: does the order of the roots in the fifo_queue matter?
while lifo_queue: while lifo_queue:
# using pop rather than pop_left makes this queue LIFO # using pop rather than pop_left makes this queue LIFO
# using a LIFO queue makes the search DFS # using a LIFO queue makes the search DFS
cur_var_or_node = lifo_queue.pop() node = lifo_queue.pop()
if cur_var_or_node not in expand_cache: # id(cur_var_or_node) not in visited_set: if node not in node_to_parents:
#rval_list.append(cur_var_or_node)
#visited_set.add(id(cur_var_or_node))
if cur_var_or_node in iset: if node in iset:
# Inputs to the graph must not have any dependencies # Inputs to the graph must not have any dependencies
# Note: the empty list is treated as false # Note: the empty list is treated as false
assert not orderings.get(cur_var_or_node, False) assert not orderings.get(node, False)
expand_l = [] parents = []
else: else:
expand_l = cur_var_or_node.get_parents() parents = node.get_parents()
expand_l.extend(orderings.get(cur_var_or_node, [])) parents.extend(orderings.get(node, []))
if expand_l: if parents:
for r in expand_l: for r in parents:
# insert cur_var_or_node in expand_inv[r] # insert node in node_to_children[r]
# (if r is not already in expand_inv, # (if r is not already in node_to_children,
# intialize it to []) # intialize it to [])
expand_inv.setdefault(r, []).append(cur_var_or_node) node_to_children.setdefault(r, []).append(node)
lifo_queue.extend(expand_l) lifo_queue.extend(parents)
else: else:
fifo_queue.append(cur_var_or_node) fifo_queue.append(node)
expand_cache[cur_var_or_node] = expand_l node_to_parents[node] = parents
#assert len(rval_list) == len(expand_cache.keys())
rset = set() rset = set()
rlist = [] rlist = []
...@@ -161,12 +168,13 @@ def _contains_cycle(inputs, outputs, orderings): ...@@ -161,12 +168,13 @@ def _contains_cycle(inputs, outputs, orderings):
if node not in rset: if node not in rset:
rlist.append(node) rlist.append(node)
rset.add(node) rset.add(node)
for client in expand_inv.get(node, []): for client in node_to_children.get(node, []):
expand_cache[client] = [a for a in expand_cache[client] if a is not node] node_to_parents[client] = [a for a in node_to_parents[client] if a is not node]
if not expand_cache[client]: if not node_to_parents[client]:
fifo_queue.append(client) fifo_queue.append(client)
return len(rlist) != len(expand_cache.keys())
return len(rlist) != len(node_to_parents.keys())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论