moving from specs to refresh, passing tests in compile.py

上级 b747a0cb
......@@ -18,6 +18,7 @@ def experimental_linker(env, target = None):
py_ops = set()
thunks = []
computed_results = []
for op in order:
try:
......@@ -34,6 +35,7 @@ def experimental_linker(env, target = None):
result = op._perform
py_ops.add(op)
thunks.append((result, op._perform_inplace))
computed_results.extend(op.outputs)
def ret():
for thunk, fallback in thunks:
......@@ -41,6 +43,8 @@ def experimental_linker(env, target = None):
thunk()
except NotImplementedError:
fallback()
for r in computed_results:
r.state = gof.result.Computed
if not target:
return ret
......@@ -48,38 +52,6 @@ def experimental_linker(env, target = None):
raise NotImplementedError("Cannot write thunk representation to a file.")
# def experimental_linker(env, target = None):
# def fetch(op):
# try:
# factory = op.c_thunk_factory()
# # print "yea %s" % op
# thunk = factory()
# return lambda: cutils.run_cthunk(thunk)
# except NotImplementedError:
# # print "nope %s" % op
# return op._perform
# order = env.toposort()
# for op in order:
# op.refresh()
# # for op in order:
# # print op
# # print 'ispecs: ', [input.spec for input in op.inputs]
# # print 'ospecs: ', [output.spec for output in op.outputs]
# thunks = [fetch(op) for op in order]
# def ret():
# # print "=================================================="
# # for thunk, op in zip(thunks, order):
# # print op
# # print 'in: ', [id(input.data) for input in op.inputs]
# # print 'out:', [id(output.data) for output in op.outputs]
# # thunk()
# for thunk in thunks:
# thunk()
# if not target:
# return ret
# else:
# raise NotImplementedError("Cannot write thunk representation to a file.")
class profile_linker:
def __init__(self, env):
self.order = env.toposort()
......@@ -201,10 +173,9 @@ def to_func(inputs, outputs):
def single(*outputs, **kwargs):
return prog(gof.graph.inputs(outputs), outputs, **kwargs)
class _test_single(unittest.TestCase):
class _test_single_build_mode(unittest.TestCase):
def setUp(self):
core.build_eval_mode()
core.build_mode()
numpy.random.seed(44)
def tearDown(self):
core.pop_mode()
......@@ -215,27 +186,44 @@ class _test_single(unittest.TestCase):
c = core.add(a,b)
self.failUnless(c.data is None)
self.failUnless(c.state is Empty)
self.failUnless(c.state is gof.result.Empty)
p = single(c)
self.failUnless(c.data is not None)
self.failUnless(c.state is gof.result.Allocated)
self.failUnless(not core._approx_eq(c, a.data + b.data))
p()
self.failUnless(c.state is gof.result.Computed)
self.failUnless(core._approx_eq(c, a.data + b.data))
new_a = numpy.random.rand(2,2)
new_b = numpy.random.rand(2,2)
a.data = new_a
b.data = new_b
p = single(c)
a.data[:] = new_a
b.data[:] = new_b
p()
self.failUnless(core._approx_eq(c, new_a + new_b))
def test_get_element(self):
core.build_eval_mode()
a_data = numpy.random.rand(2,2)
a = core.Numpy2(data=a_data)
a_i = a[0,0]
pos = core.input((0,0))
a_i = core.get_slice(a, pos)
p = single(a_i)
#p()
#print 'aaaa', a_i.owner.out, a_i.owner, a_i.data, pos.data
#print 'pre p()'
for i in 0,1:
for j in 0,1:
pos.data = (i,j)
p()
#print 'asdf', i,j,a_i.data
#print a_i.owner.inputs[1].data
#a_i.owner.inputs[1].data = [i,j]
self.failUnless(a_data[i,j] == a_i.data)
core.pop_mode()
if __name__ == '__main__':
......
差异被折叠。
......@@ -350,7 +350,7 @@ class DestroyHandler(features.Listener, features.Constraint, features.Orderings)
class NewPythonOp(Op):
__env_require__ = DestroyHandler
__env_require__ = DestroyHandler, ForbidConstantOverwrite
def view_map(self):
return {}
......@@ -358,7 +358,6 @@ class NewPythonOp(Op):
def destroy_map(self):
return {}
class PythonOp(NewPythonOp):
__metaclass__ = ClsInit
......@@ -369,10 +368,9 @@ class PythonOp(NewPythonOp):
def __clsinit__(cls, name, bases, dct):
# make impl a static method
cls.set_impl(cls.impl)
make_static(cls, 'specs')
def __new__(cls, *inputs, **kwargs):
op = Op.__new__(cls)
op = NewPythonOp.__new__(cls)
op.__init__(*inputs)
mode = kwargs.get('mode', None) or current_mode()
if mode == 'eval':
......@@ -471,40 +469,6 @@ class PythonOp(NewPythonOp):
def impl(*args):
raise NotImplementedError("This op has no implementation.")
def _specs(self):
try:
return self.specs(*[input.spec for input in self.inputs])
except NotImplementedError:
raise NotImplementedError("%s cannot infer the specs of its outputs" % self.__class__.__name__)
def specs(*inputs):
raise NotImplementedError
def refresh(self, except_list = []):
for input in self.inputs:
input.refresh()
change = self._propagate_specs()
if change:
self.alloc(except_list)
return change
def _propagate_specs(self):
specs = self._specs()
if self.nout == 1:
specs = [specs]
change = False
for output, spec in zip(self.outputs, specs):
if output.spec != spec:
output.spec = spec
change = True
return change
def alloc(self, except_list = []):
for output in self.outputs:
if output not in except_list:
output.alloc()
__env_require__ = ForbidConstantOverwrite
def __copy__(self):
"""
......@@ -577,3 +541,41 @@ class DummyOp(NewPythonOp):
DummyRemover = opt.OpRemover(DummyOp)
if 0:
class RefreshableOp(NewPythonOp):
def _specs(self):
try:
return self.specs(*[input.spec for input in self.inputs])
except NotImplementedError:
raise NotImplementedError("%s cannot infer the specs of its outputs" % self.__class__.__name__)
def specs(*inputs):
raise NotImplementedError
def refresh(self):
"""Update and allocate outputs if necessary"""
for input in self.inputs:
input.refresh()
change = self._propagate_specs()
if change:
self.alloc(except_list)
return change
def _propagate_specs(self):
specs = self._specs()
if self.nout == 1:
specs = [specs]
change = False
for output, spec in zip(self.outputs, specs):
if output.spec != spec:
output.spec = spec
change = True
return change
def alloc(self, except_list = []):
for output in self.outputs:
if output not in except_list:
output.alloc()
......@@ -93,11 +93,11 @@ class ResultBase(object):
def __init__(self, role): self.old_role = role
def __nonzero__(self): return False
class BrokenLinkError(Exception):
"""Exception thrown when an owner is a BrokenLink"""
class BrokenLinkError(Exception):
"""The owner is a BrokenLink"""
class AbstractFunction(Exception):
"""Exception thrown when an abstract function is called"""
class StateError(Exception):
"""The state of the Result is a problem"""
__slots__ = ['_role', 'constant', '_data', 'state']
......@@ -111,7 +111,7 @@ class ResultBase(object):
else:
try:
self._data[0] = self.data_filter(data)
except ResultBase.AbstractFunction:
except AbstractFunctionError:
self._data[0] = data
self.state = Computed
......@@ -175,10 +175,13 @@ class ResultBase(object):
self._data[0] = None
self.state = Empty
return
if data is self or data is self._data[0]: return
try:
self._data[0] = self.data_filter(data)
except ResultBase.AbstractFunction: #use default behaviour
except AbstractFunctionError: #use default behaviour
self._data[0] = data
if isinstance(data, ResultBase):
raise Exception()
self.state = Computed
data = property(__get_data, __set_data,
......@@ -193,14 +196,19 @@ class ResultBase(object):
the contents of self._data remain sensible.
"""
raise ResultBase.AbstractFunction()
raise AbstractFunctionError()
#
# alloc
#
def alloc(self):
"""Create self.data from data_alloc, and set state to Allocated"""
"""Create self.data from data_alloc, and set state to Allocated
Graph routines like the linker will ask Ops to allocate outputs. The
Ops, in turn, usually call this function. Results that are involved in
destroy maps and view maps are exceptions to the usual case.
"""
self.data = self.data_alloc() #might raise exception
self.state = Allocated
......@@ -211,7 +219,7 @@ class ResultBase(object):
implementation will be used in alloc() to produce a data object.
"""
raise ResultBase.AbstractFunction()
raise AbstractFunctionError()
#
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论