producers passing tests... working in general?

上级 d4e705ea
...@@ -84,8 +84,6 @@ def _compile_dir(): ...@@ -84,8 +84,6 @@ def _compile_dir():
sys.path.append(cachedir) sys.path.append(cachedir)
return cachedir return cachedir
class Allocated:
"""Memory has been allocated, but contents are not the owner's output."""
class Numpy2(ResultBase): class Numpy2(ResultBase):
"""Result storing a numpy ndarray""" """Result storing a numpy ndarray"""
__slots__ = ['_dtype', '_shape', ] __slots__ = ['_dtype', '_shape', ]
...@@ -122,8 +120,7 @@ class Numpy2(ResultBase): ...@@ -122,8 +120,7 @@ class Numpy2(ResultBase):
__array_struct__ = property(lambda self: self._data.__array_struct__ ) __array_struct__ = property(lambda self: self._data.__array_struct__ )
def data_alloc(self): def data_alloc(self):
self.data = numpy.ndarray(self.shape, self.dtype) return numpy.ndarray(self.shape, self.dtype)
self.state = Allocated
# self._dtype is used when self._data hasn't been set yet # self._dtype is used when self._data hasn't been set yet
def __dtype_get(self): def __dtype_get(self):
...@@ -211,7 +208,7 @@ class _test_Numpy2(unittest.TestCase): ...@@ -211,7 +208,7 @@ class _test_Numpy2(unittest.TestCase):
self.failUnless(r.data is None) self.failUnless(r.data is None)
self.failUnless(r.shape == (3,3)) self.failUnless(r.shape == (3,3))
self.failUnless(str(r.dtype) == 'int32') self.failUnless(str(r.dtype) == 'int32')
r.data_alloc() r.alloc()
self.failUnless(isinstance(r.data, numpy.ndarray)) self.failUnless(isinstance(r.data, numpy.ndarray))
self.failUnless(r.shape == (3,3)) self.failUnless(r.shape == (3,3))
self.failUnless(str(r.dtype) == 'int32') self.failUnless(str(r.dtype) == 'int32')
...@@ -443,6 +440,8 @@ class omega_op(gof.PythonOp): ...@@ -443,6 +440,8 @@ class omega_op(gof.PythonOp):
def gen_outputs(self): def gen_outputs(self):
return [Numpy2() for i in xrange(self.nout)] return [Numpy2() for i in xrange(self.nout)]
#TODO: use the version of this code that is in grad.py
# requires: eliminating module dependency cycles
def update_gradient(self, grad_d): def update_gradient(self, grad_d):
"""Call self.grad() and add the result to grad_d """Call self.grad() and add the result to grad_d
...@@ -861,13 +860,25 @@ def scalar_switch(normal_f, scalar_f, scalar_f_reverse = None): ...@@ -861,13 +860,25 @@ def scalar_switch(normal_f, scalar_f, scalar_f_reverse = None):
from grad import Undefined from grad import Undefined
def wrap_producer(f): def wrap_producer(f):
class producer(omega_op): class producer(gof.lib.NewPythonOp):
def __init__(self, shape, dtype, order):
assert order == 'C' #TODO: let Numpy2 support this
if current_mode() == 'build_eval':
gof.lib.NewPythonOp.__init__(self,
[input(shape), input(dtype), input(order)],
[Numpy2(data = f(shape, dtype))])
elif current_mode() == 'build':
gof.lib.NewPythonOp.__init__(self,
[input(shape), input(dtype), input(order)],
[Numpy2(data = (shape, dtype))])
def gen_outputs(self):
return [Numpy2() for i in xrange(self.nout)]
impl = f impl = f
def grad(*args): def grad(*args):
return [Undefined] * (len(args) - 1) return [Undefined] * (len(args) - 1)
producer.__name__ = f.__name__ producer.__name__ = f.__name__
def ret(dim, dtype = 'float', order = 'C'): def ret(shape, dtype = 'float64', order = 'C'):
return producer(dim, dtype, order) return producer(shape, dtype, order).out
return ret return ret
ndarray = wrap_producer(numpy.ndarray) ndarray = wrap_producer(numpy.ndarray)
...@@ -880,19 +891,19 @@ class _testCase_producer_build_mode(unittest.TestCase): ...@@ -880,19 +891,19 @@ class _testCase_producer_build_mode(unittest.TestCase):
"""producer in build mode""" """producer in build mode"""
build_mode() build_mode()
a = ones(4) a = ones(4)
self.failUnless(a.data is None) self.failUnless(a.data is None, a.data)
self.failUnless(a.state is gof.result.Empty) self.failUnless(a.state is gof.result.Empty, a.state)
self.failUnless(a.shape == (4,)) self.failUnless(a.shape == 4, a.shape)
self.failUnless(a.dtype == 'float64') self.failUnless(str(a.dtype) == 'float64', a.dtype)
pop_mode() pop_mode()
def test_1(self): def test_1(self):
"""producer in build_eval mode""" """producer in build_eval mode"""
build_eval_mode() build_eval_mode()
a = ones(4) a = ones(4)
self.failUnless((a.data == numpy.ones(4)).all()) self.failUnless((a.data == numpy.ones(4)).all(), a.data)
self.failUnless(a.state is gof.result.Computed) self.failUnless(a.state is gof.result.Computed, a.state)
self.failUnless(a.shape == (4,)) self.failUnless(a.shape == (4,), a.shape)
self.failUnless(a.dtype == 'float64') self.failUnless(str(a.dtype) == 'float64', a.dtype)
pop_mode() pop_mode()
......
...@@ -345,7 +345,16 @@ class DestroyHandler(features.Listener, features.Constraint, features.Orderings) ...@@ -345,7 +345,16 @@ class DestroyHandler(features.Listener, features.Constraint, features.Orderings)
return ords return ords
class PythonOp(Op): class NewPythonOp(Op):
def view_map(self):
return {}
def destroy_map(self):
return {}
class PythonOp(NewPythonOp):
__metaclass__ = ClsInit __metaclass__ = ClsInit
...@@ -377,7 +386,7 @@ class PythonOp(Op): ...@@ -377,7 +386,7 @@ class PythonOp(Op):
return op.outputs return op.outputs
def __init__(self, *inputs): def __init__(self, *inputs):
Op.__init__(self, inputs, self.gen_outputs()) NewPythonOp.__init__(self, inputs, self.gen_outputs())
def __validate__(self): def __validate__(self):
return all([ is_result(i) for i in self.inputs]) return all([ is_result(i) for i in self.inputs])
...@@ -385,10 +394,6 @@ class PythonOp(Op): ...@@ -385,10 +394,6 @@ class PythonOp(Op):
def gen_outputs(self): def gen_outputs(self):
raise AbstractFunctionError() raise AbstractFunctionError()
def view_map(self): return {}
def destroy_map(self): return {}
@staticmethod @staticmethod
def root_inputs(input): def root_inputs(input):
owner = input.owner owner = input.owner
...@@ -445,21 +450,20 @@ class PythonOp(Op): ...@@ -445,21 +450,20 @@ class PythonOp(Op):
else: else:
results = self._impl() results = self._impl()
if self.nout == 1: if self.nout == 1:
self.out.set_value(results) self.out.data = results
else: else:
assert self.nout == len(results) assert self.nout == len(results)
for result, output in zip(results, self.outputs): for result, output in zip(results, self.outputs):
output.set_value(result) output.data = result
def _perform(self): def _perform(self):
results = self._impl() results = self._impl()
if self.nout == 1: if self.nout == 1:
self.out.set_value(results) self.out.data = results
# self.outputs[0].data = results
else: else:
assert self.nout == len(results) assert self.nout == len(results)
for result, output in zip(results, self.outputs): for result, output in zip(results, self.outputs):
output.set_value(result) output.data = result
def _perform_inplace(self): def _perform_inplace(self):
results = self._impl() results = self._impl()
......
...@@ -37,8 +37,9 @@ class BrokenLinkError(GofError): ...@@ -37,8 +37,9 @@ class BrokenLinkError(GofError):
# ResultBase state keywords # ResultBase state keywords
class Empty : pass class Empty : """Memory has not been allocated"""
class Computed : pass class Allocated: """Memory has been allocated, contents are not the owner's output."""
class Computed : """Memory has been allocated, contents are the owner's output."""
############################ ############################
...@@ -63,11 +64,16 @@ class ResultBase(object): ...@@ -63,11 +64,16 @@ class ResultBase(object):
role - (rw) role - (rw)
owner - (ro) owner - (ro)
index - (ro) index - (ro)
data - (rw) data - (rw) : calls data_filter when setting
replaced - (rw) : True iff _role is BrokenLink replaced - (rw) : True iff _role is BrokenLink
Methods:
alloc() - create storage in data, suitable for use by C ops.
(calls data_alloc)
Abstract Methods: Abstract Methods:
data_filter data_filter
data_alloc
Notes (from previous implementation): Notes (from previous implementation):
...@@ -184,6 +190,24 @@ class ResultBase(object): ...@@ -184,6 +190,24 @@ class ResultBase(object):
""" """
raise ResultBase.AbstractFunction() raise ResultBase.AbstractFunction()
#
# alloc
#
def alloc(self):
"""Create self.data from data_alloc, and set state to Allocated"""
self.data = self.data_alloc() #might raise exception
self.state = Allocated
def data_alloc(self):
"""(abstract) Return an appropriate _data based on self.
If a subclass overrides this function, then that overriding
implementation will be used in alloc() to produce a data object.
"""
raise ResultBase.AbstractFunction()
# #
# replaced # replaced
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论