cthunks work!!

上级 71ada6d8
import time import time
import gof import gof
import cutils
import core import core
import opt import opt
from copy import copy from copy import copy
def experimental_linker(env, target = None): def experimental_linker(env, target = None):
def fetch(op): def fetch(op):
try: try:
thunk = op.c_thunk() thunk = op.c_thunk_creator()
print "yea %s" % op # print "yea %s" % op
return lambda: cutils.run_cthunk(thunk) return lambda: cutils.run_cthunk(thunk())
except NotImplementedError: except NotImplementedError:
print "nope %s" % op # print "nope %s" % op
return op._perform return op._perform
order = env.toposort() order = env.toposort()
for op in order:
op.refresh()
# for op in order:
# print op
# print 'ispecs: ', [input.spec for input in op.inputs]
# print 'ospecs: ', [output.spec for output in op.outputs]
thunks = [fetch(op) for op in order] thunks = [fetch(op) for op in order]
def ret(): def ret():
for thunk in thunks: for thunk, op in zip(thunks, order):
# print op
# print 'in: ', [id(input.data) for input in op.inputs]
# print 'out:', [id(output.data) for output in op.outputs]
thunk() thunk()
# for thunk in thunks:
# thunk()
if not target: if not target:
return ret return ret
else: else:
...@@ -102,7 +111,7 @@ class prog(gof.Prog): ...@@ -102,7 +111,7 @@ class prog(gof.Prog):
TODO: think about whether orphan computation should be in this function, TODO: think about whether orphan computation should be in this function,
or in self.__call__() or in self.__call__()
""" """
# linker = experimental_linker linker = experimental_linker
new_outputs = gof.mark_outputs_as_destroyed(outputs) new_outputs = gof.mark_outputs_as_destroyed(outputs)
gof.Prog.__init__(self, gof.Prog.__init__(self,
inputs, inputs,
......
差异被折叠。
...@@ -20,6 +20,8 @@ except ImportError: ...@@ -20,6 +20,8 @@ except ImportError:
} }
""" """
cthunk = object() cthunk = object()
mod = weave.ext_tools.ext_module('cutils_ext') mod = weave.ext_tools.ext_module('cutils_ext')
mod.add_function(weave.ext_tools.ext_function('run_cthunk', single_runner, ['cthunk'])) mod.add_function(weave.ext_tools.ext_function('run_cthunk', single_runner, ['cthunk']))
......
...@@ -22,6 +22,7 @@ __all__ = ['UNCOMPUTED', ...@@ -22,6 +22,7 @@ __all__ = ['UNCOMPUTED',
'PythonOp', 'PythonOp',
'PythonOpt', 'PythonOpt',
'COp', 'COp',
'make_static',
'DualImplOp'] 'DualImplOp']
...@@ -29,6 +30,13 @@ UNCOMPUTED = Keyword("UNCOMPUTED", False) ...@@ -29,6 +30,13 @@ UNCOMPUTED = Keyword("UNCOMPUTED", False)
UNDEFINED = Keyword("UNDEFINED", False) UNDEFINED = Keyword("UNDEFINED", False)
def make_static(cls, fname):
f = getattr(cls, fname)
if hasattr(f, 'im_func'):
f = f.im_func
setattr(cls, fname, staticmethod(f))
class ForbidConstantOverwrite(features.Listener, features.Constraint): class ForbidConstantOverwrite(features.Listener, features.Constraint):
def __init__(self, env): def __init__(self, env):
...@@ -75,13 +83,14 @@ class ForbidConstantOverwrite(features.Listener, features.Constraint): ...@@ -75,13 +83,14 @@ class ForbidConstantOverwrite(features.Listener, features.Constraint):
class PythonR(Result): class PythonR(Result):
__slots__ = ['data', 'constant', 'up_to_date'] __slots__ = ['data', 'spec', 'constant', 'up_to_date']
def __init__(self, x = None, constant = False): def __init__(self, x = None, constant = False):
self.constant = False self.constant = False
self.set_value(x) self.set_value(x)
self.constant = constant self.constant = constant
self.up_to_date = True self.up_to_date = True
self.spec = None
def set_value(self, value): def set_value(self, value):
if self.constant: if self.constant:
...@@ -93,6 +102,7 @@ class PythonR(Result): ...@@ -93,6 +102,7 @@ class PythonR(Result):
else: else:
self.data = value self.data = value
self.up_to_date = True self.up_to_date = True
self.refresh()
def __str__(self): def __str__(self):
return str(self.data) return str(self.data)
...@@ -100,10 +110,16 @@ class PythonR(Result): ...@@ -100,10 +110,16 @@ class PythonR(Result):
def __repr__(self): def __repr__(self):
return repr(self.data) return repr(self.data)
def refresh(self):
self.spec = id(self.data)
def alloc(self):
raise TypeError("Cannot allocate following this specification.")
def perform(self): def perform(self):
if self.owner: if self.owner:
self.owner.perform() self.owner.perform()
def compute(self): def compute(self):
if self.owner: if self.owner:
self.owner.compute() self.owner.compute()
...@@ -112,22 +128,19 @@ class PythonR(Result): ...@@ -112,22 +128,19 @@ class PythonR(Result):
class PythonOp(Op): class PythonOp(Op):
__metaclass__ = ClsInit __metaclass__ = ClsInit
__mode__ = ['build_eval']
nout = 1 nout = 1
@staticmethod @staticmethod
def __clsinit__(cls, name, bases, dct): def __clsinit__(cls, name, bases, dct):
# make impl a static method # make impl a static method
impl = cls.impl cls.set_impl(cls.impl)
if hasattr(cls.impl, 'im_func'): make_static(cls, 'specs')
impl = impl.im_func
cls.impl = staticmethod(impl)
def __new__(cls, *inputs, **kwargs): def __new__(cls, *inputs, **kwargs):
op = Op.__new__(cls) op = Op.__new__(cls)
op.__init__(*inputs) op.__init__(*inputs)
mode = kwargs.get('mode', None) or cls.current_mode() mode = kwargs.get('mode', None) or current_mode()
if mode == 'eval': if mode == 'eval':
op.perform() op.perform()
if op.nout == 1: if op.nout == 1:
...@@ -147,33 +160,6 @@ class PythonOp(Op): ...@@ -147,33 +160,6 @@ class PythonOp(Op):
def __validate__(self): def __validate__(self):
for input in self.inputs: for input in self.inputs:
assert isinstance(input, PythonR) assert isinstance(input, PythonR)
@classmethod
def current_mode(cls):
return cls.__mode__[-1]
@classmethod
def set_mode(cls, mode):
cls.__mode__.append(mode)
@classmethod
def build_mode(cls):
cls.set_mode('build')
@classmethod
def eval_mode(cls):
cls.set_mode('eval')
@classmethod
def build_eval_mode(cls):
cls.set_mode('build_eval')
@classmethod
def pop_mode(cls):
if len(cls.__mode__) == 1:
raise Exception("There's only one mode left on the stack.")
else:
cls.__mode__.pop()
def gen_outputs(self): def gen_outputs(self):
return [PythonR() for i in xrange(self.nout)] return [PythonR() for i in xrange(self.nout)]
...@@ -269,10 +255,48 @@ class PythonOp(Op): ...@@ -269,10 +255,48 @@ class PythonOp(Op):
def _impl(self): def _impl(self):
return self.impl(*[input.data for input in self.inputs]) return self.impl(*[input.data for input in self.inputs])
@classmethod
def set_impl(cls, impl):
make_static(cls, 'impl')
# impl = cls.impl
# if hasattr(cls.impl, 'im_func'):
# impl = impl.im_func
# cls.impl = staticmethod(impl)
def impl(*args): def impl(*args):
raise NotImplementedError("This op has no implementation.") raise NotImplementedError("This op has no implementation.")
def _specs(self):
return self.specs(*[input.spec for input in self.inputs])
def specs(*inputs):
raise NotImplementedError("This op cannot infer the specs of its outputs.")
def refresh(self, except_list = []):
for input in self.inputs:
input.refresh()
change = self._propagate_specs()
if change:
self.alloc(except_list)
return change
def _propagate_specs(self):
specs = self._specs()
if self.nout == 1:
specs = [specs]
change = False
for output, spec in zip(self.outputs, specs):
if output.spec != spec:
output.spec = spec
change = True
return change
def alloc(self, except_list = []):
for output in self.outputs:
if output not in except_list:
output.alloc()
__require__ = ForbidConstantOverwrite __require__ = ForbidConstantOverwrite
def __copy__(self): def __copy__(self):
...@@ -297,16 +321,30 @@ class PythonOp(Op): ...@@ -297,16 +321,30 @@ class PythonOp(Op):
return op[0].owner return op[0].owner
return op.owner return op.owner
__mode__ = ['build_eval']
def current_mode():
return __mode__[-1]
def set_mode(mode):
__mode__.append(mode)
def build_mode():
set_mode('build')
def eval_mode():
set_mode('eval')
def build_eval_mode():
set_mode('build_eval')
def pop_mode():
if len(__mode__) == 1:
raise Exception("There's only one mode left on the stack.")
else:
__mode__.pop()
current_mode = PythonOp.current_mode
set_mode = PythonOp.set_mode
build_mode = PythonOp.build_mode
eval_mode = PythonOp.eval_mode
build_eval_mode = PythonOp.build_eval_mode
pop_mode = PythonOp.pop_mode
class PythonOpt(opt.Optimizer): class PythonOpt(opt.Optimizer):
...@@ -315,9 +353,9 @@ class PythonOpt(opt.Optimizer): ...@@ -315,9 +353,9 @@ class PythonOpt(opt.Optimizer):
self.opt = opt self.opt = opt
def optimize(self, env): def optimize(self, env):
PythonOp.build_mode() build_mode()
self.opt.optimize(env) self.opt.optimize(env)
PythonOp.pop_mode() pop_mode()
......
...@@ -153,7 +153,7 @@ class Op(object): ...@@ -153,7 +153,7 @@ class Op(object):
self.set_output(i, previous, False) self.set_output(i, previous, False)
def refresh(self, allow_changes = False): def repair(self, allow_changes = False):
""" """
This function attempts to repair all inputs that are broken This function attempts to repair all inputs that are broken
links by calling set_input on the new Result that replaced links by calling set_input on the new Result that replaced
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论