提交 848988cd authored 作者: Matthew Rocklin's avatar Matthew Rocklin

add partially ordered sort

上级 f7b869ff
......@@ -1045,6 +1045,79 @@ def dependence(a, b):
if depends((b, a)): return -1
return 0
def reverse_dict(d):
""" Reverses direction of dependence dict
>>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
>>> reverse_dict(d)
{1: ('a',), 2: ('a', 'b'), 3: ('b',)}
"""
result = {}
for key in d:
for val in d[key]:
result[val] = result.get(val, tuple()) + (key, )
return result
def _toposort(edges):
""" Topological sort algorithm by Kahn [1] - O(nodes + vertices)
Closely follows the wikipedia page [2]
[1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
Communications of the ACM
[2] http://en.wikipedia.org/wiki/Toposort#Algorithms
"""
incoming_edges = reverse_dict(edges)
incoming_edges = {k: set(val) for k, val in incoming_edges.items()}
S = set((v for v in edges if v not in incoming_edges))
L = []
while S:
n = S.pop()
L.append(n)
for m in edges.get(n, ()):
assert n in incoming_edges[m]
incoming_edges[m].remove(n)
if not incoming_edges[m]:
S.add(m)
if any(incoming_edges.get(v, None) for v in edges):
raise ValueError("Input has cycles")
return L
def posort(l, *cmps):
""" Partially ordered sort with multiple comparators
implemented with _toposort """
comes_before = {a: set() for a in l}
comes_after = {a: set() for a in l}
def add_links(a, b): # b depends on a
comes_after[a].add(b)
comes_after[a].update(comes_after[b])
for c in comes_before[a]:
comes_after[c].update(comes_after[a])
comes_before[b].add(a)
comes_before[b].update(comes_before[a])
for c in comes_after[b]:
comes_before[c].update(comes_before[b])
def check():
""" Tests for cycles in manufactured edges """
for a in l:
for b in l:
assert not(b in comes_after[a] and a in comes_after[b])
for cmp in cmps:
for a in l:
for b in l:
if cmp(a, b) < 0: # a wants to come before b
# if this wouldn't cause a cycle and isn't already known
if not b in comes_before[a] and not b in comes_after[a]:
add_links(a, b)
# check() # debug code
return _toposort(comes_after)
def sort_apply_nodes(inputs, outputs, cmps):
""" Order a graph of apply nodes according to a list of comparators
......@@ -1065,10 +1138,4 @@ def sort_apply_nodes(inputs, outputs, cmps):
dot(Elemwise{mul,no_inplace}.0, Elemwise{add,no_inplace}.0)]
"""
# An aggregate comparator - looks at each cmp function in order
def cmp(a,b, fns=cmps):
if not fns: return 0
head, tail = fns[0], fns[1:]
return head(a, b) or cmp(a, b, tail)
return sorted(list_of_nodes(inputs, outputs), cmp=cmp)
return posort(list_of_nodes(inputs, outputs), *cmps)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论