implemented Constructor and various handy Allocators, modified the tests to use…

implemented Constructor and various handy Allocators, modified the tests to use them instead of overriding __new__
上级 618ad84b
import unittest
from constructor import *
import random
class MyAllocator(Allocator):
def __init__(self, fn):
self.fn = fn
def __call__(self):
return self.fn.__name__
def f1(a, b, c):
return a + b + c
def f2(x):
return "!!%s" % x
class _test_Constructor(unittest.TestCase):
def test_0(self):
c = Constructor(MyAllocator)
c.update({"fifi": f1, "loulou": f2})
assert c.fifi() == 'f1' and c.loulou() == 'f2'
def test_1(self):
c = Constructor(MyAllocator)
c.add_module(random)
assert c.random.random() == 'random' and c.random.randint() == 'randint'
def test_2(self):
c = Constructor(MyAllocator)
c.update({"fifi": f1, "loulou": f2})
globals().update(c)
assert fifi() == 'f1' and loulou() == 'f2'
if __name__ == '__main__':
unittest.main()
......@@ -9,10 +9,11 @@ from ext import *
from env import Env, InconsistencyError
from toolbox import EquivTool
class MyResult(ResultBase):
def __init__(self, name):
ResultBase.__init__(self, role = None, data = [1000], constant = False, name = name)
ResultBase.__init__(self, role = None, data = [1000], name = name)
def __str__(self):
return self.name
......@@ -23,11 +24,6 @@ class MyResult(ResultBase):
class MyOp(Op):
nin = -1
def __new__(cls, *inputs):
op = Op.__new__(cls)
op.__init__(*inputs)
return op.out
def __init__(self, *inputs):
assert len(inputs) == self.nin
......@@ -66,6 +62,13 @@ t2s = OpSubOptimizer(TransposeView, Sigmoid)
s2t = OpSubOptimizer(Sigmoid, TransposeView)
from constructor import Constructor
from allocators import BuildAllocator
c = Constructor(BuildAllocator)
c.update(globals())
globals().update(c)
class _test_all(unittest.TestCase):
def inputs(self):
......
......@@ -11,7 +11,7 @@ class MyResult(ResultBase):
def __init__(self, thingy):
self.thingy = thingy
ResultBase.__init__(self, role = None, data = [self.thingy], constant = False)
ResultBase.__init__(self, role = None, data = [self.thingy])
def __eq__(self, other):
return self.same_properties(other)
......
......@@ -9,7 +9,7 @@ class MyResult(ResultBase):
def __init__(self, thingy):
self.thingy = thingy
ResultBase.__init__(self, role = None, data = [self.thingy], constant = False)
ResultBase.__init__(self, role = None, data = [self.thingy])
def __eq__(self, other):
return self.same_properties(other)
......
......@@ -11,7 +11,7 @@ from toolbox import *
class MyResult(ResultBase):
def __init__(self, name):
ResultBase.__init__(self, role = None, data = [1000], constant = False, name = name)
ResultBase.__init__(self, role = None, data = [1000], name = name)
def __str__(self):
return self.name
......@@ -22,11 +22,6 @@ class MyResult(ResultBase):
class MyOp(Op):
def __new__(cls, *inputs):
op = Op.__new__(cls)
op.__init__(*inputs)
return op.out
def __init__(self, *inputs):
for input in inputs:
......@@ -49,6 +44,14 @@ class Op4(MyOp):
pass
from constructor import Constructor
from allocators import BuildAllocator
c = Constructor(BuildAllocator)
c.update(globals())
for k, v in c.items():
globals()[k.lower()] = v
def inputs():
x = MyResult('x')
y = MyResult('y')
......@@ -63,7 +66,7 @@ class _test_PatternOptimizer(unittest.TestCase):
def test_0(self):
x, y, z = inputs()
e = Op1(Op2(x, y), z)
e = op1(op2(x, y), z)
g = env([x, y, z], [e])
PatternOptimizer((Op1, (Op2, '1', '2'), '3'),
(Op4, '3', '2')).optimize(g)
......@@ -71,7 +74,7 @@ class _test_PatternOptimizer(unittest.TestCase):
def test_1(self):
x, y, z = inputs()
e = Op1(Op2(x, y), z)
e = op1(op2(x, y), z)
g = env([x, y, z], [e])
PatternOptimizer((Op1, (Op2, '1', '1'), '2'),
(Op4, '2', '1')).optimize(g)
......@@ -79,7 +82,7 @@ class _test_PatternOptimizer(unittest.TestCase):
def test_2(self):
x, y, z = inputs()
e = Op1(Op2(x, y), z)
e = op1(op2(x, y), z)
g = env([x, y, z], [e])
PatternOptimizer((Op2, '1', '2'),
(Op1, '2', '1')).optimize(g)
......@@ -87,7 +90,7 @@ class _test_PatternOptimizer(unittest.TestCase):
def test_3(self):
x, y, z = inputs()
e = Op1(Op2(x, y), Op2(x, y), Op2(y, z))
e = op1(op2(x, y), op2(x, y), op2(y, z))
g = env([x, y, z], [e])
PatternOptimizer((Op2, '1', '2'),
(Op4, '1')).optimize(g)
......@@ -95,7 +98,7 @@ class _test_PatternOptimizer(unittest.TestCase):
def test_4(self):
x, y, z = inputs()
e = Op1(Op1(Op1(Op1(x))))
e = op1(op1(op1(op1(x))))
g = env([x, y, z], [e])
PatternOptimizer((Op1, (Op1, '1')),
'1').optimize(g)
......@@ -103,7 +106,7 @@ class _test_PatternOptimizer(unittest.TestCase):
def test_5(self):
x, y, z = inputs()
e = Op1(Op1(Op1(Op1(Op1(x)))))
e = op1(op1(op1(op1(op1(x)))))
g = env([x, y, z], [e])
PatternOptimizer((Op1, (Op1, '1')),
'1').optimize(g)
......@@ -114,14 +117,14 @@ class _test_OpSubOptimizer(unittest.TestCase):
def test_0(self):
x, y, z = inputs()
e = Op1(Op1(Op1(Op1(Op1(x)))))
e = op1(op1(op1(op1(op1(x)))))
g = env([x, y, z], [e])
OpSubOptimizer(Op1, Op2).optimize(g)
assert str(g) == "[Op2(Op2(Op2(Op2(Op2(x)))))]"
def test_1(self):
x, y, z = inputs()
e = Op1(Op2(x), Op3(y), Op4(z))
e = op1(op2(x), op3(y), op4(z))
g = env([x, y, z], [e])
OpSubOptimizer(Op3, Op4).optimize(g)
assert str(g) == "[Op1(Op2(x), Op4(y), Op4(z))]"
......
......@@ -13,7 +13,7 @@ from toolbox import *
class MyResult(ResultBase):
def __init__(self, name):
ResultBase.__init__(self, role = None, data = [1000], constant = False, name = name)
ResultBase.__init__(self, role = None, data = [1000], name = name)
def __str__(self):
return self.name
......@@ -24,12 +24,7 @@ class MyResult(ResultBase):
class MyOp(Op):
nin = -1
def __new__(cls, *inputs):
op = Op.__new__(cls)
op.__init__(*inputs)
return op.out
def __init__(self, *inputs):
assert len(inputs) == self.nin
for input in inputs:
......@@ -47,6 +42,12 @@ class Add(MyOp):
class Dot(MyOp):
nin = 2
from constructor import Constructor
from allocators import BuildAllocator
c = Constructor(BuildAllocator)
c.update(globals())
globals().update(c)
def inputs():
x = MyResult('x')
......@@ -64,7 +65,7 @@ class _test_EquivTool(unittest.TestCase):
assert g.equiv(sx) is sx
g.replace(sx, Dot(x, z))
assert g.equiv(sx) is not sx
assert isinstance(g.equiv(sx).owner, Dot)
assert isinstance(g.equiv(sx).owner, Dot.opclass)
......
from constructor import Allocator
from op import Op
class OpAllocator(Allocator):
def __init__(self, opclass):
if not issubclass(opclass, Op):
raise TypeError("Expected an Op instance.")
self.opclass = opclass
class FilteredOpAllocator(OpAllocator):
def filter(self, op):
pass
def __call__(self, *inputs):
op = self.opclass(*inputs)
self.filter(op)
if len(op.outputs) == 1:
return op.outputs[0]
else:
return op.outputs
class BuildAllocator(FilteredOpAllocator):
pass
class EvalAllocator(FilteredOpAllocator):
def filter(self, op):
op.perform()
for output in op.outputs:
output.role = None
class BuildEvalAllocator(FilteredOpAllocator):
def filter(self, op):
op.perform()
from utils import AbstractFunctionError
class Allocator:
def __init__(self, fn):
self.fn = fn
class IdentityAllocator(Allocator):
def __call__(self, *args, **kwargs):
return self.fn(*args, **kwargs)
class Constructor(dict):
def __init__(self, allocator):
self._allocator = allocator
def add_from_module(self, module):
for symbol in dir(module):
if symbol[:2] == '__': continue
obj = getattr(module, symbol)
try:
self[symbol] = self._allocator(obj)
except TypeError:
pass
def add_module(self, module, module_name = None):
if module_name is None:
module_name = module.__name__
d = Constructor(self._allocator)
d.add_from_module(module)
self[module_name] = d
def update(self, d, can_fail = False):
for name, fn in d.items():
self.add(name, fn, can_fail)
def add(self, name, fn, can_fail = True):
if isinstance(fn, Constructor):
self[name] = fn
else:
try:
self[name] = self._allocator(fn)
except TypeError:
if can_fail:
raise
def __getattr__(self, attr):
return self[attr]
# class Constructor:
# def __init__(self):
# pass
# def add_module(self, module, module_name, accept=lambda x:issubclass(x, cf.base)):
# dct = {}
# for symbol in dir(module):
# if symbol[:2] == '__': continue
# obj = getattr(module, symbol)
# if accept(obj): dct[symbol] = Allocator(obj)
# class Dummy:pass
# self.__dict__[module_name] = Dummy()
# self.__dict__[module_name].__dict__.update(dct)
# def add_from_module(self, module, accept=lambda x:issubclass(x, cf.base)):
# for symbol in dir(module):
# if symbol[:2] == '__': continue
# obj = getattr(module, symbol)
# #print 'considering', symbol, obj
# if accept(obj): self.__dict__[symbol] = Allocator(obj)
# def add_globals_from_module(self, module, accept=lambda x:issubclass(x, cf.base)):
# for symbol in dir(module):
# if symbol[:2] == '__': continue
# obj = getattr(module, symbol)
# #print 'considering', symbol, obj
# if accept(obj):
# if hasattr(globals(), symbol):
# print 'Warning, overwriting global variable: %s' % symbol
# globals()[symbol] = Allocator(obj)
# if __name__=='__main__':
# c = Constructor()
# c.add_module(cf,'cf')
# aa,bb = c.cf.A(), c.cf.B()
# print aa,bb
# c.add_from_module(cf)
# a,b = c.A(), c.B()
# print a,b
# c.add_globals_from_module(cf)
# d,e = A(), B()
# print d,e
......@@ -2,10 +2,6 @@
from copy import copy
import graph
from utils import ClsInit
from err import GofError, GofTypeError, PropagationError
from op import Op
from result import is_result
from features import Listener, Orderings, Constraint, Tool, uniq_features
import utils
from utils import AbstractFunctionError
......@@ -236,12 +232,6 @@ class Env(graph.Graph):
be raised if there are type mismatches.
"""
# Assert that they are result instances.
if not is_result(r):
raise TypeError(r)
if not is_result(new_r):
raise TypeError(new_r)
self.__import_r_satisfy__([new_r])
# Save where we are so we can backtrack
......@@ -453,4 +443,19 @@ class Env(graph.Graph):
def __str__(self):
return "[%s]" % ", ".join(graph.as_string(self.inputs, self.outputs))
def clone(self, clone_inputs = True):
equiv = graph.clone_get_equiv(self.inputs, self.outputs, clone_inputs)
new = self.__class__([equiv[input] for input in self.inputs],
[equiv[output] for output in self.outputs],
self._features.keys(),
consistency_check = False)
try:
new.set_equiv(equiv)
except AttributeError:
pass
return new
def __copy__(self):
return self.clone()
"""
This file defines the Exceptions that may be raised by graph manipulations.
"""
class GofError(Exception):
pass
class GofTypeError(GofError):
pass
class GofValueError(GofError):
pass
class PropagationError(GofError):
pass
# from features import Tool
# from utils import AbstractFunctionError
def perform_linker(env, target = None):
order = env.toposort()
thunks = [op.perform for op in order]
def ret():
for thunk in thunks:
thunk()
if not target:
return ret
else:
raise NotImplementedError("Cannot write thunk representation to a file.")
def perform_linker_nochecks(env, target = None):
order = env.toposort()
thunks = [op._perform for op in order]
def ret():
for thunk in thunks:
thunk()
if not target:
return ret
else:
raise NotImplementedError("Cannot write thunk representation to a file.")
def cthunk_linker(env):
order = env.toposort()
thunks = []
cstreak = []
def append_cstreak():
if cstreak:
thunks.append(cutils.create_cthunk_loop(*cstreak))
cstreak = []
def ret():
for thunk in thunks:
thunk()
for op in order:
if hasattr(op, 'cthunk'):
cstreak.append(op.cthunk())
else:
append_cstreak()
thunks.append(op.perform)
if len(thunks) == 1:
return thunks[0]
else:
return ret
# class Linker(Tool):
# def compile(self):
# raise AbstractFunctionError()
# def run(self):
# raise AbstractFunctionError()
# def perform_linker(env, target = None):
# order = env.toposort()
# thunks = [op.perform for op in order]
# def ret():
# for thunk in thunks:
# thunk()
# if not target:
# return ret
# else:
# raise NotImplementedError("Cannot write thunk representation to a file.")
# def perform_linker_nochecks(env, target = None):
# order = env.toposort()
# thunks = [op._perform for op in order]
# def ret():
# for thunk in thunks:
# thunk()
# if not target:
# return ret
# else:
# raise NotImplementedError("Cannot write thunk representation to a file.")
# def cthunk_linker(env):
# order = env.toposort()
# thunks = []
# cstreak = []
# def append_cstreak():
# if cstreak:
# thunks.append(cutils.create_cthunk_loop(*cstreak))
# cstreak = []
# def ret():
# for thunk in thunks:
# thunk()
# for op in order:
# if hasattr(op, 'cthunk'):
# cstreak.append(op.cthunk())
# else:
# append_cstreak()
# thunks.append(op.perform)
# if len(thunks) == 1:
# return thunks[0]
# else:
# return ret
......@@ -82,9 +82,9 @@ class OpSubOptimizer(Optimizer):
for op in candidates:
try:
# note: only replaces the default 'out' port if it exists
r = self.op2(*op.inputs)
if isinstance(r, Op):
r = r.out
r = self.op2(*op.inputs).out
# if isinstance(r, Op):
# r = r.out
env.replace(op.out, r)
except InconsistencyError, e:
# print "Warning: OpSubOpt failed to transform %s into %s: %s" % (op, self.op2, str(e)) # warning is for debug
......@@ -160,12 +160,13 @@ class PatternOptimizer(OpSpecificOptimizer):
if u:
try:
# note: only replaces the default 'out' port if it exists
new = build(self.out_pattern, u)
if isinstance(new, Op):
p = self.out_pattern
new = build(p, u)
if not isinstance(p, str):
new = new.out
env.replace(op.out, new)
except InconsistencyError, e:
print "Warning: '%s' failed to apply on %s: %s" % (self, op, str(e)) # warning is for debug
# print "Warning: '%s' failed to apply on %s: %s" % (self, op, str(e)) # warning is for debug
pass
......
from env import Env
from utils import AbstractFunctionError
class Linker:
def __init__(self, env):
self.env = env
self.thunk = None
def compile(self):
raise AbstractFunctionError()
def run(self):
self.thunk()
# import compile
import env
......
......@@ -6,13 +6,9 @@ value that is the input or the output of an Op.
"""
from utils import AbstractFunctionError
from python25 import all
__all__ = ['is_result',
'ResultBase',
# 'BrokenLink',
# 'BrokenLinkError',
__all__ = ['ResultBase',
'StateError',
'Empty',
'Allocated',
......@@ -20,15 +16,6 @@ __all__ = ['is_result',
]
# class BrokenLink:
# """The owner of a Result that was replaced by another Result"""
# __slots__ = ['old_role']
# def __init__(self, role): self.old_role = role
# def __nonzero__(self): return False
# class BrokenLinkError(Exception):
# """The owner is a BrokenLink"""
class StateError(Exception):
"""The state of the Result is a problem"""
......@@ -43,18 +30,12 @@ class Computed : """Memory has been allocated, contents are the owner's output."
# Result
############################
def is_result(obj):
"""Return True iff obj provides the interface of a Result"""
attr_list = 'owner',
return all([hasattr(obj, attr) for attr in attr_list])
class ResultBase(object):
"""Base class for storing Op inputs and outputs
Attributes:
_role - None or (owner, index) #or BrokenLink
_data - anything
constant - Boolean
state - one of (Empty, Allocated, Computed)
name - string
......@@ -63,7 +44,6 @@ class ResultBase(object):
owner - (ro)
index - (ro)
data - (rw) : calls data_filter when setting
# replaced - (rw) : True iff _role is BrokenLink
Methods:
alloc() - create storage in data, suitable for use by C ops.
......@@ -72,29 +52,16 @@ class ResultBase(object):
Abstract Methods:
data_filter
data_alloc
# Notes (from previous implementation):
# A Result instance should be immutable: indeed, if some aspect of a
# Result is changed, operations that use it might suddenly become
# invalid. Instead, a new Result instance should be instanciated
# with the correct properties and the invalidate method should be
# called on the Result which is replaced (this will make its owner a
# BrokenLink instance, which behaves like False in conditional
# expressions).
"""
__slots__ = ['_role', 'constant', '_data', 'state', '_name']
__slots__ = ['_role', '_data', 'state', '_name']
def __init__(self, role=None, data=None, constant=False, name=None):
def __init__(self, role=None, data=None, name=None):
self._role = role
self._data = [None]
self.state = Empty
self.constant = False
self.__set_data(data)
self.constant = constant # can only lock data after setting it
self.name = name
#
......@@ -124,7 +91,6 @@ class ResultBase(object):
def __get_owner(self):
if self._role is None: return None
# if self.replaced: raise BrokenLinkError()
return self._role[0]
owner = property(__get_owner,
......@@ -136,7 +102,6 @@ class ResultBase(object):
def __get_index(self):
if self._role is None: return None
# if self.replaced: raise BrokenLinkError()
return self._role[1]
index = property(__get_index,
......@@ -151,12 +116,8 @@ class ResultBase(object):
return self._data[0]
def __set_data(self, data):
# if self.replaced:
# raise BrokenLinkError()
if data is self._data[0]:
return
if self.constant:
raise Exception('cannot set constant ResultBase')
if data is None:
self._data[0] = None
self.state = Empty
......@@ -208,27 +169,16 @@ class ResultBase(object):
raise AbstractFunctionError()
#
# replaced
#
# def __get_replaced(self):
# return isinstance(self._role, BrokenLink)
# def __set_replaced(self, replace):
# if replace == self.replaced: return
# if replace:
# self._role = BrokenLink(self._role)
# else:
# self._role = self._role.old_role
# replaced = property(__get_replaced, __set_replaced, doc = "has this Result been replaced?")
#
# C code generators
#
def c_type(self):
"""
Return a string naming the C type that Ops must use to manipulate
this Result.
"""
raise AbstractFunctionError()
def c_extract(self):
get_from_list = """
......@@ -239,16 +189,17 @@ class ResultBase(object):
def c_data_extract(self):
"""
The code returned from this function must be templated using "%(name)s",
representing the name that the caller wants to call this Result.
The Python object self.data is in a variable called "py_%(name)s" and
this code must declare a variable named "%(name)s" of a type appropriate
to manipulate from C. Additional variables and typedefs can be produced.
If the data is improper, set an appropriate error message and insert
"%(fail)s".
The code returned from this function must be templated using
"%(name)s", representing the name that the caller wants to
call this Result. The Python object self.data is in a
variable called "py_%(name)s" and this code must declare a
variable named "%(name)s" of type "%(type)s" where "%(type)s"
will be replaced by the return value of
self.c_type(). Additional variables and typedefs can be
produced. If the data is improper, set an appropriate error
message and insert "%(fail)s".
"""
raise AbstractFunction()
raise AbstractFunctionError()
def c_sync(self, var_name):
set_in_list = """
......@@ -265,7 +216,20 @@ class ResultBase(object):
will be accessible from Python via result.data. Do not forget to adjust
reference counts if "py_%(name)s" is changed from its original value!
"""
raise AbstractFunction()
raise AbstractFunctionError()
def c_headers(self):
"""
Return a list of header files that must be included from C to manipulate
this Result.
"""
return []
def c_libraries(self):
"""
Return a list of libraries to link against to manipulate this Result.
"""
return []
#
# name
......@@ -311,16 +275,6 @@ class ResultBase(object):
# same properties
#
# def __eq__(self, other):
# if self.state is not Computed:
# raise StateError("Can only compare computed results for equality.")
# if isinstance(other, Result):
# if other.state is not Computed:
# raise StateError("Can only compare computed results for equality.")
# return self.data == other.data
# else:
# return self.data == other
def same_properties(self, other):
raise AbstractFunction()
......
......@@ -23,6 +23,10 @@ class EquivTool(Listener, Tool, dict):
def publish(self):
self.env.equiv = self
self.env.set_equiv = self.set_equiv
def set_equiv(self, d):
self.update(d)
def group(self, main, *keys):
"Marks all the keys as having been replaced by the Result main."
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论