提交 a5129f65 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

redone result, op, tests for result, op, graph

上级 4a6845b1
import unittest
from graph import *
from op import Op
from result import ResultBase, BrokenLinkError
class MyResult(ResultBase):
def __init__(self, thingy):
self.thingy = thingy
ResultBase.__init__(self, role = None, data = [self.thingy], constant = False)
def __eq__(self, other):
return isinstance(other, MyResult) and other.thingy == self.thingy
def __str__(self):
return str(self.thingy)
def __repr__(self):
return str(self.thingy)
class MyOp(Op):
def validate_update(self):
for input in self.inputs:
if not isinstance(input, MyResult):
raise Exception("Error 1")
self.outputs = [MyResult(sum([input.thingy for input in self.inputs]))]
class _test_inputs(unittest.TestCase):
def test_0(self):
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2)
assert inputs(op.outputs) == set([r1, r2])
def test_1(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5)
assert inputs(op2.outputs) == set([r1, r2, r5])
class _test_orphans(unittest.TestCase):
def test_0(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5)
assert orphans([r1, r2], op2.outputs) == set([r5])
class _test_as_string(unittest.TestCase):
leaf_formatter = str
node_formatter = lambda op, argstrings: "%s(%s)" % (op.__class__.__name__,
", ".join(argstrings))
def test_0(self):
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2)
assert as_string([r1, r2], op.outputs) == ["MyOp(1, 2)"]
def test_1(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5)
assert as_string([r1, r2, r5], op2.outputs) == ["MyOp(MyOp(1, 2), 5)"]
def test_2(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], op.outputs[0])
assert as_string([r1, r2, r5], op2.outputs) == ["MyOp(*1 -> MyOp(1, 2), *1)"]
def test_3(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], op.outputs[0])
assert as_string(op.outputs, op2.outputs) == ["MyOp(3, 3)"]
assert as_string(op2.inputs, op2.outputs) == ["MyOp(3, 3)"]
class _test_clone(unittest.TestCase):
def test_0(self):
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2)
new = clone([r1, r2], op.outputs)
assert as_string([r1, r2], new) == ["MyOp(1, 2)"]
def test_1(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5)
new = clone([r1, r2, r5], op2.outputs)
assert op2.outputs[0] == new[0] and op2.outputs[0] is not new[0]
assert op2 is not new[0].owner
assert new[0].owner.inputs[1] is r5
assert new[0].owner.inputs[0] == op.outputs[0] and new[0].owner.inputs[0] is not op.outputs[0]
def test_2(self):
"Checks that manipulating a cloned graph leaves the original unchanged."
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(MyOp(r1, r2).outputs[0], r5)
new = clone([r1, r2, r5], op.outputs)
new_op = new[0].owner
new_op.inputs = MyResult(7), MyResult(8)
assert as_string(inputs(new_op.outputs), new_op.outputs) == ["MyOp(7, 8)"]
assert as_string(inputs(op.outputs), op.outputs) == ["MyOp(MyOp(1, 2), 5)"]
if __name__ == '__main__':
unittest.main()
import unittest
from copy import copy
from op import *
from result import ResultBase, BrokenLinkError
class MyResult(ResultBase):
def __init__(self, thingy):
self.thingy = thingy
ResultBase.__init__(self, role = None, data = [self.thingy], constant = False)
def __eq__(self, other):
return isinstance(other, MyResult) and other.thingy == self.thingy
def __str__(self):
return str(self.thingy)
def __repr__(self):
return str(self.thingy)
class MyOp(Op):
def validate_update(self):
for input in self.inputs:
if not isinstance(input, MyResult):
raise Exception("Error 1")
self.outputs = [MyResult(sum([input.thingy for input in self.inputs]))]
class _test_Op(unittest.TestCase):
# Sanity tests
def test_sanity_0(self):
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2)
assert op.inputs == [r1, r2] # Are the inputs what I provided?
assert op.outputs == [MyResult(3)] # Are the outputs what I expect?
# validate_update
def test_validate_update(self):
try:
MyOp(ResultBase(), MyResult(1)) # MyOp requires MyResult instances
except Exception, e:
assert str(e) == "Error 1"
else:
raise Exception("Expected an exception")
# Setting inputs and outputs
def test_set_inputs(self):
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2)
r3 = op.outputs[0]
op.inputs = MyResult(4), MyResult(5)
assert op.outputs == [MyResult(9)] # check if the output changed to what I expect
assert r3.data is op.outputs[0].data # check if the data was properly transferred by set_output
def test_set_bad_inputs(self):
op = MyOp(MyResult(1), MyResult(2))
try:
op.inputs = MyResult(4), ResultBase()
except Exception, e:
assert str(e) == "Error 1"
else:
raise Exception("Expected an exception")
def test_set_outputs(self):
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2) # here we only make one output
try:
op.outputs = MyResult(10), MyResult(11) # setting two outputs should fail
except TypeError, e:
assert str(e) == "The new outputs must be exactly as many as the previous outputs."
else:
raise Exception("Expected an exception")
# Tests about broken links
def test_create_broken_link(self):
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2)
r3 = op.out
op.inputs = MyResult(3), MyResult(4)
assert r3 not in op.outputs
assert r3.replaced
def test_cannot_copy_when_input_is_broken_link(self):
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2)
r3 = op.out
op2 = MyOp(r3)
op.inputs = MyResult(3), MyResult(4)
try:
copy(op2)
except BrokenLinkError:
pass
else:
raise Exception("Expected an exception")
def test_get_input_broken_link(self):
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2)
r3 = op.out
op2 = MyOp(r3)
op.inputs = MyResult(3), MyResult(4)
try:
op2.get_input(0)
except BrokenLinkError:
pass
else:
raise Exception("Expected an exception")
def test_get_inputs_broken_link(self):
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2)
r3 = op.out
op2 = MyOp(r3)
op.inputs = MyResult(3), MyResult(4)
try:
op2.get_inputs()
except BrokenLinkError:
pass
else:
raise Exception("Expected an exception")
def test_repair_broken_link(self):
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2)
r3 = op.out
op2 = MyOp(r3, MyResult(10))
op.inputs = MyResult(3), MyResult(4)
op2.repair()
assert op2.outputs == [MyResult(17)]
# Tests about string representation
def test_create_broken_link(self):
r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2)
assert str(op) == "MyOp(1, 2)"
if __name__ == '__main__':
unittest.main()
import unittest
from result import *
class _test_ResultBase(unittest.TestCase):
def test_0(self):
r = ResultBase()
def test_1(self):
r = ResultBase()
assert r.state is Empty
r.data = 0
assert r.data == 0
assert r.state is Computed
r.data = 1
assert r.data == 1
assert r.state is Computed
r.data = None
assert r.data == None
assert r.state is Empty
if __name__ == '__main__':
unittest.main()
......@@ -6,53 +6,21 @@ from utils import ClsInit
from err import GofError, GofTypeError, PropagationError
from op import Op
from result import is_result
from features import Listener, Orderings, Constraint, Tool
from features import Listener, Orderings, Constraint, Tool, uniq_features
import utils
__all__ = ['InconsistencyError',
'Env']
# class AliasDict(dict):
# "Utility class to keep track of what result has been replaced with what result."
# def group(self, main, *keys):
# "Marks all the keys as having been replaced by the result main."
# keys = [key for key in keys if key is not main]
# if self.has_key(main):
# raise Exception("Only group results that have not been grouped before.")
# for key in keys:
# if self.has_key(key):
# raise Exception("Only group results that have not been grouped before.")
# if key is main:
# continue
# self.setdefault(key, main)
# def ungroup(self, main, *keys):
# "Undoes group(main, *keys)"
# keys = [key for key in keys if key is not main]
# for key in keys:
# if self[key] is main:
# del self[key]
# def __call__(self, key):
# "Returns the currently active replacement for the given key."
# next = self.get(key, None)
# while next:
# key = next
# next = self.get(next, None)
# return key
#TODO: why is this not in err.py? -James
class InconsistencyError(GofError):
class InconsistencyError(Exception):
"""
This exception is raised by Env whenever one of the listeners marks
the graph as inconsistent.
"""
pass
def require_set(cls):
"""Return the set of objects named in a __env_require__ field in a base class"""
r = set()
......@@ -70,22 +38,22 @@ def require_set(cls):
r.add(req)
return r
class Env(graph.Graph):
"""
An Env represents a subgraph bound by a set of input results and a set of output
results. An op is in the subgraph iff it depends on the value of some of the Env's
inputs _and_ some of the Env's outputs depend on it. A result is in the subgraph
iff it is an input or an output of an op that is in the subgraph.
An Env represents a subgraph bound by a set of input results and a
set of output results. An op is in the subgraph iff it depends on
the value of some of the Env's inputs _and_ some of the Env's
outputs depend on it. A result is in the subgraph iff it is an
input or an output of an op that is in the subgraph.
The Env supports the replace operation which allows to replace a result in the
subgraph by another, e.g. replace (x + x).out by (2 * x).out. This is the basis
for optimization in omega.
The Env supports the replace operation which allows to replace a
result in the subgraph by another, e.g. replace (x + x).out by (2
* x).out. This is the basis for optimization in omega.
An Env can have listeners, which are instances of EnvListener. Each listener is
informed of any op entering or leaving the subgraph (which happens at construction
time and whenever there is a replacement). In addition to that, each listener can
implement the 'consistent' and 'ordering' methods (see EnvListener) in order to
restrict how ops in the subgraph can be related.
An Env's functionality can be extended with features, which must
be subclasses of L{Listener}, L{Constraint}, L{Orderings} or
L{Tool}.
Regarding inputs and orphans:
In the context of a computation graph, the inputs and orphans are both
......@@ -93,7 +61,6 @@ class Env(graph.Graph):
named as inputs will be assumed to contain fresh. In other words, the
backward search from outputs will stop at any node that has been explicitly
named as an input.
"""
### Special ###
......@@ -102,11 +69,6 @@ class Env(graph.Graph):
"""
Create an Env which operates on the subgraph bound by the inputs and outputs
sets. If consistency_check is False, an illegal graph will be tolerated.
Features are class types derived from things in the tools file. These
can be listeners, constraints, orderings, etc. Features add much
(most?) functionality to an Env.
"""
self._features = {}
......@@ -114,31 +76,13 @@ class Env(graph.Graph):
self._constraints = {}
self._orderings = {}
self._tools = {}
# self._preprocessors = set()
# for feature in features:
# if issubclass(feature, tools.Preprocessor):
# preprocessor = feature()
# self._preprocessors.add(preprocessor)
# inputs, outputs = preprocessor.transform(inputs, outputs)
# The inputs and outputs set bound the subgraph this Env operates on.
self.inputs = set(inputs)
self.outputs = set(outputs)
for feature_class in utils.uniq_features(features):
for feature_class in uniq_features(features):
self.add_feature(feature_class, False)
# feature = feature_class(self)
# if isinstance(feature, tools.Listener):
# self._listeners.add(feature)
# if isinstance(feature, tools.Constraint):
# self._constraints.add(feature)
# if isinstance(feature, tools.Orderings):
# self._orderings.add(feature)
# if isinstance(feature, tools.Tool):
# self._tools.add(feature)
# feature.publish()
# All ops in the subgraph defined by inputs and outputs are cached in _ops
self._ops = set()
......@@ -234,20 +178,6 @@ class Env(graph.Graph):
except KeyError:
pass
# for i, feature in enumerate(self._features):
# if isinstance(feature, feature_class): # exact class or subclass, nothing to do
# return
# elif issubclass(feature_class, feature.__class__): # superclass, we replace it
# new_feature = feature_class(self)
# self._features[i] = new_feature
# break
# else:
# new_feature = feature_class(self)
# self._features.append(new_feature)
# if isinstance(new_feature, tools.Listener):
# for op in self.io_toposort():
# new_feature.on_import(op)
def get_feature(self, feature_class):
try:
return self._features[feature_class]
......@@ -325,7 +255,7 @@ class Env(graph.Graph):
self.outputs.add(new_r)
# The actual replacement operation occurs here. This might raise
# a GofTypeError
# an error.
self.__move_clients__(clients, r, new_r)
# This function undoes the replacement.
......
from copy import copy
from op import Op
from lib import DummyOp
from features import Listener, Constraint, Orderings
from env import InconsistencyError
from utils import ClsInit
import graph
# from copy import copy
# from op import Op
# from lib import DummyOp
# from features import Listener, Constraint, Orderings
# from env import InconsistencyError
# from utils import ClsInit
# import graph
#TODO: move mark_outputs_as_destroyed to the place that uses this function
#TODO: move Return to where it is used.
__all__ = ['IONames', 'mark_outputs_as_destroyed']
__all__ = ['Destroyer', 'Viewer']
class IONames:
"""
Requires assigning a name to each of this Op's inputs and outputs.
"""
__metaclass__ = ClsInit
input_names = ()
output_names = ()
@staticmethod
def __clsinit__(cls, name, bases, dct):
for names in ['input_names', 'output_names']:
if names in dct:
x = getattr(cls, names)
if isinstance(x, str):
x = [x,]
setattr(cls, names, x)
if isinstance(x, (list, tuple)):
x = [a for a in x]
setattr(cls, names, x)
for i, varname in enumerate(x):
if not isinstance(varname, str) or hasattr(cls, varname) or varname in ['inputs', 'outputs']:
raise TypeError("In %s: '%s' is not a valid input or output name" % (cls.__name__, varname))
# Set an attribute for the variable so we can do op.x to return the input or output named "x".
setattr(cls, varname,
property(lambda op, type=names.replace('_name', ''), index=i:
getattr(op, type)[index]))
else:
print 'ERROR: Class variable %s::%s is neither list, tuple, or string' % (name, names)
raise TypeError, str(names)
else:
setattr(cls, names, ())
# def __init__(self, inputs, outputs, use_self_setters = False):
# assert len(inputs) == len(self.input_names)
# assert len(outputs) == len(self.output_names)
# Op.__init__(self, inputs, outputs, use_self_setters)
def __validate__(self):
assert len(self.inputs) == len(self.input_names)
assert len(self.outputs) == len(self.output_names)
@classmethod
def n_inputs(cls):
return len(cls.input_names)
@classmethod
def n_outputs(cls):
return len(cls.output_names)
def get_by_name(self, name):
"""
Returns the input or output which corresponds to the given name.
"""
if name in self.input_names:
return self.input_names[self.input_names.index(name)]
elif name in self.output_names:
return self.output_names[self.output_names.index(name)]
else:
raise AttributeError("No such input or output name for %s: %s" % (self.__class__.__name__, name))
......
差异被折叠。
from copy import copy
from result import BrokenLink, BrokenLinkError
from op import Op
import utils
......@@ -11,11 +9,17 @@ __all__ = ['inputs',
'ops',
'clone', 'clone_get_equiv',
'io_toposort',
'default_leaf_formatter', 'default_node_formatter',
'op_as_string',
'as_string',
'Graph']
def inputs(o, repair = False):
is_result = utils.attr_checker('owner', 'index')
is_op = utils.attr_checker('inputs', 'outputs')
def inputs(o):
"""
o -> list of output Results
......@@ -24,21 +28,12 @@ def inputs(o, repair = False):
"""
results = set()
def seek(r):
if isinstance(r, BrokenLink):
raise BrokenLinkError
op = r.owner
if op is None:
results.add(r)
else:
for i in range(len(op.inputs)):
try:
seek(op.inputs[i])
except BrokenLinkError:
if repair:
op.refresh()
seek(op.inputs[i])
else:
raise
for input in op.inputs:
seek(input)
for output in o:
seek(output)
return results
......@@ -60,8 +55,6 @@ def results_and_orphans(i, o):
incomplete_paths = []
def helper(r, path):
if isinstance(r, BrokenLink):
raise BrokenLinkError
if r in i:
results.update(path)
elif r.owner is None:
......@@ -128,45 +121,56 @@ def orphans(i, o):
return results_and_orphans(i, o)[1]
def clone(i, o):
def clone(i, o, copy_inputs = False):
"""
i -> list of input Results
o -> list of output Results
copy_inputs -> if True, the inputs will be copied (defaults to False)
Copies the subgraph contained between i and o and returns the
outputs of that copy (corresponding to o). The input Results in
the list are _not_ copied and the new graph refers to the
originals.
outputs of that copy (corresponding to o).
"""
new_o, equiv = clone_get_equiv(i, o)
return new_o
equiv = clone_get_equiv(i, o)
return [equiv[output] for output in o]
def clone_get_equiv(i, o, copy_inputs = False):
"""
i -> list of input Results
o -> list of output Results
copy_inputs -> if True, the inputs will be replaced in the cloned
graph by copies available in the equiv dictionary
returned by the function (copy_inputs defaults to False)
Returns (new_o, equiv) where new_o are the outputs of a copy of
the whole subgraph bounded by i and o and equiv is a dictionary
that maps the original ops and results found in the subgraph to
their copy (akin to deepcopy's memo). See clone for more details.
Returns equiv a dictionary mapping each result and op in the
graph delimited by i and o to a copy (akin to deepcopy's memo).
"""
d = {}
for op in ops(i, o):
d[op] = copy(op)
for input in i:
if copy_inputs:
d[input] = copy(input)
else:
d[input] = input
def clone_helper(result):
if result in d:
return d[result]
op = result.owner
if not op:
return result
else:
new_op = op.__class__(*[clone_helper(input) for input in op.inputs])
d[op] = new_op
for output, new_output in zip(op.outputs, new_op.outputs):
d[output] = new_output
return d[result]
for old_op, op in d.items():
for old_output, output in zip(old_op.outputs, op.outputs):
d[old_output] = output
for i, input in enumerate(op.inputs):
owner = input.owner
if owner in d:
op._inputs[i] = d[owner].outputs[input._index]
for output in o:
clone_helper(output)
return [[d[output] for output in o], d]
return d
def io_toposort(i, o, orderings = {}):
......@@ -188,17 +192,32 @@ def io_toposort(i, o, orderings = {}):
all = ops(i, o)
for op in all:
prereqs_d.setdefault(op, set()).update(set([input.owner for input in op.inputs if input.owner and input.owner in all]))
# prereqs_d[op] = set([input.owner for input in op.inputs if input.owner and input.owner in all])
return utils.toposort(prereqs_d)
def as_string(i, o):
default_leaf_formatter = str
default_node_formatter = lambda op, argstrings: "%s(%s)" % (op.__class__.__name__,
", ".join(argstrings))
def op_as_string(i, op,
leaf_formatter = default_leaf_formatter,
node_formatter = default_node_formatter):
strs = as_string(i, op.inputs, leaf_formatter, node_formatter)
return node_formatter(op, strs)
def as_string(i, o,
leaf_formatter = default_leaf_formatter,
node_formatter = default_node_formatter):
"""
i -> list of input Results
o -> list of output Results
leaf_formatter -> function that takes a result and returns a string to describe it
node_formatter -> function that takes an op and the list of strings corresponding
to its arguments and returns a string to describe it
Returns a string representation of the subgraph between i and o. If the same
Op is used by several other ops, the first occurrence will be marked as
op is used by several other ops, the first occurrence will be marked as
'*n -> description' and all subsequent occurrences will be marked as '*n',
where n is an id number (ids are attributed in an unspecified order and only
exist for viewing convenience).
......@@ -219,50 +238,37 @@ def as_string(i, o):
done = set()
def multi_index(x):
try:
return multi.index(x) + 1
except:
return 999
def describe(x, first = False):
if isinstance(x, Result):
done.add(x)
if x.owner is not None and x not in i:
op = x.owner
idx = op.outputs.index(x)
if idx:
s = describe(op, first) + "." + str(idx)
else:
s = describe(op, first)
return s
return multi.index(x) + 1
def describe(r):
if r.owner is not None and r not in i:
op = r.owner
idx = op.outputs.index(r)
if idx == op._default_output_idx:
idxs = ""
else:
return str(id(x))
elif isinstance(x, Op):
if x in done:
return "*%i" % multi_index(x)
idxs = "::%i" % idx
if op in done:
return "*%i%s" % (multi_index(x), idxs)
else:
done.add(x)
if not first and hasattr(x, 'name') and x.name is not None:
return x.name
s = x.__class__.__name__ + "(" + ", ".join([describe(v) for v in x.inputs]) + ")"
if x in multi:
done.add(op)
s = node_formatter(op, [describe(input) for input in op.inputs])
if op in multi:
return "*%i -> %s" % (multi_index(x), s)
else:
return s
else:
raise TypeError("Cannot print type: %s" % x.__class__)
return leaf_formatter(r)
return "[" + ", ".join([describe(x, True) for x in o]) + "]"
return [describe(output) for output in o]
# Op.__str__ = lambda self: as_string(inputs(self.outputs), self.outputs)[1:-1]
# Result.__str__ = lambda self: as_string(inputs([self]), [self])[1:-1]
class Graph:
"""
Object-oriented wrapper for all the functions in this module.
"""
def __init__(self, inputs, outputs):
self.inputs = inputs
......
差异被折叠。
......@@ -5,37 +5,32 @@ value that is the input or the output of an Op.
"""
import unittest
from err import GofError
from utils import AbstractFunctionError
from python25 import all
__all__ = ['is_result', 'ResultBase', 'BrokenLink', 'BrokenLinkError' ]
__all__ = ['is_result',
'ResultBase',
'BrokenLink',
'BrokenLinkError',
'StateError',
'Empty',
'Allocated',
'Computed',
]
class BrokenLink:
"""
This is placed as the owner of a Result that was replaced by
another Result.
"""
"""The owner of a Result that was replaced by another Result"""
__slots__ = ['old_role']
def __init__(self, role): self.old_role = role
def __nonzero__(self): return False
__slots__ = ['owner', 'index']
class BrokenLinkError(Exception):
"""The owner is a BrokenLink"""
def __init__(self, owner, index):
self.owner = owner
self.index = index
def __nonzero__(self):
return False
class BrokenLinkError(GofError):
"""
"""
pass
class StateError(Exception):
"""The state of the Result is a problem"""
# ResultBase state keywords
......@@ -61,6 +56,7 @@ class ResultBase(object):
_data - anything
constant - Boolean
state - one of (Empty, Allocated, Computed)
name - string
Properties:
role - (rw)
......@@ -89,33 +85,17 @@ class ResultBase(object):
expressions).
"""
class BrokenLink:
"""The owner of a Result that was replaced by another Result"""
__slots__ = ['old_role']
def __init__(self, role): self.old_role = role
def __nonzero__(self): return False
class BrokenLinkError(Exception):
"""The owner is a BrokenLink"""
class StateError(Exception):
"""The state of the Result is a problem"""
__slots__ = ['_role', 'constant', '_data', 'state', '_name']
__slots__ = ['_role', 'constant', '_data', 'state']
def __init__(self, role=None, data=None, constant=False):
def __init__(self, role=None, data=None, constant=False, name=None):
self._role = role
self.constant = constant
self._data = [None]
if data is None: #None is not filtered
self._data[0] = None
self.state = Empty
else:
try:
self._data[0] = self.data_filter(data)
except AbstractFunctionError:
self._data[0] = data
self.state = Computed
self.state = Empty
self.constant = False
self.__set_data(data)
self.constant = constant # can only lock data after setting it
self.name = name
#
# role
......@@ -144,7 +124,7 @@ class ResultBase(object):
def __get_owner(self):
if self._role is None: return None
if self.replaced: raise ResultBase.BrokenLinkError()
if self.replaced: raise BrokenLinkError()
return self._role[0]
owner = property(__get_owner,
......@@ -156,7 +136,7 @@ class ResultBase(object):
def __get_index(self):
if self._role is None: return None
if self.replaced: raise ResultBase.BrokenLinkError()
if self.replaced: raise BrokenLinkError()
return self._role[1]
index = property(__get_index,
......@@ -171,35 +151,39 @@ class ResultBase(object):
return self._data[0]
def __set_data(self, data):
if self.replaced: raise ResultBase.BrokenLinkError()
if self.constant: raise Exception('cannot set constant ResultBase')
if self.replaced:
raise BrokenLinkError()
if data is self._data[0]:
return
if self.constant:
raise Exception('cannot set constant ResultBase')
if data is None:
self._data[0] = None
self.state = Empty
return
if data is self or data is self._data[0]: return
try:
self._data[0] = self.data_filter(data)
except AbstractFunctionError: #use default behaviour
self._data[0] = data
if isinstance(data, ResultBase):
raise Exception()
self.validate(data)
except AbstractFunctionError:
pass
self._data[0] = data
self.state = Computed
data = property(__get_data, __set_data,
doc = "The storage associated with this result")
doc = "The storage associated with this result")
def data_filter(self, data):
"""(abstract) Return an appropriate _data based on data.
def validate(self, data):
"""(abstract) Raise an exception if the data is not of an
acceptable type.
If a subclass overrides this function, then that overriding
implementation will be used in __set_data to map the argument to
self._data. This gives a subclass the opportunity to ensure that
the contents of self._data remain sensible.
If a subclass overrides this function, __set_data will use
it to check that the argument can be used properly. This gives
a subclass the opportunity to ensure that the contents of
self._data remain sensible.
"""
raise AbstractFunctionError()
#
# alloc
#
......@@ -229,19 +213,100 @@ class ResultBase(object):
#
def __get_replaced(self):
return isinstance(self._role, ResultBase.BrokenLink)
return isinstance(self._role, BrokenLink)
def __set_replaced(self, replace):
if replace == self.replaced: return
if replace:
self._role = ResultBase.BrokenLink(self._role)
self._role = BrokenLink(self._role)
else:
self._role = self._role.old_role
replaced = property(__get_replaced, __set_replaced, doc = "has this Result been replaced?")
#
# C code generators
#
def c_extract(self):
get_from_list = """
PyObject* py_%(name)s = PyList_GET_ITEM(%(name)s_storage, 0);
Py_XINCREF(py_%(name)s);
"""
return self.c_data_extract() + get_from_list
def c_data_extract(self):
"""
The code returned from this function must be templated using "%(name)s",
representing the name that the caller wants to call this Result.
The Python object self.data is in a variable called "py_%(name)s" and
this code must declare a variable named "%(name)s" of a type appropriate
to manipulate from C. Additional variables and typedefs can be produced.
If the data is improper, set an appropriate error message and insert
"%(fail)s".
"""
raise AbstractFunction()
def c_sync(self, var_name):
set_in_list = """
PyList_SET_ITEM(%(name)s_storage, 0, py_%(name)s);
Py_XDECREF(py_%(name)s);
"""
return self.c_data_sync() + set_in_list
def c_data_sync(self):
"""
The code returned from this function must be templated using "%(name)s",
representing the name that the caller wants to call this Result.
The returned code may set "py_%(name)s" to a PyObject* and that PyObject*
will be accessible from Python via result.data. Do not forget to adjust
reference counts if "py_%(name)s" is changed from its original value!
"""
raise AbstractFunction()
#
# name
#
def __get_name(self):
if self._name:
return self._name
elif self._role:
return "%s.%i" % (self.owner.__class__, self.owner.outputs.index(self))
else:
return None
def __set_name(self, name):
if name is not None and not isinstance(name, str):
raise TypeError("Name is expected to be a string, or None.")
self._name = name
name = property(__get_name, __set_name,
doc = "Name of the Result.")
#
# String representation
#
def __str__(self):
name = self.name
if name:
if self.state is Computed:
return name + ":" + str(self.data)
else:
return name
elif self.state is Computed:
return str(self.data)
else:
return "<?>"
def __repr__(self):
return self.name or "<?>"
#################
# NumpyR Compatibility
#
......@@ -252,26 +317,26 @@ class ResultBase(object):
def set_value(self, value):
self.data = value #may raise exception
class _test_ResultBase(unittest.TestCase):
def test_0(self):
r = ResultBase()
def test_1(self):
r = ResultBase()
assert r.state is Empty
r.data = 0
assert r.data == 0
assert r.state is Computed
r.data = 1
assert r.data == 1
assert r.state is Computed
r.data = None
assert r.data == None
assert r.state is Empty
if __name__ == '__main__':
unittest.main()
# def c_data_extract(self):
# return """
# PyArrayObject* %%(name)s;
# if (py_%%(name)s == Py_None)
# %%(name)s = NULL;
# else
# %%(name)s = (PyArrayObject*)(py_%%(name)s);
# typedef %(dtype)s %%(name)s_dtype;
# """ % dict(dtype = self.dtype)
# def c_data_sync(self):
# return """
# if (!%(name)s) {
# Py_XDECREF(py_%(name));
# py_%(name)s = Py_None;
# }
# else if ((void*)py_%(name)s != (void*)%(name)s) {
# Py_XDECREF(py_%(name));
# py_%(name)s = (PyObject*)%(name)s;
# }
# """
......@@ -12,6 +12,15 @@ class AbstractFunctionError(Exception):
function has been left out of an implementation class.
"""
def attr_checker(*attrs):
def f(candidate):
for attr in attrs:
if not hasattr(candidate, attr):
return False
return True
f.__doc__ = "Checks that the candidate has the following attributes: %s" % ", ".join(["'%s'"%attr for attr in attrs])
return f
def all_bases(cls, accept):
rval = set([cls])
......@@ -34,21 +43,6 @@ def all_bases_collect(cls, raw_name):
def uniq_features(_features, *_rest):
"""Return a list such that no element is a subclass of another"""
# used in Env.__init__ to
features = [x for x in _features]
for other in _rest:
features += [x for x in other]
res = []
while features:
feature = features.pop()
for feature2 in features:
if issubclass(feature2, feature):
break
else:
res.append(feature)
return res
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论