upgraded DestroyHandler to handle result.indestructible when set, added mucho tests

上级 a5129f65
import unittest
from result import ResultBase
from op import Op
from opt import PatternOptimizer, OpSubOptimizer
from ext import *
from env import Env, InconsistencyError
from toolbox import EquivTool
class MyResult(ResultBase):
def __init__(self, name):
ResultBase.__init__(self, role = None, data = [1000], constant = False, name = name)
def __str__(self):
return self.name
def __repr__(self):
return self.name
class MyOp(Op):
nin = -1
def __new__(cls, *inputs):
op = Op.__new__(cls)
op.__init__(*inputs)
return op.out
def __init__(self, *inputs):
assert len(inputs) == self.nin
for input in inputs:
if not isinstance(input, MyResult):
raise Exception("Error 1")
self.inputs = inputs
self.outputs = [MyResult(self.__class__.__name__ + "_R")]
class Sigmoid(MyOp):
nin = 1
class TransposeView(MyOp, Viewer):
nin = 1
def view_map(self):
return {self.outputs[0]: [self.inputs[0]]}
class Add(MyOp):
nin = 2
class AddInPlace(MyOp, Destroyer):
nin = 2
def destroyed_inputs(self):
return self.inputs[:1]
class Dot(MyOp):
nin = 2
dtv_elim = PatternOptimizer((TransposeView, (TransposeView, 'x')), 'x')
a2i = OpSubOptimizer(Add, AddInPlace)
i2a = OpSubOptimizer(AddInPlace, Add)
t2s = OpSubOptimizer(TransposeView, Sigmoid)
s2t = OpSubOptimizer(Sigmoid, TransposeView)
class _test_all(unittest.TestCase):
def inputs(self):
x = MyResult('x')
y = MyResult('y')
z = MyResult('z')
return x, y, z
def env(self, inputs, outputs, validate = True):
return Env(inputs, outputs, features = [EquivTool], consistency_check = validate)
def test_0(self):
x, y, z = self.inputs()
e = Add(AddInPlace(x, y), AddInPlace(x, y))
try:
g = self.env([x,y,z], [e])
except InconsistencyError, e:
pass
else:
raise Exception("Expected an InconsistencyError")
def test_1(self):
# the loop is needed because a2i will optimize in a random order and sometimes
# only one of them fails
for i in xrange(100):
x, y, z = self.inputs()
e = Add(Add(x, y), Add(y, x))
g = self.env([x,y,z], [e])
assert g.consistent()
a2i.optimize(g)
assert g.consistent()
assert str(g) != "[AddInPlace(AddInPlace(x, y), AddInPlace(y, x))]"
def test_2(self):
x, y, z = self.inputs()
g = self.env([x,y,z], [Dot(AddInPlace(x, z), x)], False)
assert not g.consistent()
i2a.optimize(g)
assert g.consistent()
def test_3(self):
for i in xrange(100):
x, y, z = self.inputs()
e = Dot(Add(TransposeView(z), y), Add(z, x))
g = self.env([x,y,z], [e])
assert g.consistent()
a2i.optimize(g)
assert g.consistent()
assert str(g) != "[Dot(AddInPlace(TransposeView(z), y), AddInPlace(z, x))]"
def test_4(self):
x, y, z = self.inputs()
e = Dot(AddInPlace(x,y), TransposeView(x))
g = self.env([x,y,z], [e], False)
assert not g.consistent()
g.replace(e.owner.inputs[1], Add(x,z))
assert g.consistent()
def test_5(self):
x, y, z = self.inputs()
e = Dot(AddInPlace(x,y), TransposeView(TransposeView(TransposeView(TransposeView(Sigmoid(x))))))
g = self.env([x,y,z], [e])
assert g.consistent()
g.replace(e.owner.inputs[1].owner.inputs[0], x, False)
assert not g.consistent()
def test_6(self):
for i in xrange(100):
x, y, z = self.inputs()
e = Dot(AddInPlace(x,Sigmoid(y)), Sigmoid(Sigmoid(Sigmoid(Sigmoid(Sigmoid(x))))))
g = self.env([x,y,z], [e])
assert g.consistent()
s2t.optimize(g)
assert g.consistent()
assert str(g) != "[Dot(AddInPlace(x,TransposeView(y)), TransposeView(TransposeView(TransposeView(TransposeView(TransposeView(x))))))]"
def test_7(self):
x, y, z = self.inputs()
e = TransposeView(TransposeView(TransposeView(TransposeView(x))))
g = self.env([x,y,z], [e])
assert g.consistent()
chk = g.checkpoint()
dtv_elim.optimize(g)
assert str(g) == "[x]"
g.replace(g.equiv(e), Add(x,y))
assert str(g) == "[Add(x, y)]"
g.replace(g.equiv(e), Dot(AddInPlace(x,y), TransposeView(x)), False)
assert str(g) == "[Dot(AddInPlace(x, y), TransposeView(x))]"
assert not g.consistent()
g.revert(chk)
assert g.consistent()
assert str(g) == "[TransposeView(TransposeView(TransposeView(TransposeView(x))))]"
def test_8(self):
x, y, z = self.inputs()
e = Dot(Dot(AddInPlace(x,y), AddInPlace(y,z)), Add(z,x))
g = self.env([x,y,z], [e])
assert g.consistent()
a2i.optimize(g)
assert g.consistent()
assert str(g) != "[Dot(Dot(AddInPlace(x, y), AddInPlace(y, z)), AddInPlace(z, x))]" # we don't want to see that!
def test_9(self):
x, y, z = self.inputs()
x.indestructible = True
e = AddInPlace(x, y)
g = self.env([x,y,z], [e], False)
assert not g.consistent()
g.replace(e, Add(x, y))
assert g.consistent()
def test_10(self):
x, y, z = self.inputs()
x.indestructible = True
tv = TransposeView(x)
e = AddInPlace(tv, y)
g = self.env([x,y,z], [e], False)
assert not g.consistent()
g.replace(tv, Sigmoid(x))
assert g.consistent()
if __name__ == '__main__':
unittest.main()
...@@ -4,7 +4,7 @@ import unittest ...@@ -4,7 +4,7 @@ import unittest
from graph import * from graph import *
from op import Op from op import Op
from result import ResultBase, BrokenLinkError from result import ResultBase
class MyResult(ResultBase): class MyResult(ResultBase):
...@@ -14,6 +14,9 @@ class MyResult(ResultBase): ...@@ -14,6 +14,9 @@ class MyResult(ResultBase):
ResultBase.__init__(self, role = None, data = [self.thingy], constant = False) ResultBase.__init__(self, role = None, data = [self.thingy], constant = False)
def __eq__(self, other): def __eq__(self, other):
return self.same_properties(other)
def same_properties(self, other):
return isinstance(other, MyResult) and other.thingy == self.thingy return isinstance(other, MyResult) and other.thingy == self.thingy
def __str__(self): def __str__(self):
...@@ -25,11 +28,12 @@ class MyResult(ResultBase): ...@@ -25,11 +28,12 @@ class MyResult(ResultBase):
class MyOp(Op): class MyOp(Op):
def validate_update(self): def __init__(self, *inputs):
for input in self.inputs: for input in inputs:
if not isinstance(input, MyResult): if not isinstance(input, MyResult):
raise Exception("Error 1") raise Exception("Error 1")
self.outputs = [MyResult(sum([input.thingy for input in self.inputs]))] self.inputs = inputs
self.outputs = [MyResult(sum([input.thingy for input in inputs]))]
class _test_inputs(unittest.TestCase): class _test_inputs(unittest.TestCase):
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import unittest import unittest
from copy import copy from copy import copy
from op import * from op import *
from result import ResultBase, BrokenLinkError from result import ResultBase #, BrokenLinkError
class MyResult(ResultBase): class MyResult(ResultBase):
...@@ -12,6 +12,9 @@ class MyResult(ResultBase): ...@@ -12,6 +12,9 @@ class MyResult(ResultBase):
ResultBase.__init__(self, role = None, data = [self.thingy], constant = False) ResultBase.__init__(self, role = None, data = [self.thingy], constant = False)
def __eq__(self, other): def __eq__(self, other):
return self.same_properties(other)
def same_properties(self, other):
return isinstance(other, MyResult) and other.thingy == self.thingy return isinstance(other, MyResult) and other.thingy == self.thingy
def __str__(self): def __str__(self):
...@@ -23,11 +26,33 @@ class MyResult(ResultBase): ...@@ -23,11 +26,33 @@ class MyResult(ResultBase):
class MyOp(Op): class MyOp(Op):
def validate_update(self): def __init__(self, *inputs):
for input in self.inputs: for input in inputs:
if not isinstance(input, MyResult): if not isinstance(input, MyResult):
raise Exception("Error 1") raise Exception("Error 1")
self.outputs = [MyResult(sum([input.thingy for input in self.inputs]))] self.inputs = inputs
self.outputs = [MyResult(sum([input.thingy for input in inputs]))]
# def validate_update(self):
# for input in self.inputs:
# if not isinstance(input, MyResult):
# raise Exception("Error 1")
# if self.outputs is None:
# self.outputs = [MyResult(sum([input.thingy for input in self.inputs]))]
# return True
# else:
# old_thingy = self.outputs[0].thingy
# new_thingy = sum([input.thingy for input in self.inputs])
# self.outputs[0].thingy = new_thingy
# return old_thingy != new_thingy
# class MyOp(Op):
# def validate_update(self):
# for input in self.inputs:
# if not isinstance(input, MyResult):
# raise Exception("Error 1")
# self.outputs = [MyResult(sum([input.thingy for input in self.inputs]))]
class _test_Op(unittest.TestCase): class _test_Op(unittest.TestCase):
...@@ -38,6 +63,7 @@ class _test_Op(unittest.TestCase): ...@@ -38,6 +63,7 @@ class _test_Op(unittest.TestCase):
op = MyOp(r1, r2) op = MyOp(r1, r2)
assert op.inputs == [r1, r2] # Are the inputs what I provided? assert op.inputs == [r1, r2] # Are the inputs what I provided?
assert op.outputs == [MyResult(3)] # Are the outputs what I expect? assert op.outputs == [MyResult(3)] # Are the outputs what I expect?
assert op.outputs[0].owner is op and op.outputs[0].index == 0
# validate_update # validate_update
def test_validate_update(self): def test_validate_update(self):
...@@ -48,97 +74,99 @@ class _test_Op(unittest.TestCase): ...@@ -48,97 +74,99 @@ class _test_Op(unittest.TestCase):
else: else:
raise Exception("Expected an exception") raise Exception("Expected an exception")
# Setting inputs and outputs # # Setting inputs and outputs
def test_set_inputs(self): # def test_set_inputs(self):
r1, r2 = MyResult(1), MyResult(2) # r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2) # op = MyOp(r1, r2)
r3 = op.outputs[0] # r3 = op.outputs[0]
op.inputs = MyResult(4), MyResult(5) # op.inputs = MyResult(4), MyResult(5)
assert op.outputs == [MyResult(9)] # check if the output changed to what I expect # op.validate_update()
assert r3.data is op.outputs[0].data # check if the data was properly transferred by set_output # assert op.outputs == [MyResult(9)] # check if the output changed to what I expect
# # assert r3.data is op.outputs[0].data # check if the data was properly transferred by set_output
def test_set_bad_inputs(self):
op = MyOp(MyResult(1), MyResult(2)) # def test_set_bad_inputs(self):
try: # op = MyOp(MyResult(1), MyResult(2))
op.inputs = MyResult(4), ResultBase() # try:
except Exception, e: # op.inputs = MyResult(4), ResultBase()
assert str(e) == "Error 1" # op.validate_update()
else: # except Exception, e:
raise Exception("Expected an exception") # assert str(e) == "Error 1"
# else:
def test_set_outputs(self): # raise Exception("Expected an exception")
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2) # here we only make one output # def test_set_outputs(self):
try: # r1, r2 = MyResult(1), MyResult(2)
op.outputs = MyResult(10), MyResult(11) # setting two outputs should fail # op = MyOp(r1, r2) # here we only make one output
except TypeError, e: # try:
assert str(e) == "The new outputs must be exactly as many as the previous outputs." # op.outputs = MyResult(10), MyResult(11) # setting two outputs should fail
else: # except TypeError, e:
raise Exception("Expected an exception") # assert str(e) == "The new outputs must be exactly as many as the previous outputs."
# else:
# raise Exception("Expected an exception")
# Tests about broken links
def test_create_broken_link(self):
r1, r2 = MyResult(1), MyResult(2) # # Tests about broken links
op = MyOp(r1, r2) # def test_create_broken_link(self):
r3 = op.out # r1, r2 = MyResult(1), MyResult(2)
op.inputs = MyResult(3), MyResult(4) # op = MyOp(r1, r2)
assert r3 not in op.outputs # r3 = op.out
assert r3.replaced # op.inputs = MyResult(3), MyResult(4)
# assert r3 not in op.outputs
def test_cannot_copy_when_input_is_broken_link(self): # assert r3.replaced
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2) # def test_cannot_copy_when_input_is_broken_link(self):
r3 = op.out # r1, r2 = MyResult(1), MyResult(2)
op2 = MyOp(r3) # op = MyOp(r1, r2)
op.inputs = MyResult(3), MyResult(4) # r3 = op.out
try: # op2 = MyOp(r3)
copy(op2) # op.inputs = MyResult(3), MyResult(4)
except BrokenLinkError: # try:
pass # copy(op2)
else: # except BrokenLinkError:
raise Exception("Expected an exception") # pass
# else:
def test_get_input_broken_link(self): # raise Exception("Expected an exception")
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2) # def test_get_input_broken_link(self):
r3 = op.out # r1, r2 = MyResult(1), MyResult(2)
op2 = MyOp(r3) # op = MyOp(r1, r2)
op.inputs = MyResult(3), MyResult(4) # r3 = op.out
try: # op2 = MyOp(r3)
op2.get_input(0) # op.inputs = MyResult(3), MyResult(4)
except BrokenLinkError: # try:
pass # op2.get_input(0)
else: # except BrokenLinkError:
raise Exception("Expected an exception") # pass
# else:
def test_get_inputs_broken_link(self): # raise Exception("Expected an exception")
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2) # def test_get_inputs_broken_link(self):
r3 = op.out # r1, r2 = MyResult(1), MyResult(2)
op2 = MyOp(r3) # op = MyOp(r1, r2)
op.inputs = MyResult(3), MyResult(4) # r3 = op.out
try: # op2 = MyOp(r3)
op2.get_inputs() # op.inputs = MyResult(3), MyResult(4)
except BrokenLinkError: # try:
pass # op2.get_inputs()
else: # except BrokenLinkError:
raise Exception("Expected an exception") # pass
# else:
def test_repair_broken_link(self): # raise Exception("Expected an exception")
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2) # def test_repair_broken_link(self):
r3 = op.out # r1, r2 = MyResult(1), MyResult(2)
op2 = MyOp(r3, MyResult(10)) # op = MyOp(r1, r2)
op.inputs = MyResult(3), MyResult(4) # r3 = op.out
op2.repair() # op2 = MyOp(r3, MyResult(10))
assert op2.outputs == [MyResult(17)] # op.inputs = MyResult(3), MyResult(4)
# op2.repair()
# Tests about string representation # assert op2.outputs == [MyResult(17)]
def test_create_broken_link(self):
r1, r2 = MyResult(1), MyResult(2) # # Tests about string representation
op = MyOp(r1, r2) # def test_create_broken_link(self):
assert str(op) == "MyOp(1, 2)" # r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2)
# assert str(op) == "MyOp(1, 2)"
......
import unittest
from result import ResultBase
from op import Op
from opt import *
from env import Env
from toolbox import *
class MyResult(ResultBase):
def __init__(self, name):
ResultBase.__init__(self, role = None, data = [1000], constant = False, name = name)
def __str__(self):
return self.name
def __repr__(self):
return self.name
class MyOp(Op):
def __new__(cls, *inputs):
op = Op.__new__(cls)
op.__init__(*inputs)
return op.out
def __init__(self, *inputs):
for input in inputs:
if not isinstance(input, MyResult):
raise Exception("Error 1")
self.inputs = inputs
self.outputs = [MyResult(self.__class__.__name__ + "_R")]
class Op1(MyOp):
pass
class Op2(MyOp):
pass
class Op3(MyOp):
pass
class Op4(MyOp):
pass
def inputs():
x = MyResult('x')
y = MyResult('y')
z = MyResult('z')
return x, y, z
def env(inputs, outputs, validate = True):
return Env(inputs, outputs, features = [EquivTool], consistency_check = validate)
class _test_PatternOptimizer(unittest.TestCase):
def test_0(self):
x, y, z = inputs()
e = Op1(Op2(x, y), z)
g = env([x, y, z], [e])
PatternOptimizer((Op1, (Op2, '1', '2'), '3'),
(Op4, '3', '2')).optimize(g)
assert str(g) == "[Op4(z, y)]"
def test_1(self):
x, y, z = inputs()
e = Op1(Op2(x, y), z)
g = env([x, y, z], [e])
PatternOptimizer((Op1, (Op2, '1', '1'), '2'),
(Op4, '2', '1')).optimize(g)
assert str(g) != "[Op4(z, y)]"
def test_2(self):
x, y, z = inputs()
e = Op1(Op2(x, y), z)
g = env([x, y, z], [e])
PatternOptimizer((Op2, '1', '2'),
(Op1, '2', '1')).optimize(g)
assert str(g) == "[Op1(Op1(y, x), z)]"
def test_3(self):
x, y, z = inputs()
e = Op1(Op2(x, y), Op2(x, y), Op2(y, z))
g = env([x, y, z], [e])
PatternOptimizer((Op2, '1', '2'),
(Op4, '1')).optimize(g)
assert str(g) == "[Op1(Op4(x), Op4(x), Op4(y))]"
def test_4(self):
x, y, z = inputs()
e = Op1(Op1(Op1(Op1(x))))
g = env([x, y, z], [e])
PatternOptimizer((Op1, (Op1, '1')),
'1').optimize(g)
assert str(g) == "[x]"
def test_5(self):
x, y, z = inputs()
e = Op1(Op1(Op1(Op1(Op1(x)))))
g = env([x, y, z], [e])
PatternOptimizer((Op1, (Op1, '1')),
'1').optimize(g)
assert str(g) == "[Op1(x)]"
class _test_OpSubOptimizer(unittest.TestCase):
def test_0(self):
x, y, z = inputs()
e = Op1(Op1(Op1(Op1(Op1(x)))))
g = env([x, y, z], [e])
OpSubOptimizer(Op1, Op2).optimize(g)
assert str(g) == "[Op2(Op2(Op2(Op2(Op2(x)))))]"
def test_1(self):
x, y, z = inputs()
e = Op1(Op2(x), Op3(y), Op4(z))
g = env([x, y, z], [e])
OpSubOptimizer(Op3, Op4).optimize(g)
assert str(g) == "[Op1(Op2(x), Op4(y), Op4(z))]"
if __name__ == '__main__':
unittest.main()
import unittest
from result import ResultBase
from op import Op
from opt import PatternOptimizer, OpSubOptimizer
from env import Env, InconsistencyError
from toolbox import *
class MyResult(ResultBase):
def __init__(self, name):
ResultBase.__init__(self, role = None, data = [1000], constant = False, name = name)
def __str__(self):
return self.name
def __repr__(self):
return self.name
class MyOp(Op):
nin = -1
def __new__(cls, *inputs):
op = Op.__new__(cls)
op.__init__(*inputs)
return op.out
def __init__(self, *inputs):
assert len(inputs) == self.nin
for input in inputs:
if not isinstance(input, MyResult):
raise Exception("Error 1")
self.inputs = inputs
self.outputs = [MyResult(self.__class__.__name__ + "_R")]
class Sigmoid(MyOp):
nin = 1
class Add(MyOp):
nin = 2
class Dot(MyOp):
nin = 2
def inputs():
x = MyResult('x')
y = MyResult('y')
z = MyResult('z')
return x, y, z
class _test_EquivTool(unittest.TestCase):
def test_0(self):
x, y, z = inputs()
sx = Sigmoid(x)
e = Add(sx, Sigmoid(y))
g = Env([x, y, z], [e], features = [EquivTool])
assert g.equiv(sx) is sx
g.replace(sx, Dot(x, z))
assert g.equiv(sx) is not sx
assert isinstance(g.equiv(sx).owner, Dot)
if __name__ == '__main__':
unittest.main()
...@@ -8,6 +8,7 @@ from op import Op ...@@ -8,6 +8,7 @@ from op import Op
from result import is_result from result import is_result
from features import Listener, Orderings, Constraint, Tool, uniq_features from features import Listener, Orderings, Constraint, Tool, uniq_features
import utils import utils
from utils import AbstractFunctionError
__all__ = ['InconsistencyError', __all__ = ['InconsistencyError',
'Env'] 'Env']
...@@ -104,6 +105,8 @@ class Env(graph.Graph): ...@@ -104,6 +105,8 @@ class Env(graph.Graph):
self.history = [] self.history = []
self.__import_r__(self.outputs) self.__import_r__(self.outputs)
for op in self.ops():
self.satisfy(op)
if consistency_check: if consistency_check:
self.validate() self.validate()
...@@ -161,7 +164,11 @@ class Env(graph.Graph): ...@@ -161,7 +164,11 @@ class Env(graph.Graph):
self._listeners[feature_class] = feature self._listeners[feature_class] = feature
if do_import: if do_import:
for op in self.io_toposort(): for op in self.io_toposort():
try:
# print op
feature.on_import(op) feature.on_import(op)
except AbstractFunctionError:
pass
if issubclass(feature_class, Constraint): if issubclass(feature_class, Constraint):
self._constraints[feature_class] = feature self._constraints[feature_class] = feature
if issubclass(feature_class, Orderings): if issubclass(feature_class, Orderings):
...@@ -235,6 +242,8 @@ class Env(graph.Graph): ...@@ -235,6 +242,8 @@ class Env(graph.Graph):
if not is_result(new_r): if not is_result(new_r):
raise TypeError(new_r) raise TypeError(new_r)
self.__import_r_satisfy__([new_r])
# Save where we are so we can backtrack # Save where we are so we can backtrack
if consistency_check: if consistency_check:
chk = self.checkpoint() chk = self.checkpoint()
...@@ -256,7 +265,7 @@ class Env(graph.Graph): ...@@ -256,7 +265,7 @@ class Env(graph.Graph):
# The actual replacement operation occurs here. This might raise # The actual replacement operation occurs here. This might raise
# an error. # an error.
self.__move_clients__(clients, r, new_r) self.__move_clients__(clients, r, new_r) # not sure how to order this wrt to adjusting the outputs
# This function undoes the replacement. # This function undoes the replacement.
def undo(): def undo():
...@@ -344,6 +353,11 @@ class Env(graph.Graph): ...@@ -344,6 +353,11 @@ class Env(graph.Graph):
if not self._clients[r]: if not self._clients[r]:
del self._clients[r] del self._clients[r]
def __import_r_satisfy__(self, results):
for op in graph.ops(self.results(), results):
self.satisfy(op)
def __import_r__(self, results): def __import_r__(self, results):
for result in results: for result in results:
owner = result.owner owner = result.owner
...@@ -358,7 +372,6 @@ class Env(graph.Graph): ...@@ -358,7 +372,6 @@ class Env(graph.Graph):
new_ops = graph.io_toposort(self.results(), op.outputs) new_ops = graph.io_toposort(self.results(), op.outputs)
for op in new_ops: for op in new_ops:
self.satisfy(op) # add the features required by this op
self._ops.add(op) self._ops.add(op)
self._results.update(op.outputs) self._results.update(op.outputs)
...@@ -374,7 +387,10 @@ class Env(graph.Graph): ...@@ -374,7 +387,10 @@ class Env(graph.Graph):
self._results.add(input) self._results.add(input)
for listener in self._listeners.values(): for listener in self._listeners.values():
try:
listener.on_import(op) listener.on_import(op)
except AbstractFunctionError:
pass
def __prune_r__(self, results): def __prune_r__(self, results):
for result in set(results): for result in set(results):
...@@ -393,35 +409,48 @@ class Env(graph.Graph): ...@@ -393,35 +409,48 @@ class Env(graph.Graph):
self._results.difference_update(op.outputs) self._results.difference_update(op.outputs)
for listener in self._listeners.values(): for listener in self._listeners.values():
try:
listener.on_prune(op) listener.on_prune(op)
except AbstractFunctionError:
pass
for i, input in enumerate(op.inputs): for i, input in enumerate(op.inputs):
self.__remove_clients__(input, [(op, i)]) self.__remove_clients__(input, [(op, i)])
self.__prune_r__(op.inputs) self.__prune_r__(op.inputs)
def __move_clients__(self, clients, r, new_r): def __move_clients__(self, clients, r, new_r):
# We import the new result in the fold
self.__import_r__([new_r])
try: try:
# Try replacing the inputs # Try replacing the inputs
for op, i in clients: for op, i in clients:
op.set_input(i, new_r, False) op.set_input(i, new_r)
except GofTypeError, PropagationError: except:
# Oops! # Oops!
for op, i in clients: for op, i in clients:
op.set_input(i, r, False) op.set_input(i, r)
self.__prune_r__([new_r])
raise raise
self.__remove_clients__(r, clients) self.__remove_clients__(r, clients)
self.__add_clients__(new_r, clients) self.__add_clients__(new_r, clients)
# We import the new result in the fold # # We import the new result in the fold
self.__import_r__([new_r]) # # why was this line AFTER the set_inputs???
# # if we do it here then satisfy in import fucks up...
# self.__import_r__([new_r])
for listener in self._listeners.values(): for listener in self._listeners.values():
try:
listener.on_rewire(clients, r, new_r) listener.on_rewire(clients, r, new_r)
except AbstractFunctionError:
pass
# We try to get rid of the old one # We try to get rid of the old one
self.__prune_r__([r]) self.__prune_r__([r])
def __str__(self): def __str__(self):
return graph.as_string(self.inputs, self.outputs) return "[%s]" % ", ".join(graph.as_string(self.inputs, self.outputs))
# from copy import copy from features import Listener, Constraint, Orderings
# from op import Op from utils import AbstractFunctionError
# from lib import DummyOp
# from features import Listener, Constraint, Orderings
# from env import InconsistencyError
# from utils import ClsInit
# import graph
from copy import copy
from env import InconsistencyError
__all__ = ['Destroyer', 'Viewer']
__all__ = ['Destroyer',
'Viewer',
'DestroyHandler',
]
class DestroyHandler(Listener, Constraint, Orderings):
def __init__(self, env):
self.parent = {}
self.children = {}
self.destroyers = {}
self.paths = {}
self.dups = set()
self.cycles = set()
self.illegal = set()
self.env = env
self.seen = set()
for input in env.inputs:
self.children[input] = set()
def __path__(self, r):
path = self.paths.get(r, None)
if path:
return path
rval = [r]
r = self.parent.get(r, None) ### ???
while r:
rval.append(r)
r = self.parent.get(r, None)
rval.reverse()
for i, x in enumerate(rval):
self.paths[x] = rval[0:i+1]
return rval
def __views__(self, r):
children = self.children[r]
if not children:
return set([r])
else:
rval = set([r])
for child in children:
rval.update(self.__views__(child))
return rval
def __users__(self, r):
views = self.__views__(r)
rval = set()
for view in views:
for op, i in self.env.clients(view):
if op in self.seen:
rval.update(op.outputs)
return rval
class Return(DummyOp): def __pre__(self, op):
""" rval = set()
Dummy op which represents the action of returning its input if op is None:
value to an end user. It "destroys" its input to prevent any return rval
other Op to overwrite it. keep_going = False
""" for input in op.inputs:
def destroy_map(self): return {self.out:[self.inputs[0]]} foundation = self.__path__(input)[0]
destroyers = self.destroyers.get(foundation, set())
if destroyers:
keep_going = True
if op in destroyers:
users = self.__users__(foundation)
rval.update(users)
rval.update(op.inputs)
rval.difference_update(op.outputs)
return rval
def __detect_cycles_helper__(self, r, seq):
if r in seq:
self.cycles.add(tuple(seq[seq.index(r):]))
return
pre = self.__pre__(r.owner)
for r2 in pre:
self.__detect_cycles_helper__(r2, seq + [r])
def mark_outputs_as_destroyed(outputs): def __detect_cycles__(self, start, just_remove=False):
return [Return(output).out for output in outputs] users = self.__users__(start)
users.add(start)
for user in users:
for cycle in copy(self.cycles):
if user in cycle:
self.cycles.remove(cycle)
if just_remove:
return
for user in users:
self.__detect_cycles_helper__(user, [])
def get_maps(self, op):
try: vmap = op.view_map()
except AttributeError, AbstractFunctionError: vmap = {}
try: dmap = op.destroy_map()
except AttributeError, AbstractFunctionError: dmap = {}
return vmap, dmap
def on_import(self, op):
self.seen.add(op)
view_map, destroy_map = self.get_maps(op)
for i, output in enumerate(op.outputs):
views = view_map.get(output, None)
destroyed = destroy_map.get(output, None)
if destroyed:
for input in destroyed:
path = self.__path__(input)
self.__add_destroyer__(path + [output])
####### self.__add_destroyer__(path + [op])
elif views:
if len(views) > 1:
raise Exception("Output is a view of too many inputs.")
self.parent[output] = views[0]
for input in views:
self.children[input].add(output)
self.children[output] = set()
for output in op.outputs:
self.__detect_cycles__(output)
def on_prune(self, op):
view_map, destroy_map = self.get_maps(op)
if destroy_map:
destroyers = []
for i, input in enumerate(op.inputs):
destroyers.append(self.destroyers.get(self.__path__(input)[0], {}))
for destroyer in destroyers:
path = destroyer.get(op, [])
if path:
self.__remove_destroyer__(path)
if view_map:
for i, input in enumerate(op.inputs):
self.children[input].difference_update(op.outputs)
for output in op.outputs:
try:
del self.paths[output]
except:
pass
self.__detect_cycles__(output, True)
for i, output in enumerate(op.outputs):
try:
self.parent[output]
del self.parent[output]
except:
pass
del self.children[output]
self.seen.remove(op)
def __add_destroyer__(self, path):
foundation = path[0]
target = path[-1]
op = target.owner
destroyers = self.destroyers.setdefault(foundation, {})
path = destroyers.setdefault(op, path)
if len(destroyers) > 1:
self.dups.add(foundation)
if getattr(foundation, 'indestructible', False):
self.illegal.add(foundation)
def __remove_destroyer__(self, path):
foundation = path[0]
target = path[-1]
op = target.owner
destroyers = self.destroyers[foundation]
del destroyers[op]
if not destroyers:
if foundation in self.illegal:
self.illegal.remove(foundation)
del self.destroyers[foundation]
elif len(destroyers) == 1 and foundation in self.dups:
self.dups.remove(foundation)
def on_rewire(self, clients, r_1, r_2):
path_1 = self.__path__(r_1)
path_2 = self.__path__(r_2)
prev = set()
for op, i in clients:
prev.update(op.outputs)
foundation = path_1[0]
destroyers = self.destroyers.get(foundation, {}).items()
for op, path in destroyers:
if r_1 in path:
idx = path.index(r_1)
self.__remove_destroyer__(path)
if not (idx > 0 and path[idx - 1] in prev):
continue
index = path.index(r_1)
new_path = path_2 + path[index+1:]
self.__add_destroyer__(new_path)
for op, i in clients:
view_map, _ = self.get_maps(op)
for output, inputs in view_map.items():
if r_2 in inputs:
assert self.parent.get(output, None) == r_1
self.parent[output] = r_2
self.children[r_1].remove(output)
self.children[r_2].add(output)
for view in self.__views__(r_1):
try:
del self.paths[view]
except:
pass
for view in self.__views__(r_2):
try:
del self.paths[view]
except:
pass
self.__detect_cycles__(r_1)
self.__detect_cycles__(r_2)
def validate(self):
if self.dups:
raise InconsistencyError("The following values are destroyed more than once: %s" % self.dups)
elif self.cycles:
raise InconsistencyError("There are cycles: %s" % self.cycles)
elif self.illegal:
raise InconsistencyError("Attempting to destroy indestructible results: %s" % self.illegal)
else:
return True
def orderings(self):
ords = {}
for foundation, destroyers in self.destroyers.items():
for op in destroyers.keys():
ords.setdefault(op, set()).update([user.owner for user in self.__users__(foundation) if user not in op.outputs])
return ords
class Destroyer:
def destroyed_inputs(self):
raise AbstractFunctionError()
def destroy_map(self):
# compatibility
return {self.out: self.destroyed_inputs()}
__env_require__ = DestroyHandler
class Viewer:
def view_map(self):
raise AbstractFunctionError()
def view_roots(self, output):
def helper(r):
"""Return the leaves of a search through consecutive view_map()s"""
owner = r.owner
if owner is not None:
try:
view_map = owner.view_map()
except AttributeError, AbstractFunctionError:
return []
if r in view_map:
answer = []
for r2 in view_map[r]:
answer.extend(helper(r2))
return answer
else:
return [r]
else:
return [r]
return helper(output)
# from copy import copy
# from op import Op
# import result
# import graph
import utils import utils
# from random import shuffle
__all__ = ['Feature', __all__ = ['Feature',
'Listener', 'Listener',
...@@ -15,10 +8,6 @@ __all__ = ['Feature', ...@@ -15,10 +8,6 @@ __all__ = ['Feature',
'Orderings', 'Orderings',
'Tool', 'Tool',
'uniq_features', 'uniq_features',
# 'EquivTool',
# 'InstanceFinder',
# 'PrintListener',
# 'ChangeListener',
] ]
...@@ -129,133 +118,3 @@ def uniq_features(_features, *_rest): ...@@ -129,133 +118,3 @@ def uniq_features(_features, *_rest):
return res return res
# MOVE TO LIB
# class EquivTool(Listener, Tool, dict):
# def on_rewire(self, clients, r, new_r):
# repl = self(new_r)
# if repl is r:
# self.ungroup(r, new_r)
# elif repl is not new_r:
# raise Exception("Improper use of EquivTool!")
# else:
# self.group(new_r, r)
# def publish(self):
# self.env.equiv = self
# def group(self, main, *keys):
# "Marks all the keys as having been replaced by the Result main."
# keys = [key for key in keys if key is not main]
# if self.has_key(main):
# raise Exception("Only group results that have not been grouped before.")
# for key in keys:
# if self.has_key(key):
# raise Exception("Only group results that have not been grouped before.")
# if key is main:
# continue
# self.setdefault(key, main)
# def ungroup(self, main, *keys):
# "Undoes group(main, *keys)"
# keys = [key for key in keys if key is not main]
# for key in keys:
# if self[key] is main:
# del self[key]
# def __call__(self, key):
# "Returns the currently active replacement for the given key."
# next = self.get(key, None)
# while next:
# key = next
# next = self.get(next, None)
# return key
# class InstanceFinder(Listener, Tool, dict):
# def __init__(self, env):
# self.env = env
# def all_bases(self, cls):
# return utils.all_bases(cls, lambda cls: issubclass(cls, Op))
# # return [cls for cls in utils.all_bases(cls) if issubclass(cls, Op)]
# def on_import(self, op):
# for base in self.all_bases(op.__class__):
# self.setdefault(base, set()).add(op)
# def on_prune(self, op):
# for base in self.all_bases(op.__class__):
# self[base].remove(op)
# if not self[base]:
# del self[base]
# def __query__(self, cls):
# all = [x for x in self.get(cls, [])]
# shuffle(all) # this helps a lot for debugging because the order of the replacements will vary
# while all:
# next = all.pop()
# if next in self.env.ops():
# yield next
# def query(self, cls):
# return self.__query__(cls)
# def publish(self):
# self.env.get_instances_of = self.query
# class PrintListener(Listener):
# def __init__(self, env, active = True):
# self.env = env
# self.active = active
# if active:
# print "-- initializing"
# def on_import(self, op):
# if self.active:
# print "-- importing: %s" % graph.as_string(self.env.inputs, op.outputs)
# def on_prune(self, op):
# if self.active:
# print "-- pruning: %s" % graph.as_string(self.env.inputs, op.outputs)
# def on_rewire(self, clients, r, new_r):
# if self.active:
# if r.owner is None:
# rg = id(r) #r.name
# else:
# rg = graph.as_string(self.env.inputs, r.owner.outputs)
# if new_r.owner is None:
# new_rg = id(new_r) #new_r.name
# else:
# new_rg = graph.as_string(self.env.inputs, new_r.owner.outputs)
# print "-- moving from %s to %s" % (rg, new_rg)
# class ChangeListener(Listener):
# def __init__(self, env):
# self.change = False
# def on_import(self, op):
# self.change = True
# def on_prune(self, op):
# self.change = True
# def on_rewire(self, clients, r, new_r):
# self.change = True
# def __call__(self, value = "get"):
# if value == "get":
# return self.change
# else:
# self.change = value
...@@ -103,251 +103,6 @@ class ForbidConstantOverwrite(features.Listener, features.Constraint): ...@@ -103,251 +103,6 @@ class ForbidConstantOverwrite(features.Listener, features.Constraint):
else: else:
return True return True
class DestroyHandler(features.Listener, features.Constraint, features.Orderings):
def __init__(self, env):
self.parent = {}
self.children = {}
self.destroyers = {}
self.paths = {}
self.dups = set()
self.cycles = set()
self.env = env
for input in env.inputs:
# self.parent[input] = None
self.children[input] = set()
def __path__(self, r):
path = self.paths.get(r, None)
if path:
return path
rval = [r]
r = self.parent.get(r, None) ### ???
while r:
rval.append(r)
r = self.parent.get(r, None)
rval.reverse()
for i, x in enumerate(rval):
self.paths[x] = rval[0:i+1]
return rval
def __views__(self, r):
children = self.children[r]
if not children:
return set([r])
else:
rval = set([r])
for child in children:
rval.update(self.__views__(child))
return rval
def __users__(self, r):
views = self.__views__(r)
rval = set()
for view in views:
for op, i in self.env.clients(view):
rval.update(op.outputs)
return rval
def __pre__(self, op):
rval = set()
if op is None:
return rval
keep_going = False
for input in op.inputs:
foundation = self.__path__(input)[0]
destroyers = self.destroyers.get(foundation, set())
if destroyers:
keep_going = True
if op in destroyers:
users = self.__users__(foundation)
rval.update(users)
# if not keep_going:
# return set()
rval.update(op.inputs)
rval.difference_update(op.outputs)
return rval
def __detect_cycles_helper__(self, r, seq):
# print "!! ", r, seq
if r in seq:
self.cycles.add(tuple(seq[seq.index(r):]))
return
pre = self.__pre__(r.owner)
for r2 in pre:
self.__detect_cycles_helper__(r2, seq + [r])
def __detect_cycles__(self, start, just_remove=False):
# print "!!! ", start
users = self.__users__(start)
users.add(start)
for user in users:
for cycle in copy(self.cycles):
if user in cycle:
self.cycles.remove(cycle)
if just_remove:
return
for user in users:
self.__detect_cycles_helper__(user, [])
def get_maps(self, op):
return op.view_map(), op.destroy_map()
def on_import(self, op):
view_map, destroy_map = self.get_maps(op)
# for input in op.inputs:
# self.parent.setdefault(input, None)
for i, output in enumerate(op.outputs):
views = view_map.get(output, None)
destroyed = destroy_map.get(output, None)
if destroyed:
# self.parent[output] = None
if is_result(destroyed):
destroyed = [destroyed]
for input in destroyed:
path = self.__path__(input)
self.__add_destroyer__(path + [output])
elif views:
if is_result(views):
views = [views]
if len(views) > 1: #views was inputs before?
raise Exception("Output is a view of too many inputs.")
self.parent[output] = views[0]
for input in views:
self.children[input].add(output)
# else:
# self.parent[output] = None
self.children[output] = set()
for output in op.outputs:
self.__detect_cycles__(output)
# if destroy_map:
# print "op: ", op
# print "ord: ", [str(x) for x in self.orderings()[op]]
# print
def on_prune(self, op):
view_map, destroy_map = self.get_maps(op)
if destroy_map:
destroyers = []
for i, input in enumerate(op.inputs):
destroyers.append(self.destroyers.get(self.__path__(input)[0], {}))
for destroyer in destroyers:
path = destroyer.get(op, [])
if path:
self.__remove_destroyer__(path)
if view_map:
for i, input in enumerate(op.inputs):
self.children[input].difference_update(op.outputs)
for output in op.outputs:
try:
del self.paths[output]
except:
pass
self.__detect_cycles__(output, True)
for i, output in enumerate(op.outputs):
try:
del self.parent[output]
except:
pass
del self.children[output]
def __add_destroyer__(self, path):
foundation = path[0]
target = path[-1]
op = target.owner
destroyers = self.destroyers.setdefault(foundation, {})
path = destroyers.setdefault(op, path)
if len(destroyers) > 1:
self.dups.add(foundation)
def __remove_destroyer__(self, path):
foundation = path[0]
target = path[-1]
op = target.owner
destroyers = self.destroyers[foundation]
del destroyers[op]
if not destroyers:
del self.destroyers[foundation]
elif len(destroyers) == 1 and foundation in self.dups:
self.dups.remove(foundation)
def on_rewire(self, clients, r_1, r_2):
path_1 = self.__path__(r_1)
path_2 = self.__path__(r_2)
prev = set()
for op, i in clients:
prev.update(op.outputs)
foundation = path_1[0]
destroyers = self.destroyers.get(foundation, {}).items()
for op, path in destroyers:
if r_1 in path:
idx = path.index(r_1)
self.__remove_destroyer__(path)
if not (idx > 0 and path[idx - 1] in prev):
continue
index = path.index(r_1)
new_path = path_2 + path[index+1:]
self.__add_destroyer__(new_path)
for op, i in clients:
view_map, _ = self.get_maps(op)
for output, inputs in view_map.items():
if r_2 in inputs:
assert self.parent.get(output, None) == r_1
self.parent[output] = r_2
self.children[r_1].remove(output)
self.children[r_2].add(output)
for view in self.__views__(r_1):
try:
del self.paths[view]
except:
pass
for view in self.__views__(r_2):
try:
del self.paths[view]
except:
pass
self.__detect_cycles__(r_1)
self.__detect_cycles__(r_2)
def validate(self):
if self.dups:
raise InconsistencyError("The following values are destroyed more than once: %s" % self.dups)
elif self.cycles:
raise InconsistencyError("There are cycles: %s" % self.cycles)
else:
return True
def orderings(self):
ords = {}
for foundation, destroyers in self.destroyers.items():
for op in destroyers.keys():
ords.setdefault(op, set()).update([user.owner for user in self.__users__(foundation) if user not in op.outputs])
return ords
class NewPythonOp(Op): class NewPythonOp(Op):
......
差异被折叠。
...@@ -3,7 +3,7 @@ from op import Op ...@@ -3,7 +3,7 @@ from op import Op
from env import InconsistencyError from env import InconsistencyError
import utils import utils
import unify import unify
import features import toolbox
import ext import ext
...@@ -51,21 +51,11 @@ class LocalOptimizer(Optimizer): ...@@ -51,21 +51,11 @@ class LocalOptimizer(Optimizer):
if env.has_op(op): if env.has_op(op):
self.apply_on_op(env, op) self.apply_on_op(env, op)
# no_change_listener = graph.changed is None
# while(True):
# exprs = self.candidates(graph)
# graph.changed(False)
# for expr in exprs:
# self.apply_on_op(graph, expr)
# if no_change_listener or graph.changed:
# break
# else:
# break
class OpSpecificOptimizer(LocalOptimizer): class OpSpecificOptimizer(LocalOptimizer):
__env_require__ = features.InstanceFinder __env_require__ = toolbox.InstanceFinder
opclass = Op opclass = Op
...@@ -77,10 +67,10 @@ class OpSpecificOptimizer(LocalOptimizer): ...@@ -77,10 +67,10 @@ class OpSpecificOptimizer(LocalOptimizer):
class OpSubOptimizer(Optimizer): class OpSubOptimizer(Optimizer):
__env_require__ = features.InstanceFinder __env_require__ = toolbox.InstanceFinder
def __init__(self, op1, op2): def __init__(self, op1, op2):
if not op1.has_default_output: if not op1._default_output_idx >= 0:
raise TypeError("OpSubOptimizer must be used with Op instances that have a default output.") raise TypeError("OpSubOptimizer must be used with Op instances that have a default output.")
# note: op2 must have the same input signature as op1 # note: op2 must have the same input signature as op1
self.op1 = op1 self.op1 = op1
...@@ -97,14 +87,14 @@ class OpSubOptimizer(Optimizer): ...@@ -97,14 +87,14 @@ class OpSubOptimizer(Optimizer):
r = r.out r = r.out
env.replace(op.out, r) env.replace(op.out, r)
except InconsistencyError, e: except InconsistencyError, e:
print "Warning: OpSubOpt failed to transform %s into %s: %s" % (op, self.op2, str(e)) # warning is for debug # print "Warning: OpSubOpt failed to transform %s into %s: %s" % (op, self.op2, str(e)) # warning is for debug
pass pass
class OpRemover(Optimizer): class OpRemover(Optimizer):
__env_require__ = features.InstanceFinder __env_require__ = toolbox.InstanceFinder
def __init__(self, opclass): def __init__(self, opclass):
self.opclass = opclass self.opclass = opclass
...@@ -118,7 +108,7 @@ class OpRemover(Optimizer): ...@@ -118,7 +108,7 @@ class OpRemover(Optimizer):
for input, output in zip(op.inputs, op.outputs): for input, output in zip(op.inputs, op.outputs):
env.replace(output, input) env.replace(output, input)
except InconsistencyError, e: except InconsistencyError, e:
print "Warning: OpRemover failed to remove %s: %s" % (op, str(e)) # warning is for debug # print "Warning: OpRemover failed to remove %s: %s" % (op, str(e)) # warning is for debug
pass pass
......
...@@ -11,8 +11,8 @@ from python25 import all ...@@ -11,8 +11,8 @@ from python25 import all
__all__ = ['is_result', __all__ = ['is_result',
'ResultBase', 'ResultBase',
'BrokenLink', # 'BrokenLink',
'BrokenLinkError', # 'BrokenLinkError',
'StateError', 'StateError',
'Empty', 'Empty',
'Allocated', 'Allocated',
...@@ -20,14 +20,14 @@ __all__ = ['is_result', ...@@ -20,14 +20,14 @@ __all__ = ['is_result',
] ]
class BrokenLink: # class BrokenLink:
"""The owner of a Result that was replaced by another Result""" # """The owner of a Result that was replaced by another Result"""
__slots__ = ['old_role'] # __slots__ = ['old_role']
def __init__(self, role): self.old_role = role # def __init__(self, role): self.old_role = role
def __nonzero__(self): return False # def __nonzero__(self): return False
class BrokenLinkError(Exception): # class BrokenLinkError(Exception):
"""The owner is a BrokenLink""" # """The owner is a BrokenLink"""
class StateError(Exception): class StateError(Exception):
"""The state of the Result is a problem""" """The state of the Result is a problem"""
...@@ -52,7 +52,7 @@ class ResultBase(object): ...@@ -52,7 +52,7 @@ class ResultBase(object):
"""Base class for storing Op inputs and outputs """Base class for storing Op inputs and outputs
Attributes: Attributes:
_role - None or (owner, index) or BrokenLink _role - None or (owner, index) #or BrokenLink
_data - anything _data - anything
constant - Boolean constant - Boolean
state - one of (Empty, Allocated, Computed) state - one of (Empty, Allocated, Computed)
...@@ -63,7 +63,7 @@ class ResultBase(object): ...@@ -63,7 +63,7 @@ class ResultBase(object):
owner - (ro) owner - (ro)
index - (ro) index - (ro)
data - (rw) : calls data_filter when setting data - (rw) : calls data_filter when setting
replaced - (rw) : True iff _role is BrokenLink # replaced - (rw) : True iff _role is BrokenLink
Methods: Methods:
alloc() - create storage in data, suitable for use by C ops. alloc() - create storage in data, suitable for use by C ops.
...@@ -74,15 +74,15 @@ class ResultBase(object): ...@@ -74,15 +74,15 @@ class ResultBase(object):
data_alloc data_alloc
Notes (from previous implementation): # Notes (from previous implementation):
A Result instance should be immutable: indeed, if some aspect of a # A Result instance should be immutable: indeed, if some aspect of a
Result is changed, operations that use it might suddenly become # Result is changed, operations that use it might suddenly become
invalid. Instead, a new Result instance should be instanciated # invalid. Instead, a new Result instance should be instanciated
with the correct properties and the invalidate method should be # with the correct properties and the invalidate method should be
called on the Result which is replaced (this will make its owner a # called on the Result which is replaced (this will make its owner a
BrokenLink instance, which behaves like False in conditional # BrokenLink instance, which behaves like False in conditional
expressions). # expressions).
""" """
...@@ -124,7 +124,7 @@ class ResultBase(object): ...@@ -124,7 +124,7 @@ class ResultBase(object):
def __get_owner(self): def __get_owner(self):
if self._role is None: return None if self._role is None: return None
if self.replaced: raise BrokenLinkError() # if self.replaced: raise BrokenLinkError()
return self._role[0] return self._role[0]
owner = property(__get_owner, owner = property(__get_owner,
...@@ -136,7 +136,7 @@ class ResultBase(object): ...@@ -136,7 +136,7 @@ class ResultBase(object):
def __get_index(self): def __get_index(self):
if self._role is None: return None if self._role is None: return None
if self.replaced: raise BrokenLinkError() # if self.replaced: raise BrokenLinkError()
return self._role[1] return self._role[1]
index = property(__get_index, index = property(__get_index,
...@@ -151,8 +151,8 @@ class ResultBase(object): ...@@ -151,8 +151,8 @@ class ResultBase(object):
return self._data[0] return self._data[0]
def __set_data(self, data): def __set_data(self, data):
if self.replaced: # if self.replaced:
raise BrokenLinkError() # raise BrokenLinkError()
if data is self._data[0]: if data is self._data[0]:
return return
if self.constant: if self.constant:
...@@ -212,17 +212,17 @@ class ResultBase(object): ...@@ -212,17 +212,17 @@ class ResultBase(object):
# replaced # replaced
# #
def __get_replaced(self): # def __get_replaced(self):
return isinstance(self._role, BrokenLink) # return isinstance(self._role, BrokenLink)
def __set_replaced(self, replace): # def __set_replaced(self, replace):
if replace == self.replaced: return # if replace == self.replaced: return
if replace: # if replace:
self._role = BrokenLink(self._role) # self._role = BrokenLink(self._role)
else: # else:
self._role = self._role.old_role # self._role = self._role.old_role
replaced = property(__get_replaced, __set_replaced, doc = "has this Result been replaced?") # replaced = property(__get_replaced, __set_replaced, doc = "has this Result been replaced?")
# #
...@@ -307,6 +307,24 @@ class ResultBase(object): ...@@ -307,6 +307,24 @@ class ResultBase(object):
return self.name or "<?>" return self.name or "<?>"
#
# same properties
#
# def __eq__(self, other):
# if self.state is not Computed:
# raise StateError("Can only compare computed results for equality.")
# if isinstance(other, Result):
# if other.state is not Computed:
# raise StateError("Can only compare computed results for equality.")
# return self.data == other.data
# else:
# return self.data == other
def same_properties(self, other):
raise AbstractFunction()
################# #################
# NumpyR Compatibility # NumpyR Compatibility
# #
......
from features import Listener, Tool
from random import shuffle
import utils
__all__ = ['EquivTool',
'InstanceFinder',
'PrintListener',
'ChangeListener',
]
class EquivTool(Listener, Tool, dict):
def on_rewire(self, clients, r, new_r):
repl = self(new_r)
if repl is r:
self.ungroup(r, new_r)
elif repl is not new_r:
raise Exception("Improper use of EquivTool!")
else:
self.group(new_r, r)
def publish(self):
self.env.equiv = self
def group(self, main, *keys):
"Marks all the keys as having been replaced by the Result main."
keys = [key for key in keys if key is not main]
if self.has_key(main):
raise Exception("Only group results that have not been grouped before.")
for key in keys:
if self.has_key(key):
raise Exception("Only group results that have not been grouped before.")
if key is main:
continue
self.setdefault(key, main)
def ungroup(self, main, *keys):
"Undoes group(main, *keys)"
keys = [key for key in keys if key is not main]
for key in keys:
if self[key] is main:
del self[key]
def __call__(self, key):
"Returns the currently active replacement for the given key."
next = self.get(key, None)
while next:
key = next
next = self.get(next, None)
return key
class InstanceFinder(Listener, Tool, dict):
def __init__(self, env):
self.env = env
def all_bases(self, cls):
return utils.all_bases(cls, lambda cls: cls is not object)
def on_import(self, op):
for base in self.all_bases(op.__class__):
self.setdefault(base, set()).add(op)
def on_prune(self, op):
for base in self.all_bases(op.__class__):
self[base].remove(op)
if not self[base]:
del self[base]
def __query__(self, cls):
all = [x for x in self.get(cls, [])]
shuffle(all) # this helps a lot for debugging because the order of the replacements will vary
while all:
next = all.pop()
if next in self.env.ops():
yield next
def query(self, cls):
return self.__query__(cls)
def publish(self):
self.env.get_instances_of = self.query
class PrintListener(Listener):
def __init__(self, env, active = True):
self.env = env
self.active = active
if active:
print "-- initializing"
def on_import(self, op):
if self.active:
print "-- importing: %s" % op
def on_prune(self, op):
if self.active:
print "-- pruning: %s" % op
def on_rewire(self, clients, r, new_r):
if self.active:
if r.owner is not None: r = r.owner
if new_r.owner is not None: new_r = new_r.owner
print "-- moving from %s to %s" % (r, new_r)
class ChangeListener(Listener):
def __init__(self, env):
self.change = False
def on_import(self, op):
self.change = True
def on_prune(self, op):
self.change = True
def on_rewire(self, clients, r, new_r):
self.change = True
def __call__(self, value = "get"):
if value == "get":
return self.change
else:
self.change = value
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论