提交 259716c6 authored 作者: Joseph Turian's avatar Joseph Turian

merge

...@@ -124,9 +124,12 @@ class T_Function(unittest.TestCase): ...@@ -124,9 +124,12 @@ class T_Function(unittest.TestCase):
def test_orphans(self): def test_orphans(self):
gi, go = graph1() gi, go = graph1()
opt = None opt = None
p0 = Function(gi[0:0], go) p0 = Function(gi[0:0], go, keep_locals=True)
self.failUnless(p0() == 1.0) self.failUnless(p0.orphans == go, p0.orphans)
vp0 = p0()
self.failUnless(vp0 == 1.0, vp0)
p3 = Function(gi,go) p3 = Function(gi,go)
p2 = Function(gi[0:2], go) p2 = Function(gi[0:2], go)
......
...@@ -6,6 +6,12 @@ from graph import * ...@@ -6,6 +6,12 @@ from graph import *
from op import Op from op import Op
from result import Result from result import Result
if 1:
testcase = unittest.TestCase
else:
testcase = object
realtestcase = unittest.TestCase
class MyResult(Result): class MyResult(Result):
...@@ -21,10 +27,10 @@ class MyResult(Result): ...@@ -21,10 +27,10 @@ class MyResult(Result):
return isinstance(other, MyResult) and other.thingy == self.thingy return isinstance(other, MyResult) and other.thingy == self.thingy
def __str__(self): def __str__(self):
return str(self.thingy) return 'R%s' % str(self.thingy)
def __repr__(self): def __repr__(self):
return str(self.thingy) return 'R%s' % str(self.thingy)
class MyOp(Op): class MyOp(Op):
...@@ -37,18 +43,19 @@ class MyOp(Op): ...@@ -37,18 +43,19 @@ class MyOp(Op):
self.outputs = [MyResult(sum([input.thingy for input in inputs]))] self.outputs = [MyResult(sum([input.thingy for input in inputs]))]
class _test_inputs(unittest.TestCase): class _test_inputs(testcase):
def test_straightforward(self): def test_straightforward(self):
r1, r2 = MyResult(1), MyResult(2) r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2) op = MyOp(r1, r2)
assert inputs(op.outputs) == set([r1, r2]) assert inputs(op.outputs) == [r1, r2]
def test_deep(self): def test_deep(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5) op2 = MyOp(op.outputs[0], r5)
assert inputs(op2.outputs) == set([r1, r2, r5]) i = inputs(op2.outputs)
self.failUnless(i == [r1, r2, r5], i)
def test_unreached_inputs(self): def test_unreached_inputs(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
...@@ -57,23 +64,26 @@ class _test_inputs(unittest.TestCase): ...@@ -57,23 +64,26 @@ class _test_inputs(unittest.TestCase):
try: try:
# function doesn't raise if we put False instead of True # function doesn't raise if we put False instead of True
ro = results_and_orphans([r1, r2, op2.outputs[0]], op.outputs, True) ro = results_and_orphans([r1, r2, op2.outputs[0]], op.outputs, True)
self.fail()
except Exception, e: except Exception, e:
if e[0] is results_and_orphans.E_unreached: if e[0] is results_and_orphans.E_unreached:
return return
raise self.fail()
def test_uz(self):
pass
class _test_orphans(unittest.TestCase): class _test_orphans(testcase):
def test_straightforward(self): def test_straightforward(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5) op2 = MyOp(op.outputs[0], r5)
assert orphans([r1, r2], op2.outputs) == set([r5]) orph = orphans([r1, r2], op2.outputs)
self.failUnless(orph == [r5], orph)
class _test_as_string(unittest.TestCase): class _test_as_string(testcase):
leaf_formatter = str leaf_formatter = str
node_formatter = lambda op, argstrings: "%s(%s)" % (op.__class__.__name__, node_formatter = lambda op, argstrings: "%s(%s)" % (op.__class__.__name__,
...@@ -82,35 +92,37 @@ class _test_as_string(unittest.TestCase): ...@@ -82,35 +92,37 @@ class _test_as_string(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
r1, r2 = MyResult(1), MyResult(2) r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2) op = MyOp(r1, r2)
assert as_string([r1, r2], op.outputs) == ["MyOp(1, 2)"] s = as_string([r1, r2], op.outputs)
self.failUnless(s == ["MyOp(R1, R2)"], s)
def test_deep(self): def test_deep(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5) op2 = MyOp(op.outputs[0], r5)
assert as_string([r1, r2, r5], op2.outputs) == ["MyOp(MyOp(1, 2), 5)"] s = as_string([r1, r2, r5], op2.outputs)
self.failUnless(s == ["MyOp(MyOp(R1, R2), R5)"], s)
def test_multiple_references(self): def test_multiple_references(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], op.outputs[0]) op2 = MyOp(op.outputs[0], op.outputs[0])
assert as_string([r1, r2, r5], op2.outputs) == ["MyOp(*1 -> MyOp(1, 2), *1)"] assert as_string([r1, r2, r5], op2.outputs) == ["MyOp(*1 -> MyOp(R1, R2), *1)"]
def test_cutoff(self): def test_cutoff(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], op.outputs[0]) op2 = MyOp(op.outputs[0], op.outputs[0])
assert as_string(op.outputs, op2.outputs) == ["MyOp(3, 3)"] assert as_string(op.outputs, op2.outputs) == ["MyOp(R3, R3)"]
assert as_string(op2.inputs, op2.outputs) == ["MyOp(3, 3)"] assert as_string(op2.inputs, op2.outputs) == ["MyOp(R3, R3)"]
class _test_clone(unittest.TestCase): class _test_clone(testcase):
def test_accurate(self): def test_accurate(self):
r1, r2 = MyResult(1), MyResult(2) r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2) op = MyOp(r1, r2)
new = clone([r1, r2], op.outputs) new = clone([r1, r2], op.outputs)
assert as_string([r1, r2], new) == ["MyOp(1, 2)"] assert as_string([r1, r2], new) == ["MyOp(R1, R2)"]
def test_copy(self): def test_copy(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
...@@ -130,13 +142,13 @@ class _test_clone(unittest.TestCase): ...@@ -130,13 +142,13 @@ class _test_clone(unittest.TestCase):
new_op = new[0].owner new_op = new[0].owner
new_op.inputs = MyResult(7), MyResult(8) new_op.inputs = MyResult(7), MyResult(8)
assert as_string(inputs(new_op.outputs), new_op.outputs) == ["MyOp(7, 8)"] s = as_string(inputs(new_op.outputs), new_op.outputs)
assert as_string(inputs(op.outputs), op.outputs) == ["MyOp(MyOp(1, 2), 5)"] self.failUnless( s == ["MyOp(R7, R8)"], s)
assert as_string(inputs(op.outputs), op.outputs) == ["MyOp(MyOp(R1, R2), R5)"]
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
#_test_inputs('test_unreached_inputs').debug()
from copy import copy from copy import copy
from collections import deque
import utils import utils
import result, op
__all__ = ['inputs', __all__ = ['inputs',
...@@ -14,33 +16,72 @@ __all__ = ['inputs', ...@@ -14,33 +16,72 @@ __all__ = ['inputs',
'as_string', 'as_string',
'Graph'] 'Graph']
is_result = lambda o: isinstance(o, result.Result)
is_op = lambda o: isinstance(o, op.Op)
is_result = utils.attr_checker('owner', 'index') def stack_search(start, expand, mode='bfs', build_inv = False):
is_op = utils.attr_checker('inputs', 'outputs') """Search through L{Result}s, either breadth- or depth-first
@type start: deque
@param start: search from these nodes
@type explore: function
@param explore: when we get to a node, add explore(node) to the list of
nodes to visit. This function should return a list, or None
def inputs(o): @rtype: list of L{Result}
""" @return: the list of L{Result}s in order of traversal.
@type o: list
@param o: output L{Result}s @note: a L{Result} will appear at most once in the return value, even if it
appears multiple times in the start parameter.
Returns the set of inputs necessary to compute the outputs in o @postcondition: every element of start is transferred to the returned list.
such that input.owner is None.
""" @postcondition: start is empty.
results = set()
def seek(r):
op = r.owner
if op is None:
results.add(r)
else:
for input in op.inputs:
seek(input)
for output in o:
seek(output)
return results
"""
if mode not in ('bfs', 'dfs'):
raise ValueError('mode should be bfs or dfs', mode)
rval_set = set()
rval_list = list()
start_pop = start.popleft if mode is 'bfs' else start.pop
expand_inv = {}
while start:
l = start_pop()
if id(l) not in rval_set:
rval_list.append(l)
rval_set.add(id(l))
expand_l = expand(l)
if expand_l:
if build_inv:
for r in expand_l:
expand_inv.setdefault(r, []).append(l)
start.extend(expand_l)
assert len(rval_list) == len(rval_set)
if build_inv:
return rval_list, expand_inv
return rval_list
def inputs(result_list):
"""
@type result_list: list of L{Result}
@param result_list: output L{Result}s (from which to search backward through owners)
@returns: the list of L{Result}s with no owner, in the order found by a
left-recursive depth-first search started at the L{Result}s in result_list.
def results_and_orphans(i, o, except_unreachable_input=False): """
def expand(r):
if r.owner:
l = list(r.owner.inputs)
l.reverse()
return l
dfs_results = stack_search(deque(result_list), expand, 'dfs')
rval = [r for r in dfs_results if r.owner is None]
#print rval, _orig_inputs(o)
return rval
@utils.deprecated('gof.graph', 'preserving only for review')
def _results_and_orphans(i, o, except_unreachable_input=False):
""" """
@type i: list @type i: list
@param i: input L{Result}s @param i: input L{Result}s
...@@ -86,6 +127,44 @@ def results_and_orphans(i, o, except_unreachable_input=False): ...@@ -86,6 +127,44 @@ def results_and_orphans(i, o, except_unreachable_input=False):
results.update(orphans) results.update(orphans)
return results, orphans return results, orphans
def results_and_orphans(r_in, r_out, except_unreachable_input=False):
r_in_set = set(r_in)
class Dummy(object): pass
dummy = Dummy()
dummy.inputs = r_out
def expand_inputs(io):
if io in r_in_set:
return None
try:
return [io.owner] if io.owner != None else None
except AttributeError:
return io.inputs
ops_and_results, dfsinv = stack_search(
deque([dummy]),
expand_inputs, 'dfs', True)
if except_unreachable_input:
for r in r_in:
if r not in dfsinv:
raise Exception(results_and_orphans.E_unreached)
clients = stack_search(
deque(r_in),
lambda io: dfsinv.get(io,None), 'dfs')
ops_to_compute = [o for o in clients if is_op(o) and o is not dummy]
results = []
for o in ops_to_compute:
results.extend(o.inputs)
results.extend(r_out)
op_set = set(ops_to_compute)
assert len(ops_to_compute) == len(op_set)
orphans = [r for r in results \
if (r.owner not in op_set) and (r not in r_in_set)]
return results, orphans
results_and_orphans.E_unreached = 'there were unreachable inputs' results_and_orphans.E_unreached = 'there were unreachable inputs'
...@@ -262,6 +341,8 @@ def as_string(i, o, ...@@ -262,6 +341,8 @@ def as_string(i, o,
exist for viewing convenience). exist for viewing convenience).
""" """
i = set(i)
orph = orphans(i, o) orph = orphans(i, o)
multi = set() multi = set()
......
...@@ -6,7 +6,6 @@ compatible with gof's graph manipulation routines. ...@@ -6,7 +6,6 @@ compatible with gof's graph manipulation routines.
import utils import utils
from utils import ClsInit, all_bases, all_bases_collect, AbstractFunctionError from utils import ClsInit, all_bases, all_bases_collect, AbstractFunctionError
import graph
from copy import copy from copy import copy
...@@ -170,11 +169,11 @@ class Op(object): ...@@ -170,11 +169,11 @@ class Op(object):
# String representation # String representation
# #
def __str__(self): if 0:
return graph.op_as_string(self.inputs, self) def __str__(self):
return graph.op_as_string(self.inputs, self)
def __repr__(self): def __repr__(self):
return str(self) return str(self)
# #
......
...@@ -19,6 +19,31 @@ class AbstractFunctionError(Exception): ...@@ -19,6 +19,31 @@ class AbstractFunctionError(Exception):
function has been left out of an implementation class. function has been left out of an implementation class.
""" """
def deprecated(filename, msg=''):
"""Decorator which will print a warning message on the first call.
Use it like this:
@deprecated('myfile', 'do something different...')
def fn_name(...)
...
And it will print
WARNING myfile.fn_name deprecated. do something different...
"""
def _deprecated(f):
printme = [True]
def g(*args, **kwargs):
if printme[0]:
print 'WARNING: %s.%s deprecated. %s'\
% (filename, f.__name__, msg)
printme[0] = False
return f(*args, **kwargs)
return g
return _deprecated
def uniq(seq): def uniq(seq):
#TODO: consider building a set out of seq so that the if condition is constant time -JB #TODO: consider building a set out of seq so that the if condition is constant time -JB
return [x for i, x in enumerate(seq) if seq.index(x) == i] return [x for i, x in enumerate(seq) if seq.index(x) == i]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论