moving from specs to refresh, passing tests in compile.py

上级 b747a0cb
...@@ -18,6 +18,7 @@ def experimental_linker(env, target = None): ...@@ -18,6 +18,7 @@ def experimental_linker(env, target = None):
py_ops = set() py_ops = set()
thunks = [] thunks = []
computed_results = []
for op in order: for op in order:
try: try:
...@@ -34,6 +35,7 @@ def experimental_linker(env, target = None): ...@@ -34,6 +35,7 @@ def experimental_linker(env, target = None):
result = op._perform result = op._perform
py_ops.add(op) py_ops.add(op)
thunks.append((result, op._perform_inplace)) thunks.append((result, op._perform_inplace))
computed_results.extend(op.outputs)
def ret(): def ret():
for thunk, fallback in thunks: for thunk, fallback in thunks:
...@@ -41,6 +43,8 @@ def experimental_linker(env, target = None): ...@@ -41,6 +43,8 @@ def experimental_linker(env, target = None):
thunk() thunk()
except NotImplementedError: except NotImplementedError:
fallback() fallback()
for r in computed_results:
r.state = gof.result.Computed
if not target: if not target:
return ret return ret
...@@ -48,38 +52,6 @@ def experimental_linker(env, target = None): ...@@ -48,38 +52,6 @@ def experimental_linker(env, target = None):
raise NotImplementedError("Cannot write thunk representation to a file.") raise NotImplementedError("Cannot write thunk representation to a file.")
# def experimental_linker(env, target = None):
# def fetch(op):
# try:
# factory = op.c_thunk_factory()
# # print "yea %s" % op
# thunk = factory()
# return lambda: cutils.run_cthunk(thunk)
# except NotImplementedError:
# # print "nope %s" % op
# return op._perform
# order = env.toposort()
# for op in order:
# op.refresh()
# # for op in order:
# # print op
# # print 'ispecs: ', [input.spec for input in op.inputs]
# # print 'ospecs: ', [output.spec for output in op.outputs]
# thunks = [fetch(op) for op in order]
# def ret():
# # print "=================================================="
# # for thunk, op in zip(thunks, order):
# # print op
# # print 'in: ', [id(input.data) for input in op.inputs]
# # print 'out:', [id(output.data) for output in op.outputs]
# # thunk()
# for thunk in thunks:
# thunk()
# if not target:
# return ret
# else:
# raise NotImplementedError("Cannot write thunk representation to a file.")
class profile_linker: class profile_linker:
def __init__(self, env): def __init__(self, env):
self.order = env.toposort() self.order = env.toposort()
...@@ -201,10 +173,9 @@ def to_func(inputs, outputs): ...@@ -201,10 +173,9 @@ def to_func(inputs, outputs):
def single(*outputs, **kwargs): def single(*outputs, **kwargs):
return prog(gof.graph.inputs(outputs), outputs, **kwargs) return prog(gof.graph.inputs(outputs), outputs, **kwargs)
class _test_single_build_mode(unittest.TestCase):
class _test_single(unittest.TestCase):
def setUp(self): def setUp(self):
core.build_eval_mode() core.build_mode()
numpy.random.seed(44) numpy.random.seed(44)
def tearDown(self): def tearDown(self):
core.pop_mode() core.pop_mode()
...@@ -215,27 +186,44 @@ class _test_single(unittest.TestCase): ...@@ -215,27 +186,44 @@ class _test_single(unittest.TestCase):
c = core.add(a,b) c = core.add(a,b)
self.failUnless(c.data is None) self.failUnless(c.data is None)
self.failUnless(c.state is Empty) self.failUnless(c.state is gof.result.Empty)
p = single(c)
self.failUnless(c.data is not None)
self.failUnless(c.state is gof.result.Allocated)
self.failUnless(not core._approx_eq(c, a.data + b.data))
p()
self.failUnless(c.state is gof.result.Computed)
self.failUnless(core._approx_eq(c, a.data + b.data))
new_a = numpy.random.rand(2,2) new_a = numpy.random.rand(2,2)
new_b = numpy.random.rand(2,2) new_b = numpy.random.rand(2,2)
a.data = new_a a.data[:] = new_a
b.data = new_b b.data[:] = new_b
p = single(c)
p() p()
self.failUnless(core._approx_eq(c, new_a + new_b)) self.failUnless(core._approx_eq(c, new_a + new_b))
def test_get_element(self): def test_get_element(self):
core.build_eval_mode()
a_data = numpy.random.rand(2,2) a_data = numpy.random.rand(2,2)
a = core.Numpy2(data=a_data) a = core.Numpy2(data=a_data)
a_i = a[0,0] pos = core.input((0,0))
a_i = core.get_slice(a, pos)
p = single(a_i) p = single(a_i)
#p()
#print 'aaaa', a_i.owner.out, a_i.owner, a_i.data, pos.data
#print 'pre p()'
for i in 0,1: for i in 0,1:
for j in 0,1: for j in 0,1:
pos.data = (i,j)
p() p()
#print 'asdf', i,j,a_i.data
#print a_i.owner.inputs[1].data
#a_i.owner.inputs[1].data = [i,j]
self.failUnless(a_data[i,j] == a_i.data) self.failUnless(a_data[i,j] == a_i.data)
core.pop_mode()
if __name__ == '__main__': if __name__ == '__main__':
......
差异被折叠。
...@@ -350,7 +350,7 @@ class DestroyHandler(features.Listener, features.Constraint, features.Orderings) ...@@ -350,7 +350,7 @@ class DestroyHandler(features.Listener, features.Constraint, features.Orderings)
class NewPythonOp(Op): class NewPythonOp(Op):
__env_require__ = DestroyHandler __env_require__ = DestroyHandler, ForbidConstantOverwrite
def view_map(self): def view_map(self):
return {} return {}
...@@ -358,7 +358,6 @@ class NewPythonOp(Op): ...@@ -358,7 +358,6 @@ class NewPythonOp(Op):
def destroy_map(self): def destroy_map(self):
return {} return {}
class PythonOp(NewPythonOp): class PythonOp(NewPythonOp):
__metaclass__ = ClsInit __metaclass__ = ClsInit
...@@ -369,10 +368,9 @@ class PythonOp(NewPythonOp): ...@@ -369,10 +368,9 @@ class PythonOp(NewPythonOp):
def __clsinit__(cls, name, bases, dct): def __clsinit__(cls, name, bases, dct):
# make impl a static method # make impl a static method
cls.set_impl(cls.impl) cls.set_impl(cls.impl)
make_static(cls, 'specs')
def __new__(cls, *inputs, **kwargs): def __new__(cls, *inputs, **kwargs):
op = Op.__new__(cls) op = NewPythonOp.__new__(cls)
op.__init__(*inputs) op.__init__(*inputs)
mode = kwargs.get('mode', None) or current_mode() mode = kwargs.get('mode', None) or current_mode()
if mode == 'eval': if mode == 'eval':
...@@ -471,40 +469,6 @@ class PythonOp(NewPythonOp): ...@@ -471,40 +469,6 @@ class PythonOp(NewPythonOp):
def impl(*args): def impl(*args):
raise NotImplementedError("This op has no implementation.") raise NotImplementedError("This op has no implementation.")
def _specs(self):
try:
return self.specs(*[input.spec for input in self.inputs])
except NotImplementedError:
raise NotImplementedError("%s cannot infer the specs of its outputs" % self.__class__.__name__)
def specs(*inputs):
raise NotImplementedError
def refresh(self, except_list = []):
for input in self.inputs:
input.refresh()
change = self._propagate_specs()
if change:
self.alloc(except_list)
return change
def _propagate_specs(self):
specs = self._specs()
if self.nout == 1:
specs = [specs]
change = False
for output, spec in zip(self.outputs, specs):
if output.spec != spec:
output.spec = spec
change = True
return change
def alloc(self, except_list = []):
for output in self.outputs:
if output not in except_list:
output.alloc()
__env_require__ = ForbidConstantOverwrite
def __copy__(self): def __copy__(self):
""" """
...@@ -577,3 +541,41 @@ class DummyOp(NewPythonOp): ...@@ -577,3 +541,41 @@ class DummyOp(NewPythonOp):
DummyRemover = opt.OpRemover(DummyOp) DummyRemover = opt.OpRemover(DummyOp)
if 0:
class RefreshableOp(NewPythonOp):
def _specs(self):
try:
return self.specs(*[input.spec for input in self.inputs])
except NotImplementedError:
raise NotImplementedError("%s cannot infer the specs of its outputs" % self.__class__.__name__)
def specs(*inputs):
raise NotImplementedError
def refresh(self):
"""Update and allocate outputs if necessary"""
for input in self.inputs:
input.refresh()
change = self._propagate_specs()
if change:
self.alloc(except_list)
return change
def _propagate_specs(self):
specs = self._specs()
if self.nout == 1:
specs = [specs]
change = False
for output, spec in zip(self.outputs, specs):
if output.spec != spec:
output.spec = spec
change = True
return change
def alloc(self, except_list = []):
for output in self.outputs:
if output not in except_list:
output.alloc()
...@@ -93,11 +93,11 @@ class ResultBase(object): ...@@ -93,11 +93,11 @@ class ResultBase(object):
def __init__(self, role): self.old_role = role def __init__(self, role): self.old_role = role
def __nonzero__(self): return False def __nonzero__(self): return False
class BrokenLinkError(Exception): class BrokenLinkError(Exception):
"""Exception thrown when an owner is a BrokenLink""" """The owner is a BrokenLink"""
class AbstractFunction(Exception): class StateError(Exception):
"""Exception thrown when an abstract function is called""" """The state of the Result is a problem"""
__slots__ = ['_role', 'constant', '_data', 'state'] __slots__ = ['_role', 'constant', '_data', 'state']
...@@ -111,7 +111,7 @@ class ResultBase(object): ...@@ -111,7 +111,7 @@ class ResultBase(object):
else: else:
try: try:
self._data[0] = self.data_filter(data) self._data[0] = self.data_filter(data)
except ResultBase.AbstractFunction: except AbstractFunctionError:
self._data[0] = data self._data[0] = data
self.state = Computed self.state = Computed
...@@ -175,10 +175,13 @@ class ResultBase(object): ...@@ -175,10 +175,13 @@ class ResultBase(object):
self._data[0] = None self._data[0] = None
self.state = Empty self.state = Empty
return return
if data is self or data is self._data[0]: return
try: try:
self._data[0] = self.data_filter(data) self._data[0] = self.data_filter(data)
except ResultBase.AbstractFunction: #use default behaviour except AbstractFunctionError: #use default behaviour
self._data[0] = data self._data[0] = data
if isinstance(data, ResultBase):
raise Exception()
self.state = Computed self.state = Computed
data = property(__get_data, __set_data, data = property(__get_data, __set_data,
...@@ -193,14 +196,19 @@ class ResultBase(object): ...@@ -193,14 +196,19 @@ class ResultBase(object):
the contents of self._data remain sensible. the contents of self._data remain sensible.
""" """
raise ResultBase.AbstractFunction() raise AbstractFunctionError()
# #
# alloc # alloc
# #
def alloc(self): def alloc(self):
"""Create self.data from data_alloc, and set state to Allocated""" """Create self.data from data_alloc, and set state to Allocated
Graph routines like the linker will ask Ops to allocate outputs. The
Ops, in turn, usually call this function. Results that are involved in
destroy maps and view maps are exceptions to the usual case.
"""
self.data = self.data_alloc() #might raise exception self.data = self.data_alloc() #might raise exception
self.state = Allocated self.state = Allocated
...@@ -211,7 +219,7 @@ class ResultBase(object): ...@@ -211,7 +219,7 @@ class ResultBase(object):
implementation will be used in alloc() to produce a data object. implementation will be used in alloc() to produce a data object.
""" """
raise ResultBase.AbstractFunction() raise AbstractFunctionError()
# #
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论