提交 a5129f65 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

redone result, op, tests for result, op, graph

上级 4a6845b1
import unittest
from graph import *
from op import Op
from result import ResultBase, BrokenLinkError
class MyResult(ResultBase):
def __init__(self, thingy):
self.thingy = thingy
ResultBase.__init__(self, role = None, data = [self.thingy], constant = False)
def __eq__(self, other):
return isinstance(other, MyResult) and other.thingy == self.thingy
def __str__(self):
return str(self.thingy)
def __repr__(self):
return str(self.thingy)
class MyOp(Op):
def validate_update(self):
for input in self.inputs:
if not isinstance(input, MyResult):
raise Exception("Error 1")
self.outputs = [MyResult(sum([input.thingy for input in self.inputs]))]
class _test_inputs(unittest.TestCase):
def test_0(self):
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2)
assert inputs(op.outputs) == set([r1, r2])
def test_1(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5)
assert inputs(op2.outputs) == set([r1, r2, r5])
class _test_orphans(unittest.TestCase):
def test_0(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5)
assert orphans([r1, r2], op2.outputs) == set([r5])
class _test_as_string(unittest.TestCase):
leaf_formatter = str
node_formatter = lambda op, argstrings: "%s(%s)" % (op.__class__.__name__,
", ".join(argstrings))
def test_0(self):
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2)
assert as_string([r1, r2], op.outputs) == ["MyOp(1, 2)"]
def test_1(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5)
assert as_string([r1, r2, r5], op2.outputs) == ["MyOp(MyOp(1, 2), 5)"]
def test_2(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], op.outputs[0])
assert as_string([r1, r2, r5], op2.outputs) == ["MyOp(*1 -> MyOp(1, 2), *1)"]
def test_3(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], op.outputs[0])
assert as_string(op.outputs, op2.outputs) == ["MyOp(3, 3)"]
assert as_string(op2.inputs, op2.outputs) == ["MyOp(3, 3)"]
class _test_clone(unittest.TestCase):
def test_0(self):
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2)
new = clone([r1, r2], op.outputs)
assert as_string([r1, r2], new) == ["MyOp(1, 2)"]
def test_1(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5)
new = clone([r1, r2, r5], op2.outputs)
assert op2.outputs[0] == new[0] and op2.outputs[0] is not new[0]
assert op2 is not new[0].owner
assert new[0].owner.inputs[1] is r5
assert new[0].owner.inputs[0] == op.outputs[0] and new[0].owner.inputs[0] is not op.outputs[0]
def test_2(self):
"Checks that manipulating a cloned graph leaves the original unchanged."
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(MyOp(r1, r2).outputs[0], r5)
new = clone([r1, r2, r5], op.outputs)
new_op = new[0].owner
new_op.inputs = MyResult(7), MyResult(8)
assert as_string(inputs(new_op.outputs), new_op.outputs) == ["MyOp(7, 8)"]
assert as_string(inputs(op.outputs), op.outputs) == ["MyOp(MyOp(1, 2), 5)"]
if __name__ == '__main__':
unittest.main()
import unittest
from copy import copy
from op import *
from result import ResultBase, BrokenLinkError
class MyResult(ResultBase):
def __init__(self, thingy):
self.thingy = thingy
ResultBase.__init__(self, role = None, data = [self.thingy], constant = False)
def __eq__(self, other):
return isinstance(other, MyResult) and other.thingy == self.thingy
def __str__(self):
return str(self.thingy)
def __repr__(self):
return str(self.thingy)
class MyOp(Op):
def validate_update(self):
for input in self.inputs:
if not isinstance(input, MyResult):
raise Exception("Error 1")
self.outputs = [MyResult(sum([input.thingy for input in self.inputs]))]
class _test_Op(unittest.TestCase):
# Sanity tests
def test_sanity_0(self):
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2)
assert op.inputs == [r1, r2] # Are the inputs what I provided?
assert op.outputs == [MyResult(3)] # Are the outputs what I expect?
# validate_update
def test_validate_update(self):
try:
MyOp(ResultBase(), MyResult(1)) # MyOp requires MyResult instances
except Exception, e:
assert str(e) == "Error 1"
else:
raise Exception("Expected an exception")
# Setting inputs and outputs
def test_set_inputs(self):
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2)
r3 = op.outputs[0]
op.inputs = MyResult(4), MyResult(5)
assert op.outputs == [MyResult(9)] # check if the output changed to what I expect
assert r3.data is op.outputs[0].data # check if the data was properly transferred by set_output
def test_set_bad_inputs(self):
op = MyOp(MyResult(1), MyResult(2))
try:
op.inputs = MyResult(4), ResultBase()
except Exception, e:
assert str(e) == "Error 1"
else:
raise Exception("Expected an exception")
def test_set_outputs(self):
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2) # here we only make one output
try:
op.outputs = MyResult(10), MyResult(11) # setting two outputs should fail
except TypeError, e:
assert str(e) == "The new outputs must be exactly as many as the previous outputs."
else:
raise Exception("Expected an exception")
# Tests about broken links
def test_create_broken_link(self):
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2)
r3 = op.out
op.inputs = MyResult(3), MyResult(4)
assert r3 not in op.outputs
assert r3.replaced
def test_cannot_copy_when_input_is_broken_link(self):
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2)
r3 = op.out
op2 = MyOp(r3)
op.inputs = MyResult(3), MyResult(4)
try:
copy(op2)
except BrokenLinkError:
pass
else:
raise Exception("Expected an exception")
def test_get_input_broken_link(self):
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2)
r3 = op.out
op2 = MyOp(r3)
op.inputs = MyResult(3), MyResult(4)
try:
op2.get_input(0)
except BrokenLinkError:
pass
else:
raise Exception("Expected an exception")
def test_get_inputs_broken_link(self):
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2)
r3 = op.out
op2 = MyOp(r3)
op.inputs = MyResult(3), MyResult(4)
try:
op2.get_inputs()
except BrokenLinkError:
pass
else:
raise Exception("Expected an exception")
def test_repair_broken_link(self):
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2)
r3 = op.out
op2 = MyOp(r3, MyResult(10))
op.inputs = MyResult(3), MyResult(4)
op2.repair()
assert op2.outputs == [MyResult(17)]
# Tests about string representation
def test_create_broken_link(self):
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2)
assert str(op) == "MyOp(1, 2)"
if __name__ == '__main__':
unittest.main()
import unittest
from result import *
class _test_ResultBase(unittest.TestCase):
def test_0(self):
r = ResultBase()
def test_1(self):
r = ResultBase()
assert r.state is Empty
r.data = 0
assert r.data == 0
assert r.state is Computed
r.data = 1
assert r.data == 1
assert r.state is Computed
r.data = None
assert r.data == None
assert r.state is Empty
if __name__ == '__main__':
unittest.main()
...@@ -6,53 +6,21 @@ from utils import ClsInit ...@@ -6,53 +6,21 @@ from utils import ClsInit
from err import GofError, GofTypeError, PropagationError from err import GofError, GofTypeError, PropagationError
from op import Op from op import Op
from result import is_result from result import is_result
from features import Listener, Orderings, Constraint, Tool from features import Listener, Orderings, Constraint, Tool, uniq_features
import utils import utils
__all__ = ['InconsistencyError', __all__ = ['InconsistencyError',
'Env'] 'Env']
# class AliasDict(dict): class InconsistencyError(Exception):
# "Utility class to keep track of what result has been replaced with what result."
# def group(self, main, *keys):
# "Marks all the keys as having been replaced by the result main."
# keys = [key for key in keys if key is not main]
# if self.has_key(main):
# raise Exception("Only group results that have not been grouped before.")
# for key in keys:
# if self.has_key(key):
# raise Exception("Only group results that have not been grouped before.")
# if key is main:
# continue
# self.setdefault(key, main)
# def ungroup(self, main, *keys):
# "Undoes group(main, *keys)"
# keys = [key for key in keys if key is not main]
# for key in keys:
# if self[key] is main:
# del self[key]
# def __call__(self, key):
# "Returns the currently active replacement for the given key."
# next = self.get(key, None)
# while next:
# key = next
# next = self.get(next, None)
# return key
#TODO: why is this not in err.py? -James
class InconsistencyError(GofError):
""" """
This exception is raised by Env whenever one of the listeners marks This exception is raised by Env whenever one of the listeners marks
the graph as inconsistent. the graph as inconsistent.
""" """
pass pass
def require_set(cls): def require_set(cls):
"""Return the set of objects named in a __env_require__ field in a base class""" """Return the set of objects named in a __env_require__ field in a base class"""
r = set() r = set()
...@@ -70,22 +38,22 @@ def require_set(cls): ...@@ -70,22 +38,22 @@ def require_set(cls):
r.add(req) r.add(req)
return r return r
class Env(graph.Graph): class Env(graph.Graph):
""" """
An Env represents a subgraph bound by a set of input results and a set of output An Env represents a subgraph bound by a set of input results and a
results. An op is in the subgraph iff it depends on the value of some of the Env's set of output results. An op is in the subgraph iff it depends on
inputs _and_ some of the Env's outputs depend on it. A result is in the subgraph the value of some of the Env's inputs _and_ some of the Env's
iff it is an input or an output of an op that is in the subgraph. outputs depend on it. A result is in the subgraph iff it is an
input or an output of an op that is in the subgraph.
The Env supports the replace operation which allows to replace a result in the The Env supports the replace operation which allows to replace a
subgraph by another, e.g. replace (x + x).out by (2 * x).out. This is the basis result in the subgraph by another, e.g. replace (x + x).out by (2
for optimization in omega. * x).out. This is the basis for optimization in omega.
An Env can have listeners, which are instances of EnvListener. Each listener is An Env's functionality can be extended with features, which must
informed of any op entering or leaving the subgraph (which happens at construction be subclasses of L{Listener}, L{Constraint}, L{Orderings} or
time and whenever there is a replacement). In addition to that, each listener can L{Tool}.
implement the 'consistent' and 'ordering' methods (see EnvListener) in order to
restrict how ops in the subgraph can be related.
Regarding inputs and orphans: Regarding inputs and orphans:
In the context of a computation graph, the inputs and orphans are both In the context of a computation graph, the inputs and orphans are both
...@@ -93,7 +61,6 @@ class Env(graph.Graph): ...@@ -93,7 +61,6 @@ class Env(graph.Graph):
named as inputs will be assumed to contain fresh. In other words, the named as inputs will be assumed to contain fresh. In other words, the
backward search from outputs will stop at any node that has been explicitly backward search from outputs will stop at any node that has been explicitly
named as an input. named as an input.
""" """
### Special ### ### Special ###
...@@ -102,11 +69,6 @@ class Env(graph.Graph): ...@@ -102,11 +69,6 @@ class Env(graph.Graph):
""" """
Create an Env which operates on the subgraph bound by the inputs and outputs Create an Env which operates on the subgraph bound by the inputs and outputs
sets. If consistency_check is False, an illegal graph will be tolerated. sets. If consistency_check is False, an illegal graph will be tolerated.
Features are class types derived from things in the tools file. These
can be listeners, constraints, orderings, etc. Features add much
(most?) functionality to an Env.
""" """
self._features = {} self._features = {}
...@@ -114,31 +76,13 @@ class Env(graph.Graph): ...@@ -114,31 +76,13 @@ class Env(graph.Graph):
self._constraints = {} self._constraints = {}
self._orderings = {} self._orderings = {}
self._tools = {} self._tools = {}
# self._preprocessors = set()
# for feature in features:
# if issubclass(feature, tools.Preprocessor):
# preprocessor = feature()
# self._preprocessors.add(preprocessor)
# inputs, outputs = preprocessor.transform(inputs, outputs)
# The inputs and outputs set bound the subgraph this Env operates on. # The inputs and outputs set bound the subgraph this Env operates on.
self.inputs = set(inputs) self.inputs = set(inputs)
self.outputs = set(outputs) self.outputs = set(outputs)
for feature_class in utils.uniq_features(features): for feature_class in uniq_features(features):
self.add_feature(feature_class, False) self.add_feature(feature_class, False)
# feature = feature_class(self)
# if isinstance(feature, tools.Listener):
# self._listeners.add(feature)
# if isinstance(feature, tools.Constraint):
# self._constraints.add(feature)
# if isinstance(feature, tools.Orderings):
# self._orderings.add(feature)
# if isinstance(feature, tools.Tool):
# self._tools.add(feature)
# feature.publish()
# All ops in the subgraph defined by inputs and outputs are cached in _ops # All ops in the subgraph defined by inputs and outputs are cached in _ops
self._ops = set() self._ops = set()
...@@ -234,20 +178,6 @@ class Env(graph.Graph): ...@@ -234,20 +178,6 @@ class Env(graph.Graph):
except KeyError: except KeyError:
pass pass
# for i, feature in enumerate(self._features):
# if isinstance(feature, feature_class): # exact class or subclass, nothing to do
# return
# elif issubclass(feature_class, feature.__class__): # superclass, we replace it
# new_feature = feature_class(self)
# self._features[i] = new_feature
# break
# else:
# new_feature = feature_class(self)
# self._features.append(new_feature)
# if isinstance(new_feature, tools.Listener):
# for op in self.io_toposort():
# new_feature.on_import(op)
def get_feature(self, feature_class): def get_feature(self, feature_class):
try: try:
return self._features[feature_class] return self._features[feature_class]
...@@ -325,7 +255,7 @@ class Env(graph.Graph): ...@@ -325,7 +255,7 @@ class Env(graph.Graph):
self.outputs.add(new_r) self.outputs.add(new_r)
# The actual replacement operation occurs here. This might raise # The actual replacement operation occurs here. This might raise
# a GofTypeError # an error.
self.__move_clients__(clients, r, new_r) self.__move_clients__(clients, r, new_r)
# This function undoes the replacement. # This function undoes the replacement.
......
from copy import copy # from copy import copy
from op import Op # from op import Op
from lib import DummyOp # from lib import DummyOp
from features import Listener, Constraint, Orderings # from features import Listener, Constraint, Orderings
from env import InconsistencyError # from env import InconsistencyError
from utils import ClsInit # from utils import ClsInit
import graph # import graph
#TODO: move mark_outputs_as_destroyed to the place that uses this function __all__ = ['Destroyer', 'Viewer']
#TODO: move Return to where it is used.
__all__ = ['IONames', 'mark_outputs_as_destroyed']
class IONames:
"""
Requires assigning a name to each of this Op's inputs and outputs.
"""
__metaclass__ = ClsInit
input_names = ()
output_names = ()
@staticmethod
def __clsinit__(cls, name, bases, dct):
for names in ['input_names', 'output_names']:
if names in dct:
x = getattr(cls, names)
if isinstance(x, str):
x = [x,]
setattr(cls, names, x)
if isinstance(x, (list, tuple)):
x = [a for a in x]
setattr(cls, names, x)
for i, varname in enumerate(x):
if not isinstance(varname, str) or hasattr(cls, varname) or varname in ['inputs', 'outputs']:
raise TypeError("In %s: '%s' is not a valid input or output name" % (cls.__name__, varname))
# Set an attribute for the variable so we can do op.x to return the input or output named "x".
setattr(cls, varname,
property(lambda op, type=names.replace('_name', ''), index=i:
getattr(op, type)[index]))
else:
print 'ERROR: Class variable %s::%s is neither list, tuple, or string' % (name, names)
raise TypeError, str(names)
else:
setattr(cls, names, ())
# def __init__(self, inputs, outputs, use_self_setters = False):
# assert len(inputs) == len(self.input_names)
# assert len(outputs) == len(self.output_names)
# Op.__init__(self, inputs, outputs, use_self_setters)
def __validate__(self):
assert len(self.inputs) == len(self.input_names)
assert len(self.outputs) == len(self.output_names)
@classmethod
def n_inputs(cls):
return len(cls.input_names)
@classmethod
def n_outputs(cls):
return len(cls.output_names)
def get_by_name(self, name):
"""
Returns the input or output which corresponds to the given name.
"""
if name in self.input_names:
return self.input_names[self.input_names.index(name)]
elif name in self.output_names:
return self.output_names[self.output_names.index(name)]
else:
raise AttributeError("No such input or output name for %s: %s" % (self.__class__.__name__, name))
......
from copy import copy # from copy import copy
from op import Op # from op import Op
import result # import result
import graph # import graph
import utils import utils
from random import shuffle # from random import shuffle
__all__ = ['Feature', __all__ = ['Feature',
...@@ -13,267 +14,248 @@ __all__ = ['Feature', ...@@ -13,267 +14,248 @@ __all__ = ['Feature',
'Constraint', 'Constraint',
'Orderings', 'Orderings',
'Tool', 'Tool',
'EquivTool', 'uniq_features',
'InstanceFinder', # 'EquivTool',
'PrintListener', # 'InstanceFinder',
'ChangeListener', # 'PrintListener',
# 'ChangeListener',
] ]
class Feature(object): class Feature(object):
def __init__(self, env): def __init__(self, env):
"""
Initializes the Feature's env field to the parameter
provided.
"""
self.env = env self.env = env
class Listener(Feature): class Listener(Feature):
"""
When registered by an env, each listener is informed of any op
entering or leaving the subgraph (which happens at construction
time and whenever there is a replacement).
"""
def on_import(self, op): def on_import(self, op):
pass """
This method is called by the env whenever a new op is
added to the graph.
"""
raise utils.AbstractFunctionError()
def on_prune(self, op): def on_prune(self, op):
pass """
This method is called by the env whenever an op is
removed from the graph.
"""
raise utils.AbstractFunctionError()
def on_rewire(self, clients, r, new_r): def on_rewire(self, clients, r, new_r):
pass """
clients -> (op, i) pairs such that op.inputs[i] is new_r
but used to be r
r -> the old result that was used by the ops in clients
new_r -> the new result that is now used by the ops in clients
Note that the change from r to new_r is done before this
method is called.
"""
raise utils.AbstractFunctionError()
class Constraint(Feature): class Constraint(Feature):
"""
When registered by an env, a Constraint can restrict the ops that
can be in the subgraph or restrict the ways ops interact with each
other.
"""
def validate(self): def validate(self):
return True """
Raises an L{InconsistencyError} if the env is currently
invalid from the perspective of this object.
"""
raise utils.AbstractFunctionError()
class Orderings(Feature): class Orderings(Feature):
"""
When registered by an env, an Orderings object can provide supplemental
ordering constraints to the subgraph's topological sort.
"""
def orderings(self): def orderings(self):
return {} """
Returns {op: set(ops that must be evaluated before this op), ...}
This is called by env.orderings() and used in env.toposort() but
not in env.io_toposort().
"""
raise utils.AbstractFunctionError()
class Tool(Feature): class Tool(Feature):
"""
A Tool can extend the functionality of an env so that, for example,
optimizations can have access to efficient ways to search the graph.
"""
def publish(self): def publish(self):
pass """
This is only called once by the env, when the Tool is added.
Adds methods to env.
# class Preprocessor(Feature): """
raise utils.AbstractFunctionError()
# def transform(self, inputs, outputs):
# return inputs, outputs
# def __call__(self, inputs, outputs): def uniq_features(_features, *_rest):
# return self.transform(inputs, outputs) """Return a list such that no element is a subclass of another"""
# used in Env.__init__ to
features = [x for x in _features]
# class Optimization(object): for other in _rest:
features += [x for x in other]
# def require(self): res = []
# return [] while features:
feature = features.pop()
# def apply(self, env): for feature2 in features:
# pass if issubclass(feature2, feature):
break
# def __call__(self, env):
# return self.apply(env)
# Optimization
# * require <tool_class>*
# * apply
# Prog
# * __init__
# * inputs
# * outputs
# * listeners, constraints, orderings
# * dispatched by isinstance Listener, etc.
# * {tool_class: preferred_implementation, ...}
class EquivTool(Listener, Tool, dict):
def on_rewire(self, clients, r, new_r):
repl = self(new_r)
if repl is r:
self.ungroup(r, new_r)
elif repl is not new_r:
raise Exception("Improper use of EquivTool!")
else: else:
self.group(new_r, r) res.append(feature)
return res
def publish(self):
self.env.equiv = self
def group(self, main, *keys):
"Marks all the keys as having been replaced by the Result main."
keys = [key for key in keys if key is not main]
if self.has_key(main):
raise Exception("Only group results that have not been grouped before.")
for key in keys:
if self.has_key(key):
raise Exception("Only group results that have not been grouped before.")
if key is main:
continue
self.setdefault(key, main)
def ungroup(self, main, *keys):
"Undoes group(main, *keys)"
keys = [key for key in keys if key is not main]
for key in keys:
if self[key] is main:
del self[key]
def __call__(self, key):
"Returns the currently active replacement for the given key."
next = self.get(key, None)
while next:
key = next
next = self.get(next, None)
return key
class InstanceFinder(Listener, Tool, dict):
def __init__(self, env):
self.env = env
def all_bases(self, cls):
return utils.all_bases(cls, lambda cls: issubclass(cls, Op))
# return [cls for cls in utils.all_bases(cls) if issubclass(cls, Op)]
def on_import(self, op):
for base in self.all_bases(op.__class__):
self.setdefault(base, set()).add(op)
def on_prune(self, op):
for base in self.all_bases(op.__class__):
self[base].remove(op)
if not self[base]:
del self[base]
def __query__(self, cls):
all = [x for x in self.get(cls, [])]
shuffle(all) # this helps a lot for debugging because the order of the replacements will vary
while all:
next = all.pop()
if next in self.env.ops():
yield next
def query(self, cls):
return self.__query__(cls)
def publish(self):
self.env.get_instances_of = self.query
class PrintListener(Listener):
def __init__(self, env, active = True): # MOVE TO LIB
self.env = env
self.active = active
if active:
print "-- initializing"
def on_import(self, op):
if self.active:
print "-- importing: %s" % graph.as_string(self.env.inputs, op.outputs)
def on_prune(self, op):
if self.active:
print "-- pruning: %s" % graph.as_string(self.env.inputs, op.outputs)
def on_rewire(self, clients, r, new_r):
if self.active:
if r.owner is None:
rg = id(r) #r.name
else:
rg = graph.as_string(self.env.inputs, r.owner.outputs)
if new_r.owner is None:
new_rg = id(new_r) #new_r.name
else:
new_rg = graph.as_string(self.env.inputs, new_r.owner.outputs)
print "-- moving from %s to %s" % (rg, new_rg)
# class EquivTool(Listener, Tool, dict):
# def on_rewire(self, clients, r, new_r):
# repl = self(new_r)
# if repl is r:
# self.ungroup(r, new_r)
# elif repl is not new_r:
# raise Exception("Improper use of EquivTool!")
# else:
# self.group(new_r, r)
class ChangeListener(Listener): # def publish(self):
# self.env.equiv = self
def __init__(self, env):
self.change = False # def group(self, main, *keys):
# "Marks all the keys as having been replaced by the Result main."
def on_import(self, op): # keys = [key for key in keys if key is not main]
self.change = True # if self.has_key(main):
# raise Exception("Only group results that have not been grouped before.")
def on_prune(self, op): # for key in keys:
self.change = True # if self.has_key(key):
# raise Exception("Only group results that have not been grouped before.")
def on_rewire(self, clients, r, new_r): # if key is main:
self.change = True # continue
# self.setdefault(key, main)
# def ungroup(self, main, *keys):
# "Undoes group(main, *keys)"
# keys = [key for key in keys if key is not main]
# for key in keys:
# if self[key] is main:
# del self[key]
# def __call__(self, key):
# "Returns the currently active replacement for the given key."
# next = self.get(key, None)
# while next:
# key = next
# next = self.get(next, None)
# return key
# class InstanceFinder(Listener, Tool, dict):
# def __init__(self, env):
# self.env = env
def __call__(self, value = "get"): # def all_bases(self, cls):
if value == "get": # return utils.all_bases(cls, lambda cls: issubclass(cls, Op))
return self.change # # return [cls for cls in utils.all_bases(cls) if issubclass(cls, Op)]
else:
self.change = value
# def on_import(self, op):
# for base in self.all_bases(op.__class__):
# self.setdefault(base, set()).add(op)
# def on_prune(self, op):
# for base in self.all_bases(op.__class__):
# self[base].remove(op)
# if not self[base]:
# del self[base]
# def __query__(self, cls):
# all = [x for x in self.get(cls, [])]
# shuffle(all) # this helps a lot for debugging because the order of the replacements will vary
# while all:
# next = all.pop()
# if next in self.env.ops():
# yield next
# def query(self, cls):
# return self.__query__(cls)
# def publish(self):
# self.env.get_instances_of = self.query
# class SuperFinder(Listener, Tool, dict): # class PrintListener(Listener):
# def __init__(self, env, props): # def __init__(self, env, active = True):
# self.env = env # self.env = env
# self.props = props # self.active = active
# if active:
# print "-- initializing"
# def on_import(self, op): # def on_import(self, op):
# for prop, value in self.props(op).items(): # if self.active:
# self.setdefault(prop, {}).setdefault(value, set()).add(op) # print "-- importing: %s" % graph.as_string(self.env.inputs, op.outputs)
# def on_prune(self, op): # def on_prune(self, op):
# for prop, value in self.props(op).items(): # if self.active:
# self[prop][value].remove(op) # print "-- pruning: %s" % graph.as_string(self.env.inputs, op.outputs)
# if len(self[prop][value]) == 0:
# del self[prop][value]
# if len(self[prop]) == 0:
# del self[prop]
# def __query__(self, order, template):
# all = []
# for prop, value in template.items():
# all += [x for x in self.get(prop, {}).get(value, set())]
# # If not None, the order option requires the order listener to be included in the env under the name 'order'
# if order == 'o->i':
# all.sort(lambda op1, op2: self.env.order[op1].__cmp__(self.env.order[op2]))
# elif order == 'i->o':
# all.sort(lambda op1, op2: self.env.order[op2].__cmp__(self.env.order[op1]))
# while all:
# next = all.pop()
# if next in self.env.ops():
# yield next
# def query(self, **template): # def on_rewire(self, clients, r, new_r):
# return self.__query__(None, template) # if self.active:
# if r.owner is None:
# rg = id(r) #r.name
# else:
# rg = graph.as_string(self.env.inputs, r.owner.outputs)
# if new_r.owner is None:
# new_rg = id(new_r) #new_r.name
# else:
# new_rg = graph.as_string(self.env.inputs, new_r.owner.outputs)
# print "-- moving from %s to %s" % (rg, new_rg)
# def query_downstream(self, **template):
# return self.__query__('i->o', template)
# def query_upstream(self, **template):
# return self.__query__('o->i', template)
# def publish(self): # class ChangeListener(Listener):
# self.env.query = self.query
# def __init__(self, env):
# self.change = False
# def on_import(self, op):
# self.change = True
# def on_prune(self, op):
# self.change = True
# def on_rewire(self, clients, r, new_r):
# self.change = True
# def __call__(self, value = "get"):
# if value == "get":
# return self.change
# else:
# self.change = value
from copy import copy from copy import copy
from result import BrokenLink, BrokenLinkError
from op import Op
import utils import utils
...@@ -11,11 +9,17 @@ __all__ = ['inputs', ...@@ -11,11 +9,17 @@ __all__ = ['inputs',
'ops', 'ops',
'clone', 'clone_get_equiv', 'clone', 'clone_get_equiv',
'io_toposort', 'io_toposort',
'default_leaf_formatter', 'default_node_formatter',
'op_as_string',
'as_string', 'as_string',
'Graph'] 'Graph']
def inputs(o, repair = False): is_result = utils.attr_checker('owner', 'index')
is_op = utils.attr_checker('inputs', 'outputs')
def inputs(o):
""" """
o -> list of output Results o -> list of output Results
...@@ -24,21 +28,12 @@ def inputs(o, repair = False): ...@@ -24,21 +28,12 @@ def inputs(o, repair = False):
""" """
results = set() results = set()
def seek(r): def seek(r):
if isinstance(r, BrokenLink):
raise BrokenLinkError
op = r.owner op = r.owner
if op is None: if op is None:
results.add(r) results.add(r)
else: else:
for i in range(len(op.inputs)): for input in op.inputs:
try: seek(input)
seek(op.inputs[i])
except BrokenLinkError:
if repair:
op.refresh()
seek(op.inputs[i])
else:
raise
for output in o: for output in o:
seek(output) seek(output)
return results return results
...@@ -60,8 +55,6 @@ def results_and_orphans(i, o): ...@@ -60,8 +55,6 @@ def results_and_orphans(i, o):
incomplete_paths = [] incomplete_paths = []
def helper(r, path): def helper(r, path):
if isinstance(r, BrokenLink):
raise BrokenLinkError
if r in i: if r in i:
results.update(path) results.update(path)
elif r.owner is None: elif r.owner is None:
...@@ -128,45 +121,56 @@ def orphans(i, o): ...@@ -128,45 +121,56 @@ def orphans(i, o):
return results_and_orphans(i, o)[1] return results_and_orphans(i, o)[1]
def clone(i, o): def clone(i, o, copy_inputs = False):
""" """
i -> list of input Results i -> list of input Results
o -> list of output Results o -> list of output Results
copy_inputs -> if True, the inputs will be copied (defaults to False)
Copies the subgraph contained between i and o and returns the Copies the subgraph contained between i and o and returns the
outputs of that copy (corresponding to o). The input Results in outputs of that copy (corresponding to o).
the list are _not_ copied and the new graph refers to the
originals.
""" """
new_o, equiv = clone_get_equiv(i, o) equiv = clone_get_equiv(i, o)
return new_o return [equiv[output] for output in o]
def clone_get_equiv(i, o, copy_inputs = False): def clone_get_equiv(i, o, copy_inputs = False):
""" """
i -> list of input Results i -> list of input Results
o -> list of output Results o -> list of output Results
copy_inputs -> if True, the inputs will be replaced in the cloned
graph by copies available in the equiv dictionary
returned by the function (copy_inputs defaults to False)
Returns (new_o, equiv) where new_o are the outputs of a copy of Returns equiv a dictionary mapping each result and op in the
the whole subgraph bounded by i and o and equiv is a dictionary graph delimited by i and o to a copy (akin to deepcopy's memo).
that maps the original ops and results found in the subgraph to
their copy (akin to deepcopy's memo). See clone for more details.
""" """
d = {} d = {}
for op in ops(i, o): for input in i:
d[op] = copy(op) if copy_inputs:
d[input] = copy(input)
else:
d[input] = input
def clone_helper(result):
if result in d:
return d[result]
op = result.owner
if not op:
return result
else:
new_op = op.__class__(*[clone_helper(input) for input in op.inputs])
d[op] = new_op
for output, new_output in zip(op.outputs, new_op.outputs):
d[output] = new_output
return d[result]
for old_op, op in d.items(): for output in o:
for old_output, output in zip(old_op.outputs, op.outputs): clone_helper(output)
d[old_output] = output
for i, input in enumerate(op.inputs):
owner = input.owner
if owner in d:
op._inputs[i] = d[owner].outputs[input._index]
return [[d[output] for output in o], d] return d
def io_toposort(i, o, orderings = {}): def io_toposort(i, o, orderings = {}):
...@@ -188,17 +192,32 @@ def io_toposort(i, o, orderings = {}): ...@@ -188,17 +192,32 @@ def io_toposort(i, o, orderings = {}):
all = ops(i, o) all = ops(i, o)
for op in all: for op in all:
prereqs_d.setdefault(op, set()).update(set([input.owner for input in op.inputs if input.owner and input.owner in all])) prereqs_d.setdefault(op, set()).update(set([input.owner for input in op.inputs if input.owner and input.owner in all]))
# prereqs_d[op] = set([input.owner for input in op.inputs if input.owner and input.owner in all])
return utils.toposort(prereqs_d) return utils.toposort(prereqs_d)
def as_string(i, o): default_leaf_formatter = str
default_node_formatter = lambda op, argstrings: "%s(%s)" % (op.__class__.__name__,
", ".join(argstrings))
def op_as_string(i, op,
leaf_formatter = default_leaf_formatter,
node_formatter = default_node_formatter):
strs = as_string(i, op.inputs, leaf_formatter, node_formatter)
return node_formatter(op, strs)
def as_string(i, o,
leaf_formatter = default_leaf_formatter,
node_formatter = default_node_formatter):
""" """
i -> list of input Results i -> list of input Results
o -> list of output Results o -> list of output Results
leaf_formatter -> function that takes a result and returns a string to describe it
node_formatter -> function that takes an op and the list of strings corresponding
to its arguments and returns a string to describe it
Returns a string representation of the subgraph between i and o. If the same Returns a string representation of the subgraph between i and o. If the same
Op is used by several other ops, the first occurrence will be marked as op is used by several other ops, the first occurrence will be marked as
'*n -> description' and all subsequent occurrences will be marked as '*n', '*n -> description' and all subsequent occurrences will be marked as '*n',
where n is an id number (ids are attributed in an unspecified order and only where n is an id number (ids are attributed in an unspecified order and only
exist for viewing convenience). exist for viewing convenience).
...@@ -219,50 +238,37 @@ def as_string(i, o): ...@@ -219,50 +238,37 @@ def as_string(i, o):
done = set() done = set()
def multi_index(x): def multi_index(x):
try: return multi.index(x) + 1
return multi.index(x) + 1
except: def describe(r):
return 999 if r.owner is not None and r not in i:
op = r.owner
def describe(x, first = False): idx = op.outputs.index(r)
if isinstance(x, Result): if idx == op._default_output_idx:
done.add(x) idxs = ""
if x.owner is not None and x not in i:
op = x.owner
idx = op.outputs.index(x)
if idx:
s = describe(op, first) + "." + str(idx)
else:
s = describe(op, first)
return s
else: else:
return str(id(x)) idxs = "::%i" % idx
if op in done:
elif isinstance(x, Op): return "*%i%s" % (multi_index(x), idxs)
if x in done:
return "*%i" % multi_index(x)
else: else:
done.add(x) done.add(op)
if not first and hasattr(x, 'name') and x.name is not None: s = node_formatter(op, [describe(input) for input in op.inputs])
return x.name if op in multi:
s = x.__class__.__name__ + "(" + ", ".join([describe(v) for v in x.inputs]) + ")"
if x in multi:
return "*%i -> %s" % (multi_index(x), s) return "*%i -> %s" % (multi_index(x), s)
else: else:
return s return s
else: else:
raise TypeError("Cannot print type: %s" % x.__class__) return leaf_formatter(r)
return "[" + ", ".join([describe(x, True) for x in o]) + "]" return [describe(output) for output in o]
# Op.__str__ = lambda self: as_string(inputs(self.outputs), self.outputs)[1:-1]
# Result.__str__ = lambda self: as_string(inputs([self]), [self])[1:-1]
class Graph: class Graph:
"""
Object-oriented wrapper for all the functions in this module.
"""
def __init__(self, inputs, outputs): def __init__(self, inputs, outputs):
self.inputs = inputs self.inputs = inputs
......
...@@ -4,8 +4,9 @@ Contains the Op class, which is the base interface for all operations ...@@ -4,8 +4,9 @@ Contains the Op class, which is the base interface for all operations
compatible with gof's graph manipulation routines. compatible with gof's graph manipulation routines.
""" """
from result import BrokenLink from result import BrokenLinkError
from utils import ClsInit, all_bases, all_bases_collect from utils import ClsInit, all_bases, all_bases_collect, AbstractFunctionError
import graph
from copy import copy from copy import copy
...@@ -29,10 +30,6 @@ class Op(object): ...@@ -29,10 +30,6 @@ class Op(object):
__slots__ = ['_inputs', '_outputs'] __slots__ = ['_inputs', '_outputs']
#create inputs and outputs as read-only attributes
inputs = property(lambda self: self._inputs, doc = "The list of this Op's input Results.")
outputs = property(lambda self: self._outputs, doc = "The list of this Op's output Results.")
_default_output_idx = 0 _default_output_idx = 0
def default_output(self): def default_output(self):
...@@ -41,240 +38,469 @@ class Op(object): ...@@ -41,240 +38,469 @@ class Op(object):
return self.outputs[self._default_output_idx] return self.outputs[self._default_output_idx]
except (IndexError, TypeError): except (IndexError, TypeError):
raise AttributeError("Op does not have a default output.") raise AttributeError("Op does not have a default output.")
out = property(default_output, out = property(default_output,
doc = "Same as self.outputs[0] if this Op's has_default_output field is True.") doc = "Same as self.outputs[0] if this Op's has_default_output field is True.")
def __init__(self, inputs, outputs, use_self_setters = False):
"""
Initializes the '_inputs' and '_outputs' slots and sets the
owner of all outputs to self.
If use_self_setters is False, Op::set_input and Op::set_output
are used, which do the minimum checks and manipulations. Else,
the user defined set_input and set_output functions are
called (in any case, all inputs and outputs are initialized
to None).
"""
self._inputs = [None] * len(inputs)
self._outputs = [None] * len(outputs)
if use_self_setters:
for i, input in enumerate(inputs):
self.set_input(i, input, validate = False)
for i, output in enumerate(outputs):
self.set_output(i, output, validate = False)
self.validate()
else:
for i, input in enumerate(inputs):
Op.set_input(self, i, input, validate = False)
for i, output in enumerate(outputs):
Op.set_output(self, i, output, validate = False)
self.validate()
self.validate()
def __init__(self, *inputs):
self._inputs = None
self._outputs = None
self.__set_inputs(inputs)
self.validate_update()
def set_input(self, i, input, allow_changes = False, validate = True): def __get_input(self, i):
""" input = self._inputs[i]
Sets the ith input of self.inputs to input. i must be an if input.replaced:
integer in the range from 0 to len(self.inputs) - 1 and input raise BrokenLinkError()
must be a Result instance. The method may raise a GofTypeError return input
or a GofValueError accordingly to the semantics of the Op, if def __set_input(self, i, new_input):
the new input is of the wrong type or has the wrong self._inputs[i] = new_input
properties.
If i > len(self.inputs), an IndexError must be raised. If i == def __get_inputs(self):
len(self.inputs), it is allowed for the Op to extend the list for input in self._inputs:
of inputs if it is a vararg Op, else an IndexError should be if input.replaced:
raised. raise BrokenLinkError()
return self._inputs
For a vararg Op, it is also allowed to have the input def __set_inputs(self, new_inputs):
parameter set to None for 0 <= i < len(self.inputs), in which self._inputs = list(new_inputs)
case the rest of the inputs will be shifted left. In any other
situation, a ValueError should be raised.
def __get_output(self, i):
In some cases, set_input may change some outputs: for example, return self._outputs[i]
a change of an input from float to double might require the def __set_output(self, i, new_output):
output's type to also change from float to double. If old_output = self._outputs[i]
allow_changes is True, set_input is allowed to perform those if old_output != new_output:
changes and must return a list of pairs, each pair containing old_output.replaced = True
the old output and the output it was replaced with (they
_must_ be different Result instances). See Op::set_output for
important information about replacing outputs. If
allow_changes is False and some change in the outputs is
required for the change in input to be correct, a
PropagationError must be raised.
This default implementation sets the ith input to input and
changes no outputs. It returns None.
"""
previous = self.inputs[i]
self.inputs[i] = input
if validate:
try: try:
self.validate() # We try to reuse the old storage, if there is one
new_output.data = old_output.data
except: except:
# this call gives a subclass the chance to undo the set_outputs pass
# that it may have triggered... new_output.role = (self, i)
# TODO: test this functionality! self._outputs[i] = new_output
self.set_input(i, previous, True, False)
def __get_outputs(self):
return self._outputs
def __set_outputs(self, new_outputs):
if self._outputs is None:
for i, output in enumerate(new_outputs):
output.role = (self, i)
self._outputs = new_outputs
return
if len(self._outputs) != len(new_outputs):
raise TypeError("The new outputs must be exactly as many as the previous outputs.")
for i, new_output in enumerate(new_outputs):
self.__set_output(i, new_output)
def get_input(self, i):
return self.__get_input(i)
def set_input(self, i, new_input):
old_input = self.__get_input(i)
try:
self.__set_input(i, new_input)
self.validate_update()
except:
self.__set_input(i, old_input)
self.validate_update()
raise
def get_inputs(self):
return self.__get_inputs()
def set_inputs(self, new_inputs):
old_inputs = self.__get_inputs()
try:
self.__set_inputs(new_inputs)
self.validate_update()
except:
self._inputs = old_inputs
raise
def get_output(self, i):
return self.__get_output(i)
def get_outputs(self):
return self.__get_outputs()
#create inputs and outputs as read-only attributes
inputs = property(get_inputs, set_inputs, doc = "The list of this Op's input Results.")
outputs = property(get_outputs, __set_outputs, doc = "The list of this Op's output Results.")
def validate_update(self):
"""
(Abstract) This function must do two things:
* validate: check all the inputs in self.inputs to ensure
that they have the right type for this Op, etc.
If the validation fails, raise an exception.
* update: create output Results and set the Op's outputs
The existing outputs must not be changed in place.
The return value of validate_update is not used.
"""
raise AbstractFunctionError()
def set_output(self, i, output, validate = True): def repair(self):
""" """
Sets the ith output to output. The previous output, which is Repairs all the inputs that are broken links to use what
being replaced, must be invalidated using Result::invalidate. they were replaced with. Then, calls self.validate_update()
The new output must not already have an owner, or its owner must to validate the new inputs and make new outputs.
be self. It cannot be a broken link, unless it used to be at this
spot, in which case it can be reinstated.
For Ops that have vararg output lists, see the regulations in
Op::set_input.
""" """
if isinstance(output.owner, BrokenLink) \ changed = False
and output.owner.owner is self \ repaired_inputs = []
and output.owner.index == i: old_inputs = self._inputs
output.revalidate() for input in self._inputs:
else: if input.replaced:
output.set_owner(self, i) # this checks for an already existing owner changed = True
previous = self.outputs[i] role = input.role.old_role
if previous: input = role[0].outputs[role[1]]
previous.invalidate() repaired_inputs.append(input)
self.outputs[i] = output if changed:
if validate:
try: try:
self.validate() self.__set_inputs(repaired_inputs)
self.validate_update()
except: except:
self.set_output(i, previous, False) self._inputs = old_inputs
raise
return changed
def _dontuse_repair(self, allow_changes = False): #
# copy
#
def __copy__(self):
""" """
This function attempts to repair all inputs that are broken Shallow copy of this Op. The inputs are the exact same, but
links by calling set_input on the new Result that replaced the outputs are recreated because of the one-owner-per-result
them. Note that if a set_input operation invalidates one or policy.
more outputs, new broken links might appear in the other ops
that use this op's outputs.
It is possible that the new inputs are inconsistent with this
op, in which case an exception will be raised and the previous
inputs (and outputs) will be restored.
refresh returns a list of (old_output, new_output) pairs
detailing the changes, if any.
""" """
backtrack = [] return self.__class__(*self.inputs)
try:
for i, input in enumerate(self.inputs):
link = input.owner
if isinstance(link, BrokenLink):
current = link.owner.outputs[link.index]
dirt = self.set_input(i, current, allow_changes)
backtrack.append((i, input, dirt))
except:
# Restore the inputs and outputs that were successfully changed.
for i, input, dirt in backtrack:
self.inputs[i] = input
if dirt:
for old, new in dirt:
new.invalidate()
old.revalidate()
self.outputs[self.outputs.index(new)] = old
raise
all_dirt = []
for i, input, dirt in backtrack:
if dirt:
all_dirt += dirt
return all_dirt
#
# String representation
#
def __str__(self):
return graph.op_as_string(self.inputs, self)
def __repr__(self):
return str(self)
#
# perform
#
def perform(self): def perform(self):
""" """
Performs the computation on the inputs and stores the results (abstract) Performs the computation associated to this Op,
in the outputs. This function should check for the validity of places the result(s) in the output Results and gives them
the inputs and raise appropriate errors for debugging (for the Computed status.
executing without checks, override _perform).
An Op may define additional ways to perform the computation
that are more efficient (e.g. a piece of C code or a C struct
with direct references to the inputs and outputs), but
perform() should always be available in order to have a
consistent interface to execute graphs.
""" """
raise NotImplementedError raise AbstractFunctionError()
def _perform(self): #
# C code generators
#
def c_var_names(self):
""" """
Performs the computation on the inputs and stores the results Returns ([list of input names], [list of output names]) for
in the outputs, like perform(), but is not required to check use as C variables.
the existence or the validity of the inputs.
""" """
return self.perform() return [["i%i" % i for i in xrange(len(self.inputs))],
["o%i" % i for i in xrange(len(self.outputs))]]
@classmethod def c_validate(self):
def require(cls):
""" """
Returns a set of Feature subclasses that must be used by any Returns C code that checks that the inputs to this function
Env manipulating this kind of op. For instance, a Destroyer can be worked on. If a failure occurs, set an Exception
requires ext.DestroyHandler to guarantee that various and insert "%(fail)s".
destructive operations don't interfere.
You may use the variable names defined by c_var_names()
By default, this collates the __require__ field of this class
and the __require__ fields of all classes that are directly or
indirectly superclasses to this class into a set.
""" """
r = set() raise AbstractFunctionError()
bases = all_bases(cls, lambda cls: hasattr(cls, '__env_require__'))
for base in bases:
req = base.__env_require__
if isinstance(req, (list, tuple)):
r.update(req)
else:
r.add(req)
return r
def c_validate_cleanup(self):
"""
Clean up things allocated by c_validate().
"""
raise AbstractFunctionError()
def validate(self): def c_update(self):
""" """
This class's __validate__ function will be called, as well as Returns C code that allocates and/or updates the outputs
the __validate__ functions of all base classes down the class (eg resizing, etc.) so they can be manipulated safely
tree. If you do not want to execute __validate__ from the base by c_code.
classes, set the class variable __validate_override__ to True.
You may use the variable names defined by c_var_names()
""" """
vfns = all_bases_collect(self.__class__, 'validate') raise AbstractFunctionError()
for vfn in vfns:
vfn(self)
def c_update_cleanup(self):
"""
Clean up things allocated by c_update().
"""
raise AbstractFunctionError()
def __copy__(self): def c_code(self):
""" """
Copies the inputs list shallowly and copies all the outputs Returns C code that does the computation associated to this
because of the one owner per output restriction. Op. You may assume that input validation and output allocation
have already been done.
You may use the variable names defined by c_var_names()
""" """
new_inputs = copy(self.inputs) raise AbstractFunctionError()
# We copy the outputs because they are tied to a single Op.
new_outputs = [copy(output) for output in self.outputs] def c_code_cleanup(self):
op = self.__class__(new_inputs, new_outputs)
op._inputs = new_inputs
op._outputs = new_outputs
for i, output in enumerate(op.outputs):
# We adjust _owner and _index manually since the copies
# point to the previous op (self).
output._owner = op
output._index = i
return op
def __deepcopy__(self, memo):
""" """
Not implemented. Use gof.graph.clone(inputs, outputs) to copy Clean up things allocated by c_code().
a subgraph.
""" """
raise NotImplementedError("Use gof.graph.clone(inputs, outputs) to copy a subgraph.") raise AbstractFunctionError()
# def __init__(self, inputs, outputs, use_self_setters = False):
# """
# Initializes the '_inputs' and '_outputs' slots and sets the
# owner of all outputs to self.
# If use_self_setters is False, Op::set_input and Op::set_output
# are used, which do the minimum checks and manipulations. Else,
# the user defined set_input and set_output functions are
# called (in any case, all inputs and outputs are initialized
# to None).
# """
# self._inputs = [None] * len(inputs)
# self._outputs = [None] * len(outputs)
# if use_self_setters:
# for i, input in enumerate(inputs):
# self.set_input(i, input, validate = False)
# for i, output in enumerate(outputs):
# self.set_output(i, output, validate = False)
# self.validate()
# else:
# for i, input in enumerate(inputs):
# Op.set_input(self, i, input, validate = False)
# for i, output in enumerate(outputs):
# Op.set_output(self, i, output, validate = False)
# self.validate()
# self.validate()
# def set_input(self, i, input, allow_changes = False, validate = True):
# """
# Sets the ith input of self.inputs to input. i must be an
# integer in the range from 0 to len(self.inputs) - 1 and input
# must be a Result instance. The method may raise a GofTypeError
# or a GofValueError accordingly to the semantics of the Op, if
# the new input is of the wrong type or has the wrong
# properties.
# If i > len(self.inputs), an IndexError must be raised. If i ==
# len(self.inputs), it is allowed for the Op to extend the list
# of inputs if it is a vararg Op, else an IndexError should be
# raised.
# For a vararg Op, it is also allowed to have the input
# parameter set to None for 0 <= i < len(self.inputs), in which
# case the rest of the inputs will be shifted left. In any other
# situation, a ValueError should be raised.
# In some cases, set_input may change some outputs: for example,
# a change of an input from float to double might require the
# output's type to also change from float to double. If
# allow_changes is True, set_input is allowed to perform those
# changes and must return a list of pairs, each pair containing
# the old output and the output it was replaced with (they
# _must_ be different Result instances). See Op::set_output for
# important information about replacing outputs. If
# allow_changes is False and some change in the outputs is
# required for the change in input to be correct, a
# PropagationError must be raised.
# This default implementation sets the ith input to input and
# changes no outputs. It returns None.
# """
# previous = self.inputs[i]
# self.inputs[i] = input
# if validate:
# try:
# self.validate()
# except:
# # this call gives a subclass the chance to undo the set_outputs
# # that it may have triggered...
# # TODO: test this functionality!
# self.set_input(i, previous, True, False)
# def set_output(self, i, output, validate = True):
# """
# Sets the ith output to output. The previous output, which is
# being replaced, must be invalidated using Result::invalidate.
# The new output must not already have an owner, or its owner must
# be self. It cannot be a broken link, unless it used to be at this
# spot, in which case it can be reinstated.
# For Ops that have vararg output lists, see the regulations in
# Op::set_input.
# """
# if isinstance(output.owner, BrokenLink) \
# and output.owner.owner is self \
# and output.owner.index == i:
# output.revalidate()
# else:
# output.set_owner(self, i) # this checks for an already existing owner
# previous = self.outputs[i]
# if previous:
# previous.invalidate()
# self.outputs[i] = output
# if validate:
# try:
# self.validate()
# except:
# self.set_output(i, previous, False)
# def _dontuse_repair(self, allow_changes = False):
# """
# This function attempts to repair all inputs that are broken
# links by calling set_input on the new Result that replaced
# them. Note that if a set_input operation invalidates one or
# more outputs, new broken links might appear in the other ops
# that use this op's outputs.
# It is possible that the new inputs are inconsistent with this
# op, in which case an exception will be raised and the previous
# inputs (and outputs) will be restored.
# refresh returns a list of (old_output, new_output) pairs
# detailing the changes, if any.
# """
# backtrack = []
# try:
# for i, input in enumerate(self.inputs):
# link = input.owner
# if isinstance(link, BrokenLink):
# current = link.owner.outputs[link.index]
# dirt = self.set_input(i, current, allow_changes)
# backtrack.append((i, input, dirt))
# except:
# # Restore the inputs and outputs that were successfully changed.
# for i, input, dirt in backtrack:
# self.inputs[i] = input
# if dirt:
# for old, new in dirt:
# new.invalidate()
# old.revalidate()
# self.outputs[self.outputs.index(new)] = old
# raise
# all_dirt = []
# for i, input, dirt in backtrack:
# if dirt:
# all_dirt += dirt
# return all_dirt
# def perform(self):
# """
# Performs the computation on the inputs and stores the results
# in the outputs. This function should check for the validity of
# the inputs and raise appropriate errors for debugging (for
# executing without checks, override _perform).
# An Op may define additional ways to perform the computation
# that are more efficient (e.g. a piece of C code or a C struct
# with direct references to the inputs and outputs), but
# perform() should always be available in order to have a
# consistent interface to execute graphs.
# """
# raise NotImplementedError
# def _perform(self):
# """
# Performs the computation on the inputs and stores the results
# in the outputs, like perform(), but is not required to check
# the existence or the validity of the inputs.
# """
# return self.perform()
# @classmethod
# def require(cls):
# """
# Returns a set of Feature subclasses that must be used by any
# Env manipulating this kind of op. For instance, a Destroyer
# requires ext.DestroyHandler to guarantee that various
# destructive operations don't interfere.
# By default, this collates the __require__ field of this class
# and the __require__ fields of all classes that are directly or
# indirectly superclasses to this class into a set.
# """
# r = set()
# bases = all_bases(cls, lambda cls: hasattr(cls, '__env_require__'))
# for base in bases:
# req = base.__env_require__
# if isinstance(req, (list, tuple)):
# r.update(req)
# else:
# r.add(req)
# return r
# def validate(self):
# """
# This class's __validate__ function will be called, as well as
# the __validate__ functions of all base classes down the class
# tree. If you do not want to execute __validate__ from the base
# classes, set the class variable __validate_override__ to True.
# """
# vfns = all_bases_collect(self.__class__, 'validate')
# for vfn in vfns:
# vfn(self)
# def __copy__(self):
# """
# Copies the inputs list shallowly and copies all the outputs
# because of the one owner per output restriction.
# """
# new_inputs = copy(self.inputs)
# # We copy the outputs because they are tied to a single Op.
# new_outputs = [copy(output) for output in self.outputs]
# op = self.__class__(new_inputs, new_outputs)
# op._inputs = new_inputs
# op._outputs = new_outputs
# for i, output in enumerate(op.outputs):
# # We adjust _owner and _index manually since the copies
# # point to the previous op (self).
# output._owner = op
# output._index = i
# return op
# def __deepcopy__(self, memo):
# """
# Not implemented. Use gof.graph.clone(inputs, outputs) to copy
# a subgraph.
# """
# raise NotImplementedError("Use gof.graph.clone(inputs, outputs) to copy a subgraph.")
...@@ -5,37 +5,32 @@ value that is the input or the output of an Op. ...@@ -5,37 +5,32 @@ value that is the input or the output of an Op.
""" """
import unittest
from err import GofError
from utils import AbstractFunctionError from utils import AbstractFunctionError
from python25 import all from python25 import all
__all__ = ['is_result', 'ResultBase', 'BrokenLink', 'BrokenLinkError' ] __all__ = ['is_result',
'ResultBase',
'BrokenLink',
'BrokenLinkError',
'StateError',
'Empty',
'Allocated',
'Computed',
]
class BrokenLink: class BrokenLink:
""" """The owner of a Result that was replaced by another Result"""
This is placed as the owner of a Result that was replaced by __slots__ = ['old_role']
another Result. def __init__(self, role): self.old_role = role
""" def __nonzero__(self): return False
__slots__ = ['owner', 'index'] class BrokenLinkError(Exception):
"""The owner is a BrokenLink"""
def __init__(self, owner, index): class StateError(Exception):
self.owner = owner """The state of the Result is a problem"""
self.index = index
def __nonzero__(self):
return False
class BrokenLinkError(GofError):
"""
"""
pass
# ResultBase state keywords # ResultBase state keywords
...@@ -61,6 +56,7 @@ class ResultBase(object): ...@@ -61,6 +56,7 @@ class ResultBase(object):
_data - anything _data - anything
constant - Boolean constant - Boolean
state - one of (Empty, Allocated, Computed) state - one of (Empty, Allocated, Computed)
name - string
Properties: Properties:
role - (rw) role - (rw)
...@@ -89,33 +85,17 @@ class ResultBase(object): ...@@ -89,33 +85,17 @@ class ResultBase(object):
expressions). expressions).
""" """
class BrokenLink:
"""The owner of a Result that was replaced by another Result"""
__slots__ = ['old_role']
def __init__(self, role): self.old_role = role
def __nonzero__(self): return False
class BrokenLinkError(Exception):
"""The owner is a BrokenLink"""
class StateError(Exception): __slots__ = ['_role', 'constant', '_data', 'state', '_name']
"""The state of the Result is a problem"""
__slots__ = ['_role', 'constant', '_data', 'state'] def __init__(self, role=None, data=None, constant=False, name=None):
def __init__(self, role=None, data=None, constant=False):
self._role = role self._role = role
self.constant = constant
self._data = [None] self._data = [None]
if data is None: #None is not filtered self.state = Empty
self._data[0] = None self.constant = False
self.state = Empty self.__set_data(data)
else: self.constant = constant # can only lock data after setting it
try: self.name = name
self._data[0] = self.data_filter(data)
except AbstractFunctionError:
self._data[0] = data
self.state = Computed
# #
# role # role
...@@ -144,7 +124,7 @@ class ResultBase(object): ...@@ -144,7 +124,7 @@ class ResultBase(object):
def __get_owner(self): def __get_owner(self):
if self._role is None: return None if self._role is None: return None
if self.replaced: raise ResultBase.BrokenLinkError() if self.replaced: raise BrokenLinkError()
return self._role[0] return self._role[0]
owner = property(__get_owner, owner = property(__get_owner,
...@@ -156,7 +136,7 @@ class ResultBase(object): ...@@ -156,7 +136,7 @@ class ResultBase(object):
def __get_index(self): def __get_index(self):
if self._role is None: return None if self._role is None: return None
if self.replaced: raise ResultBase.BrokenLinkError() if self.replaced: raise BrokenLinkError()
return self._role[1] return self._role[1]
index = property(__get_index, index = property(__get_index,
...@@ -171,35 +151,39 @@ class ResultBase(object): ...@@ -171,35 +151,39 @@ class ResultBase(object):
return self._data[0] return self._data[0]
def __set_data(self, data): def __set_data(self, data):
if self.replaced: raise ResultBase.BrokenLinkError() if self.replaced:
if self.constant: raise Exception('cannot set constant ResultBase') raise BrokenLinkError()
if data is self._data[0]:
return
if self.constant:
raise Exception('cannot set constant ResultBase')
if data is None: if data is None:
self._data[0] = None self._data[0] = None
self.state = Empty self.state = Empty
return return
if data is self or data is self._data[0]: return
try: try:
self._data[0] = self.data_filter(data) self.validate(data)
except AbstractFunctionError: #use default behaviour except AbstractFunctionError:
self._data[0] = data pass
if isinstance(data, ResultBase): self._data[0] = data
raise Exception()
self.state = Computed self.state = Computed
data = property(__get_data, __set_data, data = property(__get_data, __set_data,
doc = "The storage associated with this result") doc = "The storage associated with this result")
def data_filter(self, data): def validate(self, data):
"""(abstract) Return an appropriate _data based on data. """(abstract) Raise an exception if the data is not of an
acceptable type.
If a subclass overrides this function, then that overriding If a subclass overrides this function, __set_data will use
implementation will be used in __set_data to map the argument to it to check that the argument can be used properly. This gives
self._data. This gives a subclass the opportunity to ensure that a subclass the opportunity to ensure that the contents of
the contents of self._data remain sensible. self._data remain sensible.
""" """
raise AbstractFunctionError() raise AbstractFunctionError()
# #
# alloc # alloc
# #
...@@ -229,19 +213,100 @@ class ResultBase(object): ...@@ -229,19 +213,100 @@ class ResultBase(object):
# #
def __get_replaced(self): def __get_replaced(self):
return isinstance(self._role, ResultBase.BrokenLink) return isinstance(self._role, BrokenLink)
def __set_replaced(self, replace): def __set_replaced(self, replace):
if replace == self.replaced: return if replace == self.replaced: return
if replace: if replace:
self._role = ResultBase.BrokenLink(self._role) self._role = BrokenLink(self._role)
else: else:
self._role = self._role.old_role self._role = self._role.old_role
replaced = property(__get_replaced, __set_replaced, doc = "has this Result been replaced?") replaced = property(__get_replaced, __set_replaced, doc = "has this Result been replaced?")
#
# C code generators
#
def c_extract(self):
get_from_list = """
PyObject* py_%(name)s = PyList_GET_ITEM(%(name)s_storage, 0);
Py_XINCREF(py_%(name)s);
"""
return self.c_data_extract() + get_from_list
def c_data_extract(self):
"""
The code returned from this function must be templated using "%(name)s",
representing the name that the caller wants to call this Result.
The Python object self.data is in a variable called "py_%(name)s" and
this code must declare a variable named "%(name)s" of a type appropriate
to manipulate from C. Additional variables and typedefs can be produced.
If the data is improper, set an appropriate error message and insert
"%(fail)s".
"""
raise AbstractFunction()
def c_sync(self, var_name):
set_in_list = """
PyList_SET_ITEM(%(name)s_storage, 0, py_%(name)s);
Py_XDECREF(py_%(name)s);
"""
return self.c_data_sync() + set_in_list
def c_data_sync(self):
"""
The code returned from this function must be templated using "%(name)s",
representing the name that the caller wants to call this Result.
The returned code may set "py_%(name)s" to a PyObject* and that PyObject*
will be accessible from Python via result.data. Do not forget to adjust
reference counts if "py_%(name)s" is changed from its original value!
"""
raise AbstractFunction()
#
# name
#
def __get_name(self):
if self._name:
return self._name
elif self._role:
return "%s.%i" % (self.owner.__class__, self.owner.outputs.index(self))
else:
return None
def __set_name(self, name):
if name is not None and not isinstance(name, str):
raise TypeError("Name is expected to be a string, or None.")
self._name = name
name = property(__get_name, __set_name,
doc = "Name of the Result.")
#
# String representation
#
def __str__(self):
name = self.name
if name:
if self.state is Computed:
return name + ":" + str(self.data)
else:
return name
elif self.state is Computed:
return str(self.data)
else:
return "<?>"
def __repr__(self):
return self.name or "<?>"
################# #################
# NumpyR Compatibility # NumpyR Compatibility
# #
...@@ -252,26 +317,26 @@ class ResultBase(object): ...@@ -252,26 +317,26 @@ class ResultBase(object):
def set_value(self, value): def set_value(self, value):
self.data = value #may raise exception self.data = value #may raise exception
class _test_ResultBase(unittest.TestCase):
def test_0(self):
r = ResultBase()
def test_1(self):
r = ResultBase()
assert r.state is Empty
r.data = 0
assert r.data == 0
assert r.state is Computed
r.data = 1
assert r.data == 1
assert r.state is Computed
r.data = None
assert r.data == None
assert r.state is Empty
if __name__ == '__main__':
unittest.main()
# def c_data_extract(self):
# return """
# PyArrayObject* %%(name)s;
# if (py_%%(name)s == Py_None)
# %%(name)s = NULL;
# else
# %%(name)s = (PyArrayObject*)(py_%%(name)s);
# typedef %(dtype)s %%(name)s_dtype;
# """ % dict(dtype = self.dtype)
# def c_data_sync(self):
# return """
# if (!%(name)s) {
# Py_XDECREF(py_%(name));
# py_%(name)s = Py_None;
# }
# else if ((void*)py_%(name)s != (void*)%(name)s) {
# Py_XDECREF(py_%(name));
# py_%(name)s = (PyObject*)%(name)s;
# }
# """
...@@ -12,6 +12,15 @@ class AbstractFunctionError(Exception): ...@@ -12,6 +12,15 @@ class AbstractFunctionError(Exception):
function has been left out of an implementation class. function has been left out of an implementation class.
""" """
def attr_checker(*attrs):
def f(candidate):
for attr in attrs:
if not hasattr(candidate, attr):
return False
return True
f.__doc__ = "Checks that the candidate has the following attributes: %s" % ", ".join(["'%s'"%attr for attr in attrs])
return f
def all_bases(cls, accept): def all_bases(cls, accept):
rval = set([cls]) rval = set([cls])
...@@ -34,21 +43,6 @@ def all_bases_collect(cls, raw_name): ...@@ -34,21 +43,6 @@ def all_bases_collect(cls, raw_name):
def uniq_features(_features, *_rest):
"""Return a list such that no element is a subclass of another"""
# used in Env.__init__ to
features = [x for x in _features]
for other in _rest:
features += [x for x in other]
res = []
while features:
feature = features.pop()
for feature2 in features:
if issubclass(feature2, feature):
break
else:
res.append(feature)
return res
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论