提交 5301d18d authored 作者: James Bergstra's avatar James Bergstra

merged

......@@ -124,9 +124,12 @@ class T_Function(unittest.TestCase):
def test_orphans(self):
gi, go = graph1()
opt = None
p0 = Function(gi[0:0], go)
p0 = Function(gi[0:0], go, keep_locals=True)
self.failUnless(p0() == 1.0)
self.failUnless(p0.orphans == go, p0.orphans)
vp0 = p0()
self.failUnless(vp0 == 1.0, vp0)
p3 = Function(gi,go)
p2 = Function(gi[0:2], go)
......
from collections import deque
import unittest
from graph import *
from op import Op
from result import Result
def inputs(result_list):
"""
@type result_list: list of L{Result}
@param result_list: output L{Result}s (from which to search backward through owners)
@returns: the list of L{Result}s with no owner, in the order found by a
left-recursive depth-first search started at the L{Result}s in result_list.
"""
def expand(r):
if r.owner:
l = list(r.owner.inputs)
l.reverse()
return l
dfs_results = stack_search(deque(result_list), expand, 'dfs')
rval = [r for r in dfs_results if r.owner is None]
#print rval, _orig_inputs(o)
return rval
if 1:
testcase = unittest.TestCase
else:
testcase = object
realtestcase = unittest.TestCase
class MyResult(Result):
......@@ -21,10 +45,10 @@ class MyResult(Result):
return isinstance(other, MyResult) and other.thingy == self.thingy
def __str__(self):
return str(self.thingy)
return 'R%s' % str(self.thingy)
def __repr__(self):
return str(self.thingy)
return 'R%s' % str(self.thingy)
class MyOp(Op):
......@@ -37,18 +61,19 @@ class MyOp(Op):
self.outputs = [MyResult(sum([input.thingy for input in inputs]))]
class _test_inputs(unittest.TestCase):
class _test_inputs(testcase):
def test_straightforward(self):
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2)
assert inputs(op.outputs) == set([r1, r2])
assert inputs(op.outputs) == [r1, r2]
def test_deep(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5)
assert inputs(op2.outputs) == set([r1, r2, r5])
i = inputs(op2.outputs)
self.failUnless(i == [r1, r2, r5], i)
def test_unreached_inputs(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
......@@ -57,23 +82,26 @@ class _test_inputs(unittest.TestCase):
try:
# function doesn't raise if we put False instead of True
ro = results_and_orphans([r1, r2, op2.outputs[0]], op.outputs, True)
self.fail()
except Exception, e:
if e[0] is results_and_orphans.E_unreached:
return
raise
self.fail()
def test_uz(self):
pass
class _test_orphans(unittest.TestCase):
class _test_orphans(testcase):
def test_straightforward(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5)
assert orphans([r1, r2], op2.outputs) == set([r5])
orph = orphans([r1, r2], op2.outputs)
self.failUnless(orph == [r5], orph)
class _test_as_string(unittest.TestCase):
class _test_as_string(testcase):
leaf_formatter = str
node_formatter = lambda op, argstrings: "%s(%s)" % (op.__class__.__name__,
......@@ -82,35 +110,37 @@ class _test_as_string(unittest.TestCase):
def test_straightforward(self):
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2)
assert as_string([r1, r2], op.outputs) == ["MyOp(1, 2)"]
s = as_string([r1, r2], op.outputs)
self.failUnless(s == ["MyOp(R1, R2)"], s)
def test_deep(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5)
assert as_string([r1, r2, r5], op2.outputs) == ["MyOp(MyOp(1, 2), 5)"]
s = as_string([r1, r2, r5], op2.outputs)
self.failUnless(s == ["MyOp(MyOp(R1, R2), R5)"], s)
def test_multiple_references(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], op.outputs[0])
assert as_string([r1, r2, r5], op2.outputs) == ["MyOp(*1 -> MyOp(1, 2), *1)"]
assert as_string([r1, r2, r5], op2.outputs) == ["MyOp(*1 -> MyOp(R1, R2), *1)"]
def test_cutoff(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], op.outputs[0])
assert as_string(op.outputs, op2.outputs) == ["MyOp(3, 3)"]
assert as_string(op2.inputs, op2.outputs) == ["MyOp(3, 3)"]
assert as_string(op.outputs, op2.outputs) == ["MyOp(R3, R3)"]
assert as_string(op2.inputs, op2.outputs) == ["MyOp(R3, R3)"]
class _test_clone(unittest.TestCase):
class _test_clone(testcase):
def test_accurate(self):
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2)
new = clone([r1, r2], op.outputs)
assert as_string([r1, r2], new) == ["MyOp(1, 2)"]
assert as_string([r1, r2], new) == ["MyOp(R1, R2)"]
def test_copy(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
......@@ -130,13 +160,91 @@ class _test_clone(unittest.TestCase):
new_op = new[0].owner
new_op.inputs = MyResult(7), MyResult(8)
assert as_string(inputs(new_op.outputs), new_op.outputs) == ["MyOp(7, 8)"]
assert as_string(inputs(op.outputs), op.outputs) == ["MyOp(MyOp(1, 2), 5)"]
s = as_string(inputs(new_op.outputs), new_op.outputs)
self.failUnless( s == ["MyOp(R7, R8)"], s)
assert as_string(inputs(op.outputs), op.outputs) == ["MyOp(MyOp(R1, R2), R5)"]
def prenode(obj):
if isinstance(obj, Result):
if obj.owner:
return [obj.owner]
if isinstance(obj, Op):
return obj.inputs
class _test_toposort(testcase):
def test0(self):
"""Test a simple graph"""
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
o = MyOp(r1, r2)
o2 = MyOp(o.outputs[0], r5)
if __name__ == '__main__':
unittest.main()
all = general_toposort(o2.outputs, prenode)
self.failUnless(all == [r5, r2, r1, o, o.outputs[0], o2, o2.outputs[0]], all)
all = io_toposort([r5], o2.outputs)
self.failUnless(all == [o, o2], all)
def test1(self):
"""Test a graph with double dependencies"""
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
o = MyOp(r1, r1)
o2 = MyOp(o.outputs[0], r5)
all = general_toposort(o2.outputs, prenode)
self.failUnless(all == [r5, r1, o, o.outputs[0], o2, o2.outputs[0]], all)
def test2(self):
"""Test a graph where the inputs have owners"""
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
o = MyOp(r1, r1)
r2b = o.outputs[0]
o2 = MyOp(r2b, r2b)
all = io_toposort([r2b], o2.outputs)
self.failUnless(all == [o2], all)
o2 = MyOp(r2b, r5)
all = io_toposort([r2b], o2.outputs)
self.failUnless(all == [o2], all)
def test3(self):
"""Test a graph which is not connected"""
r1, r2, r3, r4 = MyResult(1), MyResult(2), MyResult(3), MyResult(4)
o0 = MyOp(r1, r2)
o1 = MyOp(r3, r4)
all = io_toposort([r1, r2, r3, r4], o0.outputs + o1.outputs)
self.failUnless(all == [o1,o0], all)
def test4(self):
"""Test inputs and outputs mixed together in a chain graph"""
r1, r2, r3, r4 = MyResult(1), MyResult(2), MyResult(3), MyResult(4)
o0 = MyOp(r1, r2)
o1 = MyOp(o0.outputs[0], r1)
all = io_toposort([r1, o0.outputs[0]], [o0.outputs[0], o1.outputs[0]])
self.failUnless(all == [o1], all)
def test5(self):
"""Test when outputs have clients"""
r1, r2, r3, r4 = MyResult(1), MyResult(2), MyResult(3), MyResult(4)
o0 = MyOp(r1, r2)
o1 = MyOp(o0.outputs[0], r4)
all = io_toposort([], o0.outputs)
self.failUnless(all == [o0], all)
if __name__ == '__main__':
if 1:
#run all tests
unittest.main()
elif 1:
#load some TestCase classes
suite = unittest.TestLoader()
suite = suite.loadTestsFromTestCase(_test_toposort)
#run just some of them
unittest.TextTestRunner(verbosity=2).run(suite)
else:
#run just a single test
_test_toposort('test0').debug()
from copy import copy
from collections import deque
import utils
import result, op
__all__ = ['inputs',
'results_and_orphans', 'results', 'orphans',
'results_and_orphans', 'results', 'orphans', 'stack_search',
'ops',
'clone', 'clone_get_equiv',
'io_toposort',
'io_toposort', 'general_toposort',
'default_leaf_formatter', 'default_node_formatter',
'op_as_string',
'as_string',
'Graph']
is_result = lambda o: isinstance(o, result.Result)
is_op = lambda o: isinstance(o, op.Op)
is_result = utils.attr_checker('owner', 'index')
is_op = utils.attr_checker('inputs', 'outputs')
def stack_search(start, expand, mode='bfs', build_inv = False):
"""Search through L{Result}s, either breadth- or depth-first
@type start: deque
@param start: search from these nodes
@type explore: function
@param explore: when we get to a node, add explore(node) to the list of
nodes to visit. This function should return a list, or None
def inputs(o):
"""
@type o: list
@param o: output L{Result}s
Returns the set of inputs necessary to compute the outputs in o
such that input.owner is None.
"""
results = set()
def seek(r):
op = r.owner
if op is None:
results.add(r)
else:
for input in op.inputs:
seek(input)
for output in o:
seek(output)
return results
@rtype: list of L{Result}
@return: the list of L{Result}s in order of traversal.
@note: a L{Result} will appear at most once in the return value, even if it
appears multiple times in the start parameter.
@postcondition: every element of start is transferred to the returned list.
@postcondition: start is empty.
def results_and_orphans(i, o, except_unreachable_input=False):
"""
@type i: list
@param i: input L{Result}s
@type o: list
@param o: output L{Result}s
Returns the pair (results, orphans). The former is the set of
L{Result}s that are involved in the subgraph that lies between i and
o. This includes i, o, orphans(i, o) and all results of all
intermediary steps from i to o. The second element of the returned
pair is orphans(i, o).
if mode not in ('bfs', 'dfs'):
raise ValueError('mode should be bfs or dfs', mode)
rval_set = set()
rval_list = list()
start_pop = start.popleft if mode is 'bfs' else start.pop
expand_inv = {}
while start:
l = start_pop()
if id(l) not in rval_set:
rval_list.append(l)
rval_set.add(id(l))
expand_l = expand(l)
if expand_l:
if build_inv:
for r in expand_l:
expand_inv.setdefault(r, []).append(l)
start.extend(expand_l)
assert len(rval_list) == len(rval_set)
if build_inv:
return rval_list, expand_inv
return rval_list
@utils.deprecated('gof.graph', 'is this function ever used?')
def inputs(result_list):
"""
results = set()
i = set(i)
results.update(i)
incomplete_paths = []
reached = set()
def helper(r, path):
if r in i:
reached.add(r)
results.update(path)
elif r.owner is None:
incomplete_paths.append(path)
else:
op = r.owner
for r2 in op.inputs:
helper(r2, path + [r2])
for output in o:
helper(output, [output])
orphans = set()
for path in incomplete_paths:
for r in path:
if r not in results:
orphans.add(r)
break
@type result_list: list of L{Result}
@param result_list: output L{Result}s (from which to search backward through owners)
@returns: the list of L{Result}s with no owner, in the order found by a
left-recursive depth-first search started at the L{Result}s in result_list.
if except_unreachable_input and len(i) != len(reached):
raise Exception(results_and_orphans.E_unreached)
"""
def expand(r):
if r.owner:
l = list(r.owner.inputs)
l.reverse()
return l
dfs_results = stack_search(deque(result_list), expand, 'dfs')
rval = [r for r in dfs_results if r.owner is None]
#print rval, _orig_inputs(o)
return rval
def results_and_orphans(r_in, r_out, except_unreachable_input=False):
r_in_set = set(r_in)
class Dummy(object): pass
dummy = Dummy()
dummy.inputs = r_out
def expand_inputs(io):
if io in r_in_set:
return None
try:
return [io.owner] if io.owner != None else None
except AttributeError:
return io.inputs
ops_and_results, dfsinv = stack_search(
deque([dummy]),
expand_inputs, 'dfs', True)
if except_unreachable_input:
for r in r_in:
if r not in dfsinv:
raise Exception(results_and_orphans.E_unreached)
clients = stack_search(
deque(r_in),
lambda io: dfsinv.get(io,None), 'dfs')
ops_to_compute = [o for o in clients if is_op(o) and o is not dummy]
results.update(orphans)
results = []
for o in ops_to_compute:
results.extend(o.inputs)
results.extend(r_out)
op_set = set(ops_to_compute)
assert len(ops_to_compute) == len(op_set)
orphans = [r for r in results \
if (r.owner not in op_set) and (r not in r_in_set)]
return results, orphans
results_and_orphans.E_unreached = 'there were unreachable inputs'
......@@ -203,31 +236,70 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = False):
return d
def general_toposort(r_out, deps):
"""
@note: deps(i) should behave like a pure function (no funny business with
internal state)
def io_toposort(i, o, orderings = {}):
@note: deps(i) can/should be cached by the deps function to be fast
"""
@type i: list
@param i: input L{Result}s
@type o: list
@param o: output L{Result}s
@param orderings: {op: [requirements for op]} (defaults to {})
deps_cache = {}
def _deps(io):
if io not in deps_cache:
d = deps(io)
if d:
deps_cache[io] = list(d)
else:
deps_cache[io] = d
return d
else:
return deps_cache[io]
assert isinstance(r_out, (tuple, list, deque))
reachable, clients = stack_search( deque(r_out), _deps, 'dfs', True)
sources = deque([r for r in reachable if not deps_cache.get(r, None)])
rset = set()
rlist = []
while sources:
node = sources.popleft()
if node not in rset:
rlist.append(node)
rset.add(node)
for client in clients.get(node, []):
deps_cache[client] = [a for a in deps_cache[client] if a is not node]
if not deps_cache[client]:
sources.append(client)
if len(rlist) != len(reachable):
print ''
print reachable
print rlist
raise 'failed to complete topological sort of given nodes'
return rlist
def io_toposort(i, o, orderings = {}):
iset = set(i)
def deps(obj):
rval = []
if obj not in iset:
if isinstance(obj, result.Result):
if obj.owner:
rval = [obj.owner]
if isinstance(obj, op.Op):
rval = list(obj.inputs)
rval.extend(orderings.get(obj, []))
else:
assert not orderings.get(obj, [])
return rval
topo = general_toposort(o, deps)
return [o for o in topo if isinstance(o, op.Op)]
@rtype: ordered list
@return: L{Op}s that belong in the subgraph between i and o which
respects the following constraints:
- all inputs in i are assumed to be already computed
- the L{Op}s that compute an L{Op}'s inputs must be computed before it
- the orderings specified in the optional orderings parameter must be satisfied
Note that this function does not take into account ordering information
related to destructive operations or other special behavior.
"""
prereqs_d = copy(orderings)
all = ops(i, o)
for op in all:
asdf = set([input.owner for input in op.inputs if input.owner and input.owner in all])
prereqs_d.setdefault(op, set()).update(asdf)
return utils.toposort(prereqs_d)
default_leaf_formatter = str
......@@ -262,6 +334,8 @@ def as_string(i, o,
exist for viewing convenience).
"""
i = set(i)
orph = orphans(i, o)
multi = set()
......@@ -349,4 +423,82 @@ class Graph:
if 0:
#these were the old implementations
# they were replaced out of a desire that graph search routines would not
# depend on the hash or id of any node, so that it would be deterministic
# and consistent between program executions.
@utils.deprecated('gof.graph', 'preserving only for review')
def _results_and_orphans(i, o, except_unreachable_input=False):
"""
@type i: list
@param i: input L{Result}s
@type o: list
@param o: output L{Result}s
Returns the pair (results, orphans). The former is the set of
L{Result}s that are involved in the subgraph that lies between i and
o. This includes i, o, orphans(i, o) and all results of all
intermediary steps from i to o. The second element of the returned
pair is orphans(i, o).
"""
results = set()
i = set(i)
results.update(i)
incomplete_paths = []
reached = set()
def helper(r, path):
if r in i:
reached.add(r)
results.update(path)
elif r.owner is None:
incomplete_paths.append(path)
else:
op = r.owner
for r2 in op.inputs:
helper(r2, path + [r2])
for output in o:
helper(output, [output])
orphans = set()
for path in incomplete_paths:
for r in path:
if r not in results:
orphans.add(r)
break
if except_unreachable_input and len(i) != len(reached):
raise Exception(results_and_orphans.E_unreached)
results.update(orphans)
return results, orphans
def _io_toposort(i, o, orderings = {}):
"""
@type i: list
@param i: input L{Result}s
@type o: list
@param o: output L{Result}s
@param orderings: {op: [requirements for op]} (defaults to {})
@rtype: ordered list
@return: L{Op}s that belong in the subgraph between i and o which
respects the following constraints:
- all inputs in i are assumed to be already computed
- the L{Op}s that compute an L{Op}'s inputs must be computed before it
- the orderings specified in the optional orderings parameter must be satisfied
Note that this function does not take into account ordering information
related to destructive operations or other special behavior.
"""
prereqs_d = copy(orderings)
all = ops(i, o)
for op in all:
asdf = set([input.owner for input in op.inputs if input.owner and input.owner in all])
prereqs_d.setdefault(op, set()).update(asdf)
return utils.toposort(prereqs_d)
......@@ -6,7 +6,6 @@ compatible with gof's graph manipulation routines.
import utils
from utils import ClsInit, all_bases, all_bases_collect, AbstractFunctionError
import graph
from copy import copy
......@@ -170,11 +169,11 @@ class Op(object):
# String representation
#
def __str__(self):
return graph.op_as_string(self.inputs, self)
def __repr__(self):
return str(self)
if 0:
def __str__(self):
return graph.op_as_string(self.inputs, self)
def __repr__(self):
return str(self)
#
......
......@@ -19,6 +19,31 @@ class AbstractFunctionError(Exception):
function has been left out of an implementation class.
"""
def deprecated(filename, msg=''):
"""Decorator which will print a warning message on the first call.
Use it like this:
@deprecated('myfile', 'do something different...')
def fn_name(...)
...
And it will print
WARNING myfile.fn_name deprecated. do something different...
"""
def _deprecated(f):
printme = [True]
def g(*args, **kwargs):
if printme[0]:
print 'WARNING: %s.%s deprecated. %s'\
% (filename, f.__name__, msg)
printme[0] = False
return f(*args, **kwargs)
return g
return _deprecated
def uniq(seq):
#TODO: consider building a set out of seq so that the if condition is constant time -JB
return [x for i, x in enumerate(seq) if seq.index(x) == i]
......@@ -36,6 +61,7 @@ def difference(seq1, seq2):
# -> use O(len(seq1) * len(seq2)) algo
return [x for x in seq1 if x not in seq2]
def partition(f, seq):
seqt = []
seqf = []
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论