提交 93b4e940 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

cleaned up the files

上级 c9221f08
import unittest
from wrappers import *
class _testCase_input(unittest.TestCase):
def setUp(self):
literal.hdb = {}
literal.udb = {}
def test_input_int(self):
w = input(3)
self.failUnless(isinstance(w, input.NN))
self.failUnless(str(w.data.dtype) == input.int_dtype)
self.failUnless(w.data == 3)
def test_input_float(self):
w = input(3.0)
self.failUnless(isinstance(w, input.NN))
self.failUnless(str(w.data.dtype) == input.float_dtype)
self.failUnless(w.data == 3.0)
class _testCase_wrap(unittest.TestCase):
def setUp(self):
literal.hdb = {}
literal.udb = {}
def test_wrap_int(self):
w = wrap(3)
self.failUnless(isinstance(w, input.NN))
self.failUnless(str(w.data.dtype) == input.int_dtype)
self.failUnless(w.data == 3)
def test_wrap_float(self):
w = wrap(3.0)
self.failUnless(isinstance(w, input.NN))
self.failUnless(str(w.data.dtype) == input.float_dtype)
self.failUnless(w.data == 3.0)
class _testCase_literal(unittest.TestCase):
def setUp(self):
literal.hdb = {}
literal.udb = {}
def test_int(self):
w = literal(3)
self.failUnless(isinstance(w, input.NN))
self.failUnless(str(w.data.dtype) == input.int_dtype)
self.failUnless(w.data == 3)
u = literal(1+2)
self.failUnless(u is w)
def test_float(self):
w = literal(3.0)
self.failUnless(isinstance(w, input.NN))
self.failUnless(str(w.data.dtype) == input.float_dtype)
self.failUnless(w.data == 3.0)
u = literal(1.0+2.0)
self.failUnless(u is w)
def test_mixed(self):
f = literal(2.0)
i = literal(2)
self.failUnless(i is not f)
if __name__ == '__main__':
unittest.main()
import unittest
import constructor_fodder as cf
class Allocator:
def __init__(self, cls, ctor):
self.cls = cls
self.ctor = ctor
def __call__(self, *args, **kwargs):
rval = self.cls.__new__(self.cls, *args, **kwargs)
rval.__init__(*args, **kwargs)
return rval
class ModeOpAllocator:
def __init__(self, cls, ctor):
self.cls = cls
self.ctor = ctor
def __call__(self, *args, **kwargs):
op = self.cls.__new__(self.cls, *args, **kwargs)
op.__init__(*args, **kwargs)
mode = self.ctor.mode()
if mode == 'eval':
op.perform()
if op.nout == 1:
return op.out.data
else:
return [output.data for output in op.outputs]
elif mode == 'build_eval':
op.perform()
if op.nout == 1:
return op.out
else:
return op.outputs
class Constructor:
def __init__(self):
pass
def add_module(self, module, module_name, accept=lambda x:issubclass(x, cf.base)):
dct = {}
for symbol in dir(module):
if symbol[:2] == '__': continue
obj = getattr(module, symbol)
if accept(obj): dct[symbol] = Allocator(obj)
class Dummy:pass
self.__dict__[module_name] = Dummy()
self.__dict__[module_name].__dict__.update(dct)
def add_from_module(self, module, accept=lambda x:issubclass(x, cf.base)):
for symbol in dir(module):
if symbol[:2] == '__': continue
obj = getattr(module, symbol)
#print 'considering', symbol, obj
if accept(obj): self.__dict__[symbol] = Allocator(obj)
def add_globals_from_module(self, module, accept=lambda x:issubclass(x, cf.base)):
for symbol in dir(module):
if symbol[:2] == '__': continue
obj = getattr(module, symbol)
#print 'considering', symbol, obj
if accept(obj):
if hasattr(globals(), symbol):
print 'Warning, overwriting global variable: %s' % symbol
globals()[symbol] = Allocator(obj)
if __name__=='__main__':
c = Constructor()
c.add_module(cf,'cf')
aa,bb = c.cf.A(), c.cf.B()
print aa,bb
c.add_from_module(cf)
a,b = c.A(), c.B()
print a,b
c.add_globals_from_module(cf)
d,e = A(), B()
print d,e
class base(object): pass
class A(base): pass
class B(base): pass
class C(base): pass
差异被折叠。
import unittest
from constructor import *
import random
class MyAllocator(Allocator):
def __init__(self, fn):
self.fn = fn
def __call__(self):
return self.fn.__name__
def f1(a, b, c):
return a + b + c
def f2(x):
return "!!%s" % x
class _test_Constructor(unittest.TestCase):
def test_0(self):
c = Constructor(MyAllocator)
c.update({"fifi": f1, "loulou": f2})
assert c.fifi() == 'f1' and c.loulou() == 'f2'
def test_1(self):
c = Constructor(MyAllocator)
c.add_module(random)
assert c.random.random() == 'random' and c.random.randint() == 'randint'
def test_2(self):
c = Constructor(MyAllocator)
c.update({"fifi": f1, "loulou": f2})
globals().update(c)
assert fifi() == 'f1' and loulou() == 'f2'
if __name__ == '__main__':
unittest.main()
from constructor import Allocator
from op import Op
class OpAllocator(Allocator):
def __init__(self, opclass):
if not issubclass(opclass, Op):
raise TypeError("Expected an Op instance.")
self.opclass = opclass
class FilteredOpAllocator(OpAllocator):
def filter(self, op):
pass
def __call__(self, *inputs):
op = self.opclass(*inputs)
self.filter(op)
if len(op.outputs) == 1:
return op.outputs[0]
else:
return op.outputs
class BuildAllocator(FilteredOpAllocator):
pass
class EvalAllocator(FilteredOpAllocator):
def filter(self, op):
op.perform()
for output in op.outputs:
output.role = None
class BuildEvalAllocator(FilteredOpAllocator):
def filter(self, op):
op.perform()
from utils import AbstractFunctionError
class Dispatcher(list):
all_dispatchers = {}
def __init__(self, name, description):
self.name = name
self.description = description
self.all_dispatchers[name] = self
def __call__(self, *inputs, **opts):
for candidate in self:
try:
return candidate(*inputs, **opts)
except TypeError:
continue
if opts:
s = " with options %s" % opts
else:
s = ""
raise OmegaTypeError("No candidate found for %s(%s) %s" \
% (self.name,
", ".join([input.__class__.__name__ for input in inputs]),
s))
def add_handler(self, x):
self.insert(0, x)
def fallback_handler(self, x):
self.append(x)
class Allocator:
def __init__(self, fn):
self.fn = fn
class IdentityAllocator(Allocator):
def __call__(self, *args, **kwargs):
return self.fn(*args, **kwargs)
class Constructor(dict):
def __init__(self, allocator):
self._allocator = allocator
def add_from_module(self, module):
for symbol in dir(module):
if symbol[:2] == '__': continue
obj = getattr(module, symbol)
try:
self[symbol] = self._allocator(obj)
except TypeError:
pass
def add_module(self, module, module_name = None):
if module_name is None:
module_name = module.__name__
d = Constructor(self._allocator)
d.add_from_module(module)
self[module_name] = d
def update(self, d, can_fail = False):
for name, fn in d.items():
self.add(name, fn, can_fail)
def add(self, name, fn, can_fail = True):
if isinstance(fn, Constructor):
self[name] = fn
else:
try:
self[name] = self._allocator(fn)
except TypeError:
if can_fail:
raise
def __getattr__(self, attr):
return self[attr]
from utils import OmegaError
class OmegaTypeError(OmegaError, TypeError):
pass
############################
# Dispatcher
############################
class Dispatcher(list):
all_dispatchers = {}
def __init__(self, name, description):
self.name = name
self.description = description
self.all_dispatchers[name] = self
def __call__(self, *inputs, **opts):
for candidate in self:
try:
return candidate(*inputs, **opts)
except OmegaTypeError:
continue
if opts:
s = " with options %s" % opts
else:
s = ""
raise OmegaTypeError("No candidate found for %s(%s) %s" \
% (self.name,
", ".join([input.__class__.__name__ for input in inputs]),
s))
def add_handler(self, x):
self.insert(0, x)
def fallback_handler(self, x):
self.append(x)
# Dispatchers for all python operators
Add = Dispatcher("Add", "x + y")
Subtract = Dispatcher("Subtract", "x - y")
Multiply = Dispatcher("Multiply", "x * y")
Divide = Dispatcher("Divide", "x / y")
FloorDivide = Dispatcher("FloorDivide", "x // y")
Modulo = Dispatcher("Modulo", "x % y")
Power = Dispatcher("Power", "x ** y")
Negate = Dispatcher("Negate", "-x")
Abs = Dispatcher("Abs", "abs(x)")
LeftShift = Dispatcher("LeftShift", "x << y")
RightShift = Dispatcher("RightShift", "x >> y")
Equals = Dispatcher("Equals", "x == y")
NotEquals = Dispatcher("NotEquals", "x != y")
Less = Dispatcher("Less", "x < y")
LessOrEqual = Dispatcher("LessOrEqual", "x <= y")
Greater = Dispatcher("Greater", "x > y")
GreaterOrEqual = Dispatcher("GreaterOrEqual", "x >= y")
Contains = Dispatcher("Contains", "x in y")
BinaryOr = Dispatcher("BinaryOr", "x | y")
BinaryAnd = Dispatcher("BinaryAnd", "x & y")
BinaryXor = Dispatcher("BinaryXor", "x ^ y")
BinaryInverse = Dispatcher("BinaryInverse", "~x")
# Dispatchers for special operations
Transpose = Dispatcher("Transpose", "x.T")
# Others
Log = Dispatcher("Log", 'log(x)')
Exp = Dispatcher("Exp", 'exp(x)')
Sin = Dispatcher("Sin", 'sin(x)')
Cos = Dispatcher("Cos", 'cos(x)')
Tan = Dispatcher("Tan", 'tan(x)')
############################
# PythonOperatorSupport
############################
class PythonOperatorSupport(object):
"""Support for built-in Python operators."""
# Common arithmetic operations
def __add__(self, x):
return Add(self, x)
def __radd__(self, x):
return Add(x, self)
def __sub__(self, x):
return Subtract(self, x)
def __rsub__(self, x):
return Subtract(x, self)
def __mul__(self, x):
return Multiply(self, x)
def __rmul__(self, x):
return Multiply(x, self)
def __div__(self, x):
return Divide(self, x)
def __rdiv__(self, x):
return Divide(x, self)
def __floordiv__(self, x):
return FloorDivide(self, x)
def __rfloordiv__(self, x):
return FloorDivide(x, self)
def __mod__(self, x):
return Modulo(self, x)
def __rmod__(self, x):
return Modulo(x, self)
def __pow__(self, x):
return Power(self, x)
def __rpow__(self, x):
return Power(x, self)
def __neg__(self):
return Negate(self)
def __abs__(self):
return Abs(self)
# Less common arithmetic operations
def __lshift__(self, x):
return LeftShift(self, x)
def __rlshift__(self, x):
return LeftShift(x, self)
def __rshift__(self, x):
return RightShift(self, x)
def __rrshift__(self, x):
return RightShift(x, self)
# Comparison operations
# def __eq__(self, x):
# return Equals(self, x)
# def __ne__(self, x):
# return NotEquals(self, x)
def __lt__(self, x):
return Less(self, x)
def __le__(self, x):
return LessOrEqual(self, x)
def __gt__(self, x):
return Greater(self, x)
def __ge__(self, x):
return GreaterOrEqual(self, x)
def __contains__(self, x):
return Contains(self, x)
# Binary operations
def __or__(self, x):
return BinaryOr(self, x)
def __ror__(self, x):
return BinaryOr(x, self)
def __and__(self, x):
return BinaryAnd(self, x)
def __rand__(self, x):
return BinaryAnd(x, self)
def __xor__(self, x):
return BinaryXor(self, x)
def __rxor__(self, x):
return BinaryXor(x, self)
def __invert__(self):
return BinaryInverse(self)
# Other operations
T = property(lambda self: Transpose(self))
norm = property(lambda self: Norm(self))
# Always nonzero
def __nonzero__(self):
return True
__all__ = globals().keys()
......@@ -110,9 +110,9 @@ class Env(graph.Graph):
### Public interface ###
# def add_output(self, output):
# self.outputs.add(output)
# self.__import_r__([output])
def add_output(self, output):
self.outputs.add(output)
self.__import_r__([output])
def clients(self, r):
"Set of all the (op, i) pairs such that op.inputs[i] is r."
......@@ -250,8 +250,6 @@ class Env(graph.Graph):
if r in self.outputs:
was_output = True
self.outputs[self.outputs.index(r)] = new_r
# self.outputs.remove(r)
# self.outputs.add(new_r)
# The actual replacement operation occurs here. This might raise
# an error.
......@@ -263,8 +261,6 @@ class Env(graph.Graph):
if was_output:
if not new_was_output:
self.outputs[self.outputs.index(new_r)] = r
# self.outputs.remove(new_r)
# self.outputs.add(r)
# Move back the clients. This should never raise an error.
self.__move_clients__(clients, new_r, r)
......
......@@ -118,7 +118,6 @@ class DestroyHandler(Listener, Constraint, Orderings):
for input in destroyed:
path = self.__path__(input)
self.__add_destroyer__(path + [output])
####### self.__add_destroyer__(path + [op])
elif views:
if len(views) > 1:
......
from copy import copy
from op import Op
from result import is_result, ResultBase
from utils import ClsInit, Keyword, AbstractFunctionError
import opt
import env
import features
import ext
from python25 import all
__all__ = [ 'UNDEFINED',
'current_mode',
'set_mode',
'build_mode',
'eval_mode',
'build_eval_mode',
'pop_mode',
'DummyOp',
'DummyRemover',
'PythonOp',
'PythonOpt',
'make_static']
UNDEFINED = Keyword("UNDEFINED", False)
def make_static(cls, fname):
f = getattr(cls, fname)
if hasattr(f, 'im_func'):
f = f.im_func
setattr(cls, fname, staticmethod(f))
def compute_from(nodes, history):
"""Recursively evaluate each node (in a quick & dirty way).
history (aka inputs) is a set of nodes that need not be [re]computed.
TODO: make this more correct by building a little graph and executing it.
The current implementation doesn't take into account any ordering
constraints imposed by destructors, for example.
"""
def compute_recursive(node):
if node and (node not in history):
if hasattr(node, 'owner'): #node is storage
compute_recursive(node.owner)
else: #node is op
if node.destroy_map():
raise ValueError('compute_from() does not work on nodes with destroy_maps')
for input in node.inputs:
compute_recursive(input)
node.perform()
history.add(node)
for n in nodes:
compute_recursive(n)
def compute(*nodes):
"""Recursively evaluate each node (in a quick & dirty way)."""
compute_from(nodes, set())
def root_inputs(input):
"""Return the leaves of a search through consecutive view_map()s"""
owner = input.owner
if owner:
view_map = owner.view_map()
if input in view_map:
answer = []
for input2 in view_map[input]:
answer.extend(root_inputs(input2))
return answer
else:
return [input]
else:
return [input]
class ForbidConstantOverwrite(features.Listener, features.Constraint):
def __init__(self, env):
self.env = env
self.bad = set()
def on_import(self, op):
for output, inputs in op.destroy_map().items():
for input in inputs:
for root_input in root_inputs(input):
if getattr(root_input, 'constant', False):
self.bad.add(op)
return
def on_prune(self, op):
if op in self.bad:
self.bad.remove(op)
def on_rewire(self, clients, r, new_r):
for op, i in clients:
self.on_prune(op)
self.on_import(op)
def validate(self):
if self.bad:
raise env.InconsistencyError("The following ops overwrite a constant value: %s" % self.bad)
else:
return True
class NewPythonOp(Op):
__env_require__ = DestroyHandler, ForbidConstantOverwrite
def view_map(self):
return {}
def destroy_map(self):
return {}
class PythonOp(NewPythonOp):
__metaclass__ = ClsInit
nout = 1
@staticmethod
def __clsinit__(cls, name, bases, dct):
# make impl a static method
cls.set_impl(cls.impl)
def __new__(cls, *inputs, **kwargs):
op = NewPythonOp.__new__(cls)
op.__init__(*inputs)
mode = kwargs.get('mode', None) or current_mode()
if mode == 'eval':
op.perform()
if op.nout == 1:
return op.out.data
else:
return [output.data for output in op.outputs]
elif mode == 'build_eval':
op.perform()
if op.nout == 1:
return op.out
else:
return op.outputs
def __init__(self, *inputs):
NewPythonOp.__init__(self, inputs, self.gen_outputs())
def __validate__(self):
return all([is_result(i) for i in self.inputs])
def gen_outputs(self):
raise AbstractFunctionError()
def check_input(self, input):
def input_is_up_to_date(input):
answer = True
for input in root_inputs(input):
answer &= input.up_to_date
return answer
if input.data is None:
raise ValueError("Uncomputed input: %s in %s" % (input, self))
if not input_is_up_to_date(input):
raise ValueError("Input is out of date: %s in %s" % (input, self))
def perform(self):
def input_is_constant(input):
answer = False
for input in root_inputs(input):
answer |= input.constant
return answer
exc = set()
for output, inputs in self.destroy_map().items():
exc.update(inputs)
for input in inputs:
if self.input_is_constant(input):
raise ValueError("Input is constant: %s" % input)
for input in exc:
self.check_input(input)
input.up_to_date = False
for input in self.inputs:
if input not in exc:
self.check_input(input)
if 0:
#J- why is this try catch here? Leftover debug?
try:
results = self._impl()
except Exception, e:
print "Error in %s: %s" % (self, e)
raise
else:
results = self._impl()
if self.nout == 1:
self.out.data = results
else:
assert self.nout == len(results)
for result, output in zip(results, self.outputs):
output.data = result
def _perform(self):
results = self._impl()
if self.nout == 1:
self.out.data = results
else:
assert self.nout == len(results)
for result, output in zip(results, self.outputs):
output.data = result
def _perform_inplace(self):
results = self._impl()
if self.nout == 1:
self.out.set_value_inplace(results)
else:
assert self.nout == len(results)
for result, output in zip(results, self.outputs):
output.set_value_inplace(result)
def _impl(self):
return self.impl(*[input.data for input in self.inputs])
@classmethod
def set_impl(cls, impl):
make_static(cls, 'impl')
def impl(*args):
raise NotImplementedError("This op has no implementation.")
def __copy__(self):
"""
Copies the inputs list shallowly and copies all the outputs
because of the one owner per output restriction.
"""
# new_inputs = copy(op.inputs)
# # We copy the outputs because they are tied to a single Op.
# new_outputs = [copy(output) for output in op.outputs]
build_mode()
op = self.__class__(*self.inputs)
pop_mode()
# op._inputs = new_inputs
# op._outputs = new_outputs
# for i, output in enumerate(op.outputs):
# # We adjust _owner and _index manually since the copies
# # point to the previous op (self).
# output._owner = op
# output._index = i
if isinstance(op, (list, tuple)):
return op[0].owner
return op.owner
__mode__ = ['build_eval']
def current_mode():
return __mode__[-1]
def set_mode(mode):
__mode__.append(mode)
def build_mode():
set_mode('build')
def eval_mode():
set_mode('eval')
def build_eval_mode():
set_mode('build_eval')
def pop_mode():
if len(__mode__) == 1:
raise Exception("There's only one mode left on the stack.")
else:
__mode__.pop()
class PythonOpt(opt.Optimizer):
def __init__(self, opt):
self.opt = opt
def optimize(self, env):
build_mode()
self.opt.optimize(env)
pop_mode()
class DummyOp(NewPythonOp):
def __init__(self, input):
Op.__init__(self, [input], [ResultBase()])
def thunk(self):
return lambda:None
DummyRemover = opt.OpRemover(DummyOp)
if 0:
class RefreshableOp(NewPythonOp):
def _specs(self):
try:
return self.specs(*[input.spec for input in self.inputs])
except NotImplementedError:
raise NotImplementedError("%s cannot infer the specs of its outputs" % self.__class__.__name__)
def specs(*inputs):
raise NotImplementedError
def refresh(self):
"""Update and allocate outputs if necessary"""
for input in self.inputs:
input.refresh()
change = self._propagate_specs()
if change:
self.alloc(except_list)
return change
def _propagate_specs(self):
specs = self._specs()
if self.nout == 1:
specs = [specs]
change = False
for output, spec in zip(self.outputs, specs):
if output.spec != spec:
output.spec = spec
change = True
return change
def alloc(self, except_list = []):
for output in self.outputs:
if output not in except_list:
output.alloc()
......@@ -7,10 +7,6 @@ __all__ = ['ModalConstructor',
'build',
'eval',
'build_eval',
# 'ModalWrapper',
# 'BuildMode',
# 'EvalMode',
# 'BuildEvalMode',
'make_constructors',
]
......@@ -31,28 +27,15 @@ class ModalConstructor:
if mode != modal_wrapper:
raise TypeError("Inconsistent modes.")
fn_args.append(arg)
# for arg in args:
# if isinstance(arg, ModalWrapper):
# if modal_wrapper is None:
# modal_wrapper = arg.__class__
# else:
# if not isinstance(arg, modal_wrapper):
# raise TypeError("Inconsistent modes.")
# fn_args.append(arg.r)
# else:
# fn_args.append(arg)
op = self.fn(*fn_args)
if modal_wrapper:
modal_wrapper(op)
# modal_wrapper.filter(op)
for output in op.outputs:
output.__mode__ = modal_wrapper
if len(op.outputs) == 1:
return op.outputs[0]
#return modal_wrapper(op.outputs[0])
else:
return op.outputs
#return [modal_wrapper(output) for output in op.outputs]
def add_modal_members(cls, *members):
......@@ -65,30 +48,6 @@ def add_modal_members(cls, *members):
setattr(cls, member, fn(member))
# class ModalWrapper:
# def __init__(self, r):
# self.r = r
# def __as_result__(self):
# return self.r
# def __get_owner(self):
# return self.r.owner
# owner = property(__get_owner)
# @classmethod
# def filter(cls, op):
# raise AbstractFunctionError()
# members1 = 'add sub mul div pow floordiv mod pow lshift rshift and or xor'.split(' ')
# members = []
# members += ["__%s__" % x for x in members1 + 'neg invert'.split(' ')]
# members += ["__r%s__" % x for x in members1]
# add_modal_members(ModalWrapper, *members)
def build_mode(op):
pass
......@@ -112,24 +71,6 @@ eval = mode_setter(eval_mode)
build_eval = mode_setter(build_eval_mode)
# class BuildMode(ModalWrapper):
# @classmethod
# def filter(cls, op):
# pass
# class EvalMode(ModalWrapper):
# @classmethod
# def filter(cls, op):
# op.perform()
# for output in op.outputs:
# output._role = None
# class BuildEvalMode(ModalWrapper):
# @classmethod
# def filter(cls, op):
# op.perform()
def _is_op(x):
try: return issubclass(x, Op)
except: return False
......
......@@ -142,22 +142,6 @@ class Op(object):
"""
raise AbstractFunctionError()
# def c_update(self):
# """
# Returns C code that allocates and/or updates the outputs
# (eg resizing, etc.) so they can be manipulated safely
# by c_code.
# You may use the variable names defined by c_var_names()
# """
# raise AbstractFunctionError()
# def c_update_cleanup(self):
# """
# Clean up things allocated by c_update().
# """
# raise AbstractFunctionError()
def c_code(self):
"""
Returns C code that does the computation associated to this
......
from env import Env
from utils import AbstractFunctionError
class Prog:
def __init__(self, inputs, outputs, optimizer, linker_class, features = []):
self.inputs = inputs
if isinstance(outputs, dict):
for name, output in outputs.items():
setattr(self, name, output)
self.outputs = outputs.values()
else:
self.outputs = outputs
self.optimizer = optimizer
self.env = Env(self.inputs, self.outputs, features, False)
self.env.add_feature(EquivTool)
self.linker = linker_class(self.env)
def build(self):
self.optimizer.optimize(self.env)
def equiv(self, r):
return self.env.equiv(r)
def __getitem__(self, r):
if isinstance(r, str):
return getattr(self, r)
else:
return self.equiv(r)
def __setitem__(self, r, value):
if isinstance(r, tuple):
for a, b in zip(r, value):
self.__setitem__(a, b)
else:
self[r].data = value
# import compile
import env
import link
from features import EquivTool
class Prog:
def __init__(self, inputs, outputs, optimizer, linker, features = []):
self.optimizer = optimizer
self.linker = linker
features = set(features)
features.add(EquivTool)
self.env = env.Env(inputs, outputs, features) #, False)
self.optimizer.optimize(self.env)
self.perform = self.linker(self.env)
self.outputs = outputs
# def __optimize__(self):
# self.optimizer.apply(self.env)
# self.order = self.env.toposort()
def equiv(self, r):
return self.env.equiv(r)
def __getitem__(self, r):
return self.equiv(r)
def __setitem__(self, r, value):
if isinstance(r, tuple):
for a, b in zip(r, value):
self.__setitem__(a, b)
else:
self.equiv(r).set_value(value)
def __call__(self, *args):
self.perform()
for output in self.outputs:
output.set_value(self[output])
return self.outputs
# return [output for output in self.env.outputs]
# if args:
# for input, arg in zip(self.inputs, args):
# if arg is not None:
# input.value = arg
# for thunk, op in zip(self.thunks, self.order):
# try:
# thunk()
# except Exception, e:
# raise e.__class__("Error in " + str(op) + ": " + str(e))
# return [output.value for output in self.outputs]
import sys
if sys.version_info[:2] < (2,5):
def all(iterable):
for element in iterable:
if not element:
return False
return True
else:
# Only bother with this else clause and the __all__ line if you are putting
# this in a separate file.
import __builtin__
all = __builtin__.all
__all__ = ['all']
......@@ -160,33 +160,18 @@ class ResultBase(object):
"""
raise AbstractFunctionError()
# def c_extract(self):
# get_from_list = """
# //PyObject* py_%(name)s = PyList_GET_ITEM(%(name)s_storage, 0);
# //Py_XINCREF(py_%(name)s);
# """
# return get_from_list + self.c_data_extract()
def c_extract(self):
"""
# The code returned from this function must be templated using
# "%(name)s", representing the name that the caller wants to
# call this Result. The Python object self.data is in a
# variable called "py_%(name)s" and this code must declare a
# variable named "%(name)s" of type "%(type)s" where "%(type)s"
# will be replaced by the return value of
# self.c_type(). Additional variables and typedefs may not be
# produced. If the data is improper, set an appropriate error
# message and insert "%(fail)s".
The code returned from this function must be templated using
"%(name)s", representing the name that the caller wants to
call this Result. The Python object self.data is in a
variable called "py_%(name)s" and this code must set the
variables declared by c_declare to something representative
of py_%(name)s. If the data is improper, set an appropriate error
message and insert "%(fail)s".
"""
raise AbstractFunctionError()
# def c_cleanup(self):
# decref = """
# //Py_XDECREF(py_%(name)s);
# """
# return self.c_data_cleanup() + decref
def c_cleanup(self):
"""
This returns C code that should deallocate whatever
......@@ -197,13 +182,6 @@ class ResultBase(object):
"""
raise AbstractFunctionError()
# def c_sync(self):
# set_in_list = """
# //PyList_SET_ITEM(%(name)s_storage, 0, py_%(name)s);
# //Py_XDECREF(py_%(name)s);
# """
# return self.c_data_sync() + set_in_list
def c_sync(self):
"""
The code returned from this function must be templated using "%(name)s",
......@@ -287,38 +265,3 @@ class ResultBase(object):
def same_properties(self, other):
raise AbstractFunction()
#################
# NumpyR Compatibility
#
up_to_date = property(lambda self: True)
def refresh(self): pass
def set_owner(self, owner, idx):
self.role = (owner, idx)
def set_value(self, value):
self.data = value #may raise exception
# def c_data_extract(self):
# return """
# PyArrayObject* %%(name)s;
# if (py_%%(name)s == Py_None)
# %%(name)s = NULL;
# else
# %%(name)s = (PyArrayObject*)(py_%%(name)s);
# typedef %(dtype)s %%(name)s_dtype;
# """ % dict(dtype = self.dtype)
# def c_data_sync(self):
# return """
# if (!%(name)s) {
# Py_XDECREF(py_%(name));
# py_%(name)s = Py_None;
# }
# else if ((void*)py_%(name)s != (void*)%(name)s) {
# Py_XDECREF(py_%(name));
# py_%(name)s = (PyObject*)%(name)s;
# }
# """
#ifndef _OMEGA_H
#define _OMEGA_H
//#include whatever defines PyArrayObject
template<typename T>
struct TMat_t
{
T * __restrict__ d;/**< pointer to element (0,0) */
size_t M; /**< number of rows */
size_t N; /**< number of columns */
size_t m; /**< row stride */
size_t n; /**< column stride */
bool invalid;
/** null */
TMat_t(const PyArrayObject *o) :
d((double*) o->data),
M((o->nd==2) ? o->dimensions[0] : 0),
N((o->nd==2) ? o->dimensions[1] : 0),
m((o->nd==2) ? o->strides[0] / sizeof(double) : 0),
n((o->nd==2) ? o->strides[1] / sizeof(double) : 0),
invalid((o->nd !=2) || (o->descr->elsize != sizeof(T)))
{
}
/** unsafe element access */
const T & operator()(size_t i, size_t j) const
{
return d[ i * m + j*n];
}
/** unsafe element access */
T & operator()(size_t i, size_t j)
{
return d[ i * m + j*n];
}
/** safe element access */
const T & at(size_t i, size_t j) const
{
return d[ assert((i < M) && (j < N)), i * m + j*n];
}
/** safe element access */
T & at(size_t i, size_t j)
{
return d[ assert((i < M) && (j < N)), i * m + j*n];
}
};
#endif
差异被折叠。
import core
import gof
from numpy import random as r
# def rwrap(f):
# wrapped =
# def ret(self, *args):
class RandomState(gof.Op, gof.ext.IONames):
input_names = ['seed']
def __init__(self, seed):
inputs = [wrap(seed)]
outputs = [ResultValue()]
gof.Op.__init__(self, inputs, outputs)
def thunk(self):
def f():
self.out.storage = r.RandomState(self.seed.storage)
return f
class Random(object):
def __init__(seed):
self.state = core.wrap(seed)
......@@ -23,11 +23,30 @@ col = _broadcastable_pattern([0, 1])
class Tensor(ResultBase):
def __init__(self, dtype, broadcastable, data=None, name=None):
def __init__(self, dtype=None, broadcastable=None, data=None, name=None, constant=False):
if dtype is None or broadcastable is None:
if data is None:
raise TypeError("Provide non-None data to complete the dtype and broadcastable flags.")
dtype = data.dtype
if constant:
broadcastable = [1*(x == 1) for x in data.shape]
else:
broadcastable = [0] * len(data.shape)
self.broadcastable = broadcastable
self.dtype = str(dtype)
self.constant = constant
ResultBase.__init__(self, role = None, data = None, name = name)
def __get_constant(self):
return self._constant
def __set_constant(self, value):
if value:
self.indestructible = True
self._constant = value
constant = property(__get_constant, __set_constant)
def filter(self, data):
arr = numpy.asarray(data, dtype = self.dtype)
for b, s in zip(self.broadcastable, arr.shape):
......
差异被折叠。
from scipy.weave import c_spec, standard_array_spec
class omega_type_converter_extension:
# def provides(self):
# """
# Returns a list of (c_type, name, init_code) tuples that represent variables
# the type converter provides to the user's code.
# """
# return []
# def format_provide(self, x):
# return '%s %s = %s;\n' % x
def declaration_code(self, templatize = 0, inline = 0):
tvars = self.template_vars(inline=inline)
code = '%(py_var)s = %(var_lookup)s;\n' % tvars
code += '%(c_type)s %(name)s = %(var_convert)s;\n' % tvars
return code
def struct_init_code(self):
return "Py_INCREF(py_%s);" % self.name
def struct_cleanup_code(self):
return "Py_DECREF(py_%s);" % self.name
def struct_members_code(self):
tvars = self.template_vars()
res = "PyObject* py_%s;\n" % self.name
res += "%(c_type)s %(name)s;\n" % tvars
return res
def struct_import_code(self):
res = "__STRUCT_P->py_%s = py_%s;\n" % (self.name, self.name)
res += "__STRUCT_P->%s = %s;\n" % (self.name, self.name)
return res
def struct_support_code(self):
return ""
def struct_typedefs(self):
return ""
class int_converter(omega_type_converter_extension, c_spec.int_converter):
pass
class float_converter(omega_type_converter_extension, c_spec.float_converter):
pass
class complex_converter(omega_type_converter_extension, c_spec.complex_converter):
pass
class unicode_converter(omega_type_converter_extension, c_spec.unicode_converter):
pass
# def provides(self):
# tvars = self.template_vars()
# return omega_type_converter_extension.provides() + [('int', 'N%(name)s' % tvars, 'PyUnicode_GET_SIZE(%(py_var)s)' % tvars)]
class string_converter(omega_type_converter_extension, c_spec.string_converter):
pass
class list_converter(omega_type_converter_extension, c_spec.list_converter):
pass
class dict_converter(omega_type_converter_extension, c_spec.dict_converter):
pass
class tuple_converter(omega_type_converter_extension, c_spec.tuple_converter):
pass
class file_converter(omega_type_converter_extension, c_spec.file_converter):
pass
class instance_converter(omega_type_converter_extension, c_spec.instance_converter):
pass
class array_converter(omega_type_converter_extension, standard_array_spec.array_converter):
# def provides(self):
# tvars = self.template_vars()
# ret = []
# ret.append((tvars['c_type'], tvars['array_name'], tvars['var_convert']))
# ret.append(('npy_intp*', 'N%(name)s' % tvars, '%(array_name)s->dimensions' % tvars))
# ret.append(('npy_intp*', 'S%(name)s' % tvars, '%(array_name)s->strides' % tvars))
# ret.append(('int', 'D%(name)s' % tvars, '%(array_name)s->nd' % tvars))
# ret.append(('%(num_type)s*' % tvars, '%(name)s' % tvars, '(%(num_type)s*) %(array_name)s->data' % tvars))
# return ret
# def declaration_code(self, templatize = 0, inline = 0):
# tvars = self.template_vars(inline=inline)
# tvars['cap_name'] = self.name.upper()
# prov = self.provides()
# code = '%(py_var)s = %(var_lookup)s;\n' % tvars
# code += "\n".join(self.format_provide(export) for export in prov[:1])
# code += '\nconversion_numpy_check_type(%(array_name)s,%(num_typecode)s,"%(name)s");\n' % tvars
# code += "\n".join(self.format_provide(export) for export in prov[1:])
# return code
# def struct_support_code(self, templatize = 0, inline = 0):
# tvars = self.template_vars(inline=inline)
# cap_name = self.name.upper()
# tvars['cap_name'] = cap_name
# code = 'inline %(num_type)s& %(cap_name)s1(int i) { return (*((%(num_type)s*)(%(array_name)s->data + (i)*S%(name)s[0])));}\n' \
# 'inline %(num_type)s& %(cap_name)s2(int i, int j) { return (*((%(num_type)s*)(%(array_name)s->data + (i)*S%(name)s[0] + (j)*S%(name)s[1])));}\n' \
# 'inline %(num_type)s& %(cap_name)s3(int i, int j, int k) { return (*((%(num_type)s*)(%(array_name)s->data + (i)*S%(name)s[0] + (j)*S%(name)s[1] + (k)*S%(name)s[2])));}\n' \
# 'inline %(num_type)s& %(cap_name)s4(int i, int j, int k, int l) { return (*((%(num_type)s*)(%(array_name)s->data + (i)*S%(name)s[0] + (j)*S%(name)s[1] + (k)*S%(name)s[2] + (l)*S%(name)s[3])));}\n'
# return code % tvars
def struct_typedefs(self):
tvars = self.template_vars()
return "typedef %(num_type)s %(name)s_dtype;\n" % tvars
# return "\n".join(["typedef %s %s_type;" % (c_type, name)])
# def struct_template_types(self):
# tvars = self.template_vars()
# return [("typename %s_type" % name, c_type) for c_type, name, init in self.provides()] + [("typename %s_dtype" % self.name, tvars['num_type'])]
default = [array_converter(),
int_converter(),
float_converter(),
complex_converter(),
unicode_converter(),
string_converter(),
list_converter(),
dict_converter(),
tuple_converter(),
file_converter(),
instance_converter()]
from core import Numpy2, omega_op
def input(x):
#static member initialization
if not hasattr(input, 'float_dtype'):
input.float_dtype = 'float64'
input.int_dtype = 'int64'
input.NN = Numpy2
if isinstance(x, numpy.ndarray):
#return NumpyR(x)
return input.NN(data=x)
elif isinstance(x, int):
z = numpy.zeros((), dtype = input.int_dtype)
z += x
return input.NN(data=z)
elif isinstance(x, float):
z = numpy.zeros((), dtype = input.float_dtype)
z += x
return input.NN(data=z)
elif is_result(x):
raise TypeError("%s is already a result." % x)
else:
return ResultBase(data=x)
def wrap(x):
if isinstance(x, Numpy2):
return x
#elif isinstance(x, NumpyR):
#return x
elif is_result(x):
return x
elif isinstance(x, omega_op):
return x.out
else:
return literal(x)
def literal(x):
"""Return a ResultValue instance wrapping a literal."""
def _hashable(x):
try:
x in {}
return True
except TypeError: # x is unhashable
return False
#static member initialization
if not hasattr(literal, 'hdb'):
literal.hdb = {}
literal.udb = {}
if _hashable(x):
db = literal.hdb
key = (type(x),x)
else:
db = literal.udb
key = (id(x),)
if key in db:
return db[key]
else:
rval = input(x)
rval.constant = True
db[key] = rval
return rval
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论