提交 c7351f2d authored 作者: James Bergstra's avatar James Bergstra

consistent toposort

上级 73b15ce2
from collections import deque
import unittest
from graph import *
from op import Op
from result import Result
def inputs(result_list):
"""
@type result_list: list of L{Result}
@param result_list: output L{Result}s (from which to search backward through owners)
@returns: the list of L{Result}s with no owner, in the order found by a
left-recursive depth-first search started at the L{Result}s in result_list.
"""
def expand(r):
if r.owner:
l = list(r.owner.inputs)
l.reverse()
return l
dfs_results = stack_search(deque(result_list), expand, 'dfs')
rval = [r for r in dfs_results if r.owner is None]
#print rval, _orig_inputs(o)
return rval
if 1:
testcase = unittest.TestCase
else:
......@@ -146,9 +164,87 @@ class _test_clone(testcase):
self.failUnless( s == ["MyOp(R7, R8)"], s)
assert as_string(inputs(op.outputs), op.outputs) == ["MyOp(MyOp(R1, R2), R5)"]
def prenode(obj):
if isinstance(obj, Result):
if obj.owner:
return [obj.owner]
if isinstance(obj, Op):
return obj.inputs
class _test_toposort(testcase):
def test0(self):
"""Test a simple graph"""
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
o = MyOp(r1, r2)
o2 = MyOp(o.outputs[0], r5)
all = general_toposort(o2.outputs, prenode)
self.failUnless(all == [r5, r2, r1, o, o.outputs[0], o2, o2.outputs[0]], all)
all = io_toposort([r5], o2.outputs)
self.failUnless(all == [o, o2], all)
def test1(self):
"""Test a graph with double dependencies"""
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
o = MyOp(r1, r1)
o2 = MyOp(o.outputs[0], r5)
all = general_toposort(o2.outputs, prenode)
self.failUnless(all == [r5, r1, o, o.outputs[0], o2, o2.outputs[0]], all)
def test2(self):
"""Test a graph where the inputs have owners"""
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
o = MyOp(r1, r1)
r2b = o.outputs[0]
o2 = MyOp(r2b, r2b)
all = io_toposort([r2b], o2.outputs)
self.failUnless(all == [o2], all)
o2 = MyOp(r2b, r5)
all = io_toposort([r2b], o2.outputs)
self.failUnless(all == [o2], all)
def test3(self):
"""Test a graph which is not connected"""
r1, r2, r3, r4 = MyResult(1), MyResult(2), MyResult(3), MyResult(4)
o0 = MyOp(r1, r2)
o1 = MyOp(r3, r4)
all = io_toposort([r1, r2, r3, r4], o0.outputs + o1.outputs)
self.failUnless(all == [o1,o0], all)
def test4(self):
"""Test inputs and outputs mixed together in a chain graph"""
r1, r2, r3, r4 = MyResult(1), MyResult(2), MyResult(3), MyResult(4)
o0 = MyOp(r1, r2)
o1 = MyOp(o0.outputs[0], r1)
all = io_toposort([r1, o0.outputs[0]], [o0.outputs[0], o1.outputs[0]])
self.failUnless(all == [o1], all)
def test5(self):
"""Test when outputs have clients"""
r1, r2, r3, r4 = MyResult(1), MyResult(2), MyResult(3), MyResult(4)
o0 = MyOp(r1, r2)
o1 = MyOp(o0.outputs[0], r4)
all = io_toposort([], o0.outputs)
self.failUnless(all == [o0], all)
if __name__ == '__main__':
unittest.main()
#_test_inputs('test_unreached_inputs').debug()
if 1:
#run all tests
unittest.main()
elif 1:
#load some TestCase classes
suite = unittest.TestLoader()
suite = suite.loadTestsFromTestCase(_test_toposort)
#run just some of them
unittest.TextTestRunner(verbosity=2).run(suite)
else:
#run just a single test
_test_toposort('test0').debug()
......@@ -7,10 +7,10 @@ import result, op
__all__ = ['inputs',
'results_and_orphans', 'results', 'orphans',
'results_and_orphans', 'results', 'orphans', 'stack_search',
'ops',
'clone', 'clone_get_equiv',
'io_toposort',
'io_toposort', 'general_toposort',
'default_leaf_formatter', 'default_node_formatter',
'op_as_string',
'as_string',
......@@ -61,6 +61,8 @@ def stack_search(start, expand, mode='bfs', build_inv = False):
return rval_list, expand_inv
return rval_list
@utils.deprecated('gof.graph', 'is this function ever used?')
def inputs(result_list):
"""
@type result_list: list of L{Result}
......@@ -80,54 +82,6 @@ def inputs(result_list):
return rval
@utils.deprecated('gof.graph', 'preserving only for review')
def _results_and_orphans(i, o, except_unreachable_input=False):
"""
@type i: list
@param i: input L{Result}s
@type o: list
@param o: output L{Result}s
Returns the pair (results, orphans). The former is the set of
L{Result}s that are involved in the subgraph that lies between i and
o. This includes i, o, orphans(i, o) and all results of all
intermediary steps from i to o. The second element of the returned
pair is orphans(i, o).
"""
results = set()
i = set(i)
results.update(i)
incomplete_paths = []
reached = set()
def helper(r, path):
if r in i:
reached.add(r)
results.update(path)
elif r.owner is None:
incomplete_paths.append(path)
else:
op = r.owner
for r2 in op.inputs:
helper(r2, path + [r2])
for output in o:
helper(output, [output])
orphans = set()
for path in incomplete_paths:
for r in path:
if r not in results:
orphans.add(r)
break
if except_unreachable_input and len(i) != len(reached):
raise Exception(results_and_orphans.E_unreached)
results.update(orphans)
return results, orphans
def results_and_orphans(r_in, r_out, except_unreachable_input=False):
r_in_set = set(r_in)
class Dummy(object): pass
......@@ -282,31 +236,70 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = False):
return d
def general_toposort(r_out, deps):
"""
@note: deps(i) should behave like a pure function (no funny business with
internal state)
def io_toposort(i, o, orderings = {}):
@note: deps(i) can/should be cached by the deps function to be fast
"""
@type i: list
@param i: input L{Result}s
@type o: list
@param o: output L{Result}s
@param orderings: {op: [requirements for op]} (defaults to {})
deps_cache = {}
def _deps(io):
if io not in deps_cache:
d = deps(io)
if d:
deps_cache[io] = list(d)
else:
deps_cache[io] = d
return d
else:
return deps_cache[io]
assert isinstance(r_out, (tuple, list, deque))
reachable, clients = stack_search( deque(r_out), _deps, 'dfs', True)
sources = deque([r for r in reachable if not deps_cache.get(r, None)])
rset = set()
rlist = []
while sources:
node = sources.popleft()
if node not in rset:
rlist.append(node)
rset.add(node)
for client in clients.get(node, []):
deps_cache[client] = [a for a in deps_cache[client] if a is not node]
if not deps_cache[client]:
sources.append(client)
if len(rlist) != len(reachable):
print ''
print reachable
print rlist
raise 'failed to complete topological sort of given nodes'
return rlist
def io_toposort(i, o, orderings = {}):
iset = set(i)
def deps(obj):
rval = []
if obj not in iset:
if isinstance(obj, result.Result):
if obj.owner:
rval = [obj.owner]
if isinstance(obj, op.Op):
rval = list(obj.inputs)
rval.extend(orderings.get(obj, []))
else:
assert not orderings.get(obj, [])
return rval
topo = general_toposort(o, deps)
return [o for o in topo if isinstance(o, op.Op)]
@rtype: ordered list
@return: L{Op}s that belong in the subgraph between i and o which
respects the following constraints:
- all inputs in i are assumed to be already computed
- the L{Op}s that compute an L{Op}'s inputs must be computed before it
- the orderings specified in the optional orderings parameter must be satisfied
Note that this function does not take into account ordering information
related to destructive operations or other special behavior.
"""
prereqs_d = copy(orderings)
all = ops(i, o)
for op in all:
asdf = set([input.owner for input in op.inputs if input.owner and input.owner in all])
prereqs_d.setdefault(op, set()).update(asdf)
return utils.toposort(prereqs_d)
default_leaf_formatter = str
......@@ -430,4 +423,82 @@ class Graph:
if 0:
#these were the old implementations
# they were replaced out of a desire that graph search routines would not
# depend on the hash or id of any node, so that it would be deterministic
# and consistent between program executions.
@utils.deprecated('gof.graph', 'preserving only for review')
def _results_and_orphans(i, o, except_unreachable_input=False):
"""
@type i: list
@param i: input L{Result}s
@type o: list
@param o: output L{Result}s
Returns the pair (results, orphans). The former is the set of
L{Result}s that are involved in the subgraph that lies between i and
o. This includes i, o, orphans(i, o) and all results of all
intermediary steps from i to o. The second element of the returned
pair is orphans(i, o).
"""
results = set()
i = set(i)
results.update(i)
incomplete_paths = []
reached = set()
def helper(r, path):
if r in i:
reached.add(r)
results.update(path)
elif r.owner is None:
incomplete_paths.append(path)
else:
op = r.owner
for r2 in op.inputs:
helper(r2, path + [r2])
for output in o:
helper(output, [output])
orphans = set()
for path in incomplete_paths:
for r in path:
if r not in results:
orphans.add(r)
break
if except_unreachable_input and len(i) != len(reached):
raise Exception(results_and_orphans.E_unreached)
results.update(orphans)
return results, orphans
def _io_toposort(i, o, orderings = {}):
"""
@type i: list
@param i: input L{Result}s
@type o: list
@param o: output L{Result}s
@param orderings: {op: [requirements for op]} (defaults to {})
@rtype: ordered list
@return: L{Op}s that belong in the subgraph between i and o which
respects the following constraints:
- all inputs in i are assumed to be already computed
- the L{Op}s that compute an L{Op}'s inputs must be computed before it
- the orderings specified in the optional orderings parameter must be satisfied
Note that this function does not take into account ordering information
related to destructive operations or other special behavior.
"""
prereqs_d = copy(orderings)
all = ops(i, o)
for op in all:
asdf = set([input.owner for input in op.inputs if input.owner and input.owner in all])
prereqs_d.setdefault(op, set()).update(asdf)
return utils.toposort(prereqs_d)
......@@ -61,6 +61,7 @@ def difference(seq1, seq2):
# -> use O(len(seq1) * len(seq2)) algo
return [x for x in seq1 if x not in seq2]
def partition(f, seq):
seqt = []
seqf = []
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论