提交 5301d18d authored 作者: James Bergstra's avatar James Bergstra

merged

...@@ -124,9 +124,12 @@ class T_Function(unittest.TestCase): ...@@ -124,9 +124,12 @@ class T_Function(unittest.TestCase):
def test_orphans(self): def test_orphans(self):
gi, go = graph1() gi, go = graph1()
opt = None opt = None
p0 = Function(gi[0:0], go) p0 = Function(gi[0:0], go, keep_locals=True)
self.failUnless(p0() == 1.0) self.failUnless(p0.orphans == go, p0.orphans)
vp0 = p0()
self.failUnless(vp0 == 1.0, vp0)
p3 = Function(gi,go) p3 = Function(gi,go)
p2 = Function(gi[0:2], go) p2 = Function(gi[0:2], go)
......
from collections import deque
import unittest import unittest
from graph import * from graph import *
from op import Op from op import Op
from result import Result from result import Result
def inputs(result_list):
"""
@type result_list: list of L{Result}
@param result_list: output L{Result}s (from which to search backward through owners)
@returns: the list of L{Result}s with no owner, in the order found by a
left-recursive depth-first search started at the L{Result}s in result_list.
"""
def expand(r):
if r.owner:
l = list(r.owner.inputs)
l.reverse()
return l
dfs_results = stack_search(deque(result_list), expand, 'dfs')
rval = [r for r in dfs_results if r.owner is None]
#print rval, _orig_inputs(o)
return rval
if 1:
testcase = unittest.TestCase
else:
testcase = object
realtestcase = unittest.TestCase
class MyResult(Result): class MyResult(Result):
...@@ -21,10 +45,10 @@ class MyResult(Result): ...@@ -21,10 +45,10 @@ class MyResult(Result):
return isinstance(other, MyResult) and other.thingy == self.thingy return isinstance(other, MyResult) and other.thingy == self.thingy
def __str__(self): def __str__(self):
return str(self.thingy) return 'R%s' % str(self.thingy)
def __repr__(self): def __repr__(self):
return str(self.thingy) return 'R%s' % str(self.thingy)
class MyOp(Op): class MyOp(Op):
...@@ -37,18 +61,19 @@ class MyOp(Op): ...@@ -37,18 +61,19 @@ class MyOp(Op):
self.outputs = [MyResult(sum([input.thingy for input in inputs]))] self.outputs = [MyResult(sum([input.thingy for input in inputs]))]
class _test_inputs(unittest.TestCase): class _test_inputs(testcase):
def test_straightforward(self): def test_straightforward(self):
r1, r2 = MyResult(1), MyResult(2) r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2) op = MyOp(r1, r2)
assert inputs(op.outputs) == set([r1, r2]) assert inputs(op.outputs) == [r1, r2]
def test_deep(self): def test_deep(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5) op2 = MyOp(op.outputs[0], r5)
assert inputs(op2.outputs) == set([r1, r2, r5]) i = inputs(op2.outputs)
self.failUnless(i == [r1, r2, r5], i)
def test_unreached_inputs(self): def test_unreached_inputs(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
...@@ -57,23 +82,26 @@ class _test_inputs(unittest.TestCase): ...@@ -57,23 +82,26 @@ class _test_inputs(unittest.TestCase):
try: try:
# function doesn't raise if we put False instead of True # function doesn't raise if we put False instead of True
ro = results_and_orphans([r1, r2, op2.outputs[0]], op.outputs, True) ro = results_and_orphans([r1, r2, op2.outputs[0]], op.outputs, True)
self.fail()
except Exception, e: except Exception, e:
if e[0] is results_and_orphans.E_unreached: if e[0] is results_and_orphans.E_unreached:
return return
raise self.fail()
def test_uz(self):
pass
class _test_orphans(unittest.TestCase):
class _test_orphans(testcase):
def test_straightforward(self): def test_straightforward(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5) op2 = MyOp(op.outputs[0], r5)
assert orphans([r1, r2], op2.outputs) == set([r5]) orph = orphans([r1, r2], op2.outputs)
self.failUnless(orph == [r5], orph)
class _test_as_string(unittest.TestCase): class _test_as_string(testcase):
leaf_formatter = str leaf_formatter = str
node_formatter = lambda op, argstrings: "%s(%s)" % (op.__class__.__name__, node_formatter = lambda op, argstrings: "%s(%s)" % (op.__class__.__name__,
...@@ -82,35 +110,37 @@ class _test_as_string(unittest.TestCase): ...@@ -82,35 +110,37 @@ class _test_as_string(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
r1, r2 = MyResult(1), MyResult(2) r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2) op = MyOp(r1, r2)
assert as_string([r1, r2], op.outputs) == ["MyOp(1, 2)"] s = as_string([r1, r2], op.outputs)
self.failUnless(s == ["MyOp(R1, R2)"], s)
def test_deep(self): def test_deep(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5) op2 = MyOp(op.outputs[0], r5)
assert as_string([r1, r2, r5], op2.outputs) == ["MyOp(MyOp(1, 2), 5)"] s = as_string([r1, r2, r5], op2.outputs)
self.failUnless(s == ["MyOp(MyOp(R1, R2), R5)"], s)
def test_multiple_references(self): def test_multiple_references(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], op.outputs[0]) op2 = MyOp(op.outputs[0], op.outputs[0])
assert as_string([r1, r2, r5], op2.outputs) == ["MyOp(*1 -> MyOp(1, 2), *1)"] assert as_string([r1, r2, r5], op2.outputs) == ["MyOp(*1 -> MyOp(R1, R2), *1)"]
def test_cutoff(self): def test_cutoff(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], op.outputs[0]) op2 = MyOp(op.outputs[0], op.outputs[0])
assert as_string(op.outputs, op2.outputs) == ["MyOp(3, 3)"] assert as_string(op.outputs, op2.outputs) == ["MyOp(R3, R3)"]
assert as_string(op2.inputs, op2.outputs) == ["MyOp(3, 3)"] assert as_string(op2.inputs, op2.outputs) == ["MyOp(R3, R3)"]
class _test_clone(unittest.TestCase): class _test_clone(testcase):
def test_accurate(self): def test_accurate(self):
r1, r2 = MyResult(1), MyResult(2) r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2) op = MyOp(r1, r2)
new = clone([r1, r2], op.outputs) new = clone([r1, r2], op.outputs)
assert as_string([r1, r2], new) == ["MyOp(1, 2)"] assert as_string([r1, r2], new) == ["MyOp(R1, R2)"]
def test_copy(self): def test_copy(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
...@@ -130,13 +160,91 @@ class _test_clone(unittest.TestCase): ...@@ -130,13 +160,91 @@ class _test_clone(unittest.TestCase):
new_op = new[0].owner new_op = new[0].owner
new_op.inputs = MyResult(7), MyResult(8) new_op.inputs = MyResult(7), MyResult(8)
assert as_string(inputs(new_op.outputs), new_op.outputs) == ["MyOp(7, 8)"] s = as_string(inputs(new_op.outputs), new_op.outputs)
assert as_string(inputs(op.outputs), op.outputs) == ["MyOp(MyOp(1, 2), 5)"] self.failUnless( s == ["MyOp(R7, R8)"], s)
assert as_string(inputs(op.outputs), op.outputs) == ["MyOp(MyOp(R1, R2), R5)"]
def prenode(obj):
if isinstance(obj, Result):
if obj.owner:
return [obj.owner]
if isinstance(obj, Op):
return obj.inputs
class _test_toposort(testcase):
def test0(self):
"""Test a simple graph"""
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
o = MyOp(r1, r2)
o2 = MyOp(o.outputs[0], r5)
if __name__ == '__main__': all = general_toposort(o2.outputs, prenode)
unittest.main() self.failUnless(all == [r5, r2, r1, o, o.outputs[0], o2, o2.outputs[0]], all)
all = io_toposort([r5], o2.outputs)
self.failUnless(all == [o, o2], all)
def test1(self):
"""Test a graph with double dependencies"""
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
o = MyOp(r1, r1)
o2 = MyOp(o.outputs[0], r5)
all = general_toposort(o2.outputs, prenode)
self.failUnless(all == [r5, r1, o, o.outputs[0], o2, o2.outputs[0]], all)
def test2(self):
"""Test a graph where the inputs have owners"""
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
o = MyOp(r1, r1)
r2b = o.outputs[0]
o2 = MyOp(r2b, r2b)
all = io_toposort([r2b], o2.outputs)
self.failUnless(all == [o2], all)
o2 = MyOp(r2b, r5)
all = io_toposort([r2b], o2.outputs)
self.failUnless(all == [o2], all)
def test3(self):
"""Test a graph which is not connected"""
r1, r2, r3, r4 = MyResult(1), MyResult(2), MyResult(3), MyResult(4)
o0 = MyOp(r1, r2)
o1 = MyOp(r3, r4)
all = io_toposort([r1, r2, r3, r4], o0.outputs + o1.outputs)
self.failUnless(all == [o1,o0], all)
def test4(self):
"""Test inputs and outputs mixed together in a chain graph"""
r1, r2, r3, r4 = MyResult(1), MyResult(2), MyResult(3), MyResult(4)
o0 = MyOp(r1, r2)
o1 = MyOp(o0.outputs[0], r1)
all = io_toposort([r1, o0.outputs[0]], [o0.outputs[0], o1.outputs[0]])
self.failUnless(all == [o1], all)
def test5(self):
"""Test when outputs have clients"""
r1, r2, r3, r4 = MyResult(1), MyResult(2), MyResult(3), MyResult(4)
o0 = MyOp(r1, r2)
o1 = MyOp(o0.outputs[0], r4)
all = io_toposort([], o0.outputs)
self.failUnless(all == [o0], all)
if __name__ == '__main__':
if 1:
#run all tests
unittest.main()
elif 1:
#load some TestCase classes
suite = unittest.TestLoader()
suite = suite.loadTestsFromTestCase(_test_toposort)
#run just some of them
unittest.TextTestRunner(verbosity=2).run(suite)
else:
#run just a single test
_test_toposort('test0').debug()
from copy import copy from copy import copy
from collections import deque
import utils import utils
import result, op
__all__ = ['inputs', __all__ = ['inputs',
'results_and_orphans', 'results', 'orphans', 'results_and_orphans', 'results', 'orphans', 'stack_search',
'ops', 'ops',
'clone', 'clone_get_equiv', 'clone', 'clone_get_equiv',
'io_toposort', 'io_toposort', 'general_toposort',
'default_leaf_formatter', 'default_node_formatter', 'default_leaf_formatter', 'default_node_formatter',
'op_as_string', 'op_as_string',
'as_string', 'as_string',
'Graph'] 'Graph']
is_result = lambda o: isinstance(o, result.Result)
is_op = lambda o: isinstance(o, op.Op)
is_result = utils.attr_checker('owner', 'index') def stack_search(start, expand, mode='bfs', build_inv = False):
is_op = utils.attr_checker('inputs', 'outputs') """Search through L{Result}s, either breadth- or depth-first
@type start: deque
@param start: search from these nodes
@type explore: function
@param explore: when we get to a node, add explore(node) to the list of
nodes to visit. This function should return a list, or None
def inputs(o): @rtype: list of L{Result}
""" @return: the list of L{Result}s in order of traversal.
@type o: list
@param o: output L{Result}s @note: a L{Result} will appear at most once in the return value, even if it
appears multiple times in the start parameter.
Returns the set of inputs necessary to compute the outputs in o
such that input.owner is None.
"""
results = set()
def seek(r):
op = r.owner
if op is None:
results.add(r)
else:
for input in op.inputs:
seek(input)
for output in o:
seek(output)
return results
@postcondition: every element of start is transferred to the returned list.
@postcondition: start is empty.
def results_and_orphans(i, o, except_unreachable_input=False):
""" """
@type i: list if mode not in ('bfs', 'dfs'):
@param i: input L{Result}s raise ValueError('mode should be bfs or dfs', mode)
@type o: list rval_set = set()
@param o: output L{Result}s rval_list = list()
start_pop = start.popleft if mode is 'bfs' else start.pop
Returns the pair (results, orphans). The former is the set of expand_inv = {}
L{Result}s that are involved in the subgraph that lies between i and while start:
o. This includes i, o, orphans(i, o) and all results of all l = start_pop()
intermediary steps from i to o. The second element of the returned if id(l) not in rval_set:
pair is orphans(i, o). rval_list.append(l)
rval_set.add(id(l))
expand_l = expand(l)
if expand_l:
if build_inv:
for r in expand_l:
expand_inv.setdefault(r, []).append(l)
start.extend(expand_l)
assert len(rval_list) == len(rval_set)
if build_inv:
return rval_list, expand_inv
return rval_list
@utils.deprecated('gof.graph', 'is this function ever used?')
def inputs(result_list):
""" """
results = set() @type result_list: list of L{Result}
i = set(i) @param result_list: output L{Result}s (from which to search backward through owners)
results.update(i) @returns: the list of L{Result}s with no owner, in the order found by a
incomplete_paths = [] left-recursive depth-first search started at the L{Result}s in result_list.
reached = set()
def helper(r, path):
if r in i:
reached.add(r)
results.update(path)
elif r.owner is None:
incomplete_paths.append(path)
else:
op = r.owner
for r2 in op.inputs:
helper(r2, path + [r2])
for output in o:
helper(output, [output])
orphans = set()
for path in incomplete_paths:
for r in path:
if r not in results:
orphans.add(r)
break
if except_unreachable_input and len(i) != len(reached): """
raise Exception(results_and_orphans.E_unreached) def expand(r):
if r.owner:
l = list(r.owner.inputs)
l.reverse()
return l
dfs_results = stack_search(deque(result_list), expand, 'dfs')
rval = [r for r in dfs_results if r.owner is None]
#print rval, _orig_inputs(o)
return rval
def results_and_orphans(r_in, r_out, except_unreachable_input=False):
r_in_set = set(r_in)
class Dummy(object): pass
dummy = Dummy()
dummy.inputs = r_out
def expand_inputs(io):
if io in r_in_set:
return None
try:
return [io.owner] if io.owner != None else None
except AttributeError:
return io.inputs
ops_and_results, dfsinv = stack_search(
deque([dummy]),
expand_inputs, 'dfs', True)
if except_unreachable_input:
for r in r_in:
if r not in dfsinv:
raise Exception(results_and_orphans.E_unreached)
clients = stack_search(
deque(r_in),
lambda io: dfsinv.get(io,None), 'dfs')
ops_to_compute = [o for o in clients if is_op(o) and o is not dummy]
results.update(orphans) results = []
for o in ops_to_compute:
results.extend(o.inputs)
results.extend(r_out)
op_set = set(ops_to_compute)
assert len(ops_to_compute) == len(op_set)
orphans = [r for r in results \
if (r.owner not in op_set) and (r not in r_in_set)]
return results, orphans return results, orphans
results_and_orphans.E_unreached = 'there were unreachable inputs' results_and_orphans.E_unreached = 'there were unreachable inputs'
...@@ -203,31 +236,70 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = False): ...@@ -203,31 +236,70 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = False):
return d return d
def general_toposort(r_out, deps):
"""
@note: deps(i) should behave like a pure function (no funny business with
internal state)
def io_toposort(i, o, orderings = {}): @note: deps(i) can/should be cached by the deps function to be fast
""" """
@type i: list deps_cache = {}
@param i: input L{Result}s def _deps(io):
@type o: list if io not in deps_cache:
@param o: output L{Result}s d = deps(io)
@param orderings: {op: [requirements for op]} (defaults to {}) if d:
deps_cache[io] = list(d)
else:
deps_cache[io] = d
return d
else:
return deps_cache[io]
assert isinstance(r_out, (tuple, list, deque))
reachable, clients = stack_search( deque(r_out), _deps, 'dfs', True)
sources = deque([r for r in reachable if not deps_cache.get(r, None)])
rset = set()
rlist = []
while sources:
node = sources.popleft()
if node not in rset:
rlist.append(node)
rset.add(node)
for client in clients.get(node, []):
deps_cache[client] = [a for a in deps_cache[client] if a is not node]
if not deps_cache[client]:
sources.append(client)
if len(rlist) != len(reachable):
print ''
print reachable
print rlist
raise 'failed to complete topological sort of given nodes'
return rlist
def io_toposort(i, o, orderings = {}):
iset = set(i)
def deps(obj):
rval = []
if obj not in iset:
if isinstance(obj, result.Result):
if obj.owner:
rval = [obj.owner]
if isinstance(obj, op.Op):
rval = list(obj.inputs)
rval.extend(orderings.get(obj, []))
else:
assert not orderings.get(obj, [])
return rval
topo = general_toposort(o, deps)
return [o for o in topo if isinstance(o, op.Op)]
@rtype: ordered list
@return: L{Op}s that belong in the subgraph between i and o which
respects the following constraints:
- all inputs in i are assumed to be already computed
- the L{Op}s that compute an L{Op}'s inputs must be computed before it
- the orderings specified in the optional orderings parameter must be satisfied
Note that this function does not take into account ordering information
related to destructive operations or other special behavior.
"""
prereqs_d = copy(orderings)
all = ops(i, o)
for op in all:
asdf = set([input.owner for input in op.inputs if input.owner and input.owner in all])
prereqs_d.setdefault(op, set()).update(asdf)
return utils.toposort(prereqs_d)
default_leaf_formatter = str default_leaf_formatter = str
...@@ -262,6 +334,8 @@ def as_string(i, o, ...@@ -262,6 +334,8 @@ def as_string(i, o,
exist for viewing convenience). exist for viewing convenience).
""" """
i = set(i)
orph = orphans(i, o) orph = orphans(i, o)
multi = set() multi = set()
...@@ -349,4 +423,82 @@ class Graph: ...@@ -349,4 +423,82 @@ class Graph:
if 0:
#these were the old implementations
# they were replaced out of a desire that graph search routines would not
# depend on the hash or id of any node, so that it would be deterministic
# and consistent between program executions.
@utils.deprecated('gof.graph', 'preserving only for review')
def _results_and_orphans(i, o, except_unreachable_input=False):
"""
@type i: list
@param i: input L{Result}s
@type o: list
@param o: output L{Result}s
Returns the pair (results, orphans). The former is the set of
L{Result}s that are involved in the subgraph that lies between i and
o. This includes i, o, orphans(i, o) and all results of all
intermediary steps from i to o. The second element of the returned
pair is orphans(i, o).
"""
results = set()
i = set(i)
results.update(i)
incomplete_paths = []
reached = set()
def helper(r, path):
if r in i:
reached.add(r)
results.update(path)
elif r.owner is None:
incomplete_paths.append(path)
else:
op = r.owner
for r2 in op.inputs:
helper(r2, path + [r2])
for output in o:
helper(output, [output])
orphans = set()
for path in incomplete_paths:
for r in path:
if r not in results:
orphans.add(r)
break
if except_unreachable_input and len(i) != len(reached):
raise Exception(results_and_orphans.E_unreached)
results.update(orphans)
return results, orphans
def _io_toposort(i, o, orderings = {}):
"""
@type i: list
@param i: input L{Result}s
@type o: list
@param o: output L{Result}s
@param orderings: {op: [requirements for op]} (defaults to {})
@rtype: ordered list
@return: L{Op}s that belong in the subgraph between i and o which
respects the following constraints:
- all inputs in i are assumed to be already computed
- the L{Op}s that compute an L{Op}'s inputs must be computed before it
- the orderings specified in the optional orderings parameter must be satisfied
Note that this function does not take into account ordering information
related to destructive operations or other special behavior.
"""
prereqs_d = copy(orderings)
all = ops(i, o)
for op in all:
asdf = set([input.owner for input in op.inputs if input.owner and input.owner in all])
prereqs_d.setdefault(op, set()).update(asdf)
return utils.toposort(prereqs_d)
...@@ -6,7 +6,6 @@ compatible with gof's graph manipulation routines. ...@@ -6,7 +6,6 @@ compatible with gof's graph manipulation routines.
import utils import utils
from utils import ClsInit, all_bases, all_bases_collect, AbstractFunctionError from utils import ClsInit, all_bases, all_bases_collect, AbstractFunctionError
import graph
from copy import copy from copy import copy
...@@ -170,11 +169,11 @@ class Op(object): ...@@ -170,11 +169,11 @@ class Op(object):
# String representation # String representation
# #
def __str__(self): if 0:
return graph.op_as_string(self.inputs, self) def __str__(self):
return graph.op_as_string(self.inputs, self)
def __repr__(self): def __repr__(self):
return str(self) return str(self)
# #
......
...@@ -19,6 +19,31 @@ class AbstractFunctionError(Exception): ...@@ -19,6 +19,31 @@ class AbstractFunctionError(Exception):
function has been left out of an implementation class. function has been left out of an implementation class.
""" """
def deprecated(filename, msg=''):
"""Decorator which will print a warning message on the first call.
Use it like this:
@deprecated('myfile', 'do something different...')
def fn_name(...)
...
And it will print
WARNING myfile.fn_name deprecated. do something different...
"""
def _deprecated(f):
printme = [True]
def g(*args, **kwargs):
if printme[0]:
print 'WARNING: %s.%s deprecated. %s'\
% (filename, f.__name__, msg)
printme[0] = False
return f(*args, **kwargs)
return g
return _deprecated
def uniq(seq): def uniq(seq):
#TODO: consider building a set out of seq so that the if condition is constant time -JB #TODO: consider building a set out of seq so that the if condition is constant time -JB
return [x for i, x in enumerate(seq) if seq.index(x) == i] return [x for i, x in enumerate(seq) if seq.index(x) == i]
...@@ -36,6 +61,7 @@ def difference(seq1, seq2): ...@@ -36,6 +61,7 @@ def difference(seq1, seq2):
# -> use O(len(seq1) * len(seq2)) algo # -> use O(len(seq1) * len(seq2)) algo
return [x for x in seq1 if x not in seq2] return [x for x in seq1 if x not in seq2]
def partition(f, seq): def partition(f, seq):
seqt = [] seqt = []
seqf = [] seqf = []
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论