minor fixes, working on producer in build mode... getting state right

上级 32249295
...@@ -84,6 +84,8 @@ def _compile_dir(): ...@@ -84,6 +84,8 @@ def _compile_dir():
sys.path.append(cachedir) sys.path.append(cachedir)
return cachedir return cachedir
class Allocated:
"""Memory has been allocated, but contents are not the owner's output."""
class Numpy2(ResultBase): class Numpy2(ResultBase):
"""Result storing a numpy ndarray""" """Result storing a numpy ndarray"""
__slots__ = ['_dtype', '_shape', ] __slots__ = ['_dtype', '_shape', ]
...@@ -121,6 +123,7 @@ class Numpy2(ResultBase): ...@@ -121,6 +123,7 @@ class Numpy2(ResultBase):
def data_alloc(self): def data_alloc(self):
self.data = numpy.ndarray(self.shape, self.dtype) self.data = numpy.ndarray(self.shape, self.dtype)
self.state = Allocated
# self._dtype is used when self._data hasn't been set yet # self._dtype is used when self._data hasn't been set yet
def __dtype_get(self): def __dtype_get(self):
...@@ -351,15 +354,15 @@ class _testCase_literal(unittest.TestCase): ...@@ -351,15 +354,15 @@ class _testCase_literal(unittest.TestCase):
def cgetspecs(names, vals, converters): def cgen(name, behavior, names, vals, converters = None):
def cgetspecs(names, vals, converters):
d = {} d = {}
for name, value in zip(names, vals): for name, value in zip(names, vals):
d[name] = value.data d[name] = value.data
specs = weave.ext_tools.assign_variable_types(names, d, type_converters = converters) #, auto_downcast = 0) specs = weave.ext_tools.assign_variable_types(names, d, type_converters = converters) #, auto_downcast = 0)
return d, specs return d, specs
def cgen(name, behavior, names, vals, converters = None):
if not converters: if not converters:
converters = type_spec.default converters = type_spec.default
for converter in converters: for converter in converters:
...@@ -872,6 +875,27 @@ array = wrap_producer(numpy.array) ...@@ -872,6 +875,27 @@ array = wrap_producer(numpy.array)
zeros = wrap_producer(numpy.zeros) zeros = wrap_producer(numpy.zeros)
ones = wrap_producer(numpy.ones) ones = wrap_producer(numpy.ones)
class _testCase_producer_build_mode(unittest.TestCase):
def test_0(self):
"""producer in build mode"""
build_mode()
a = ones(4)
self.failUnless(a.data is None)
self.failUnless(a.state is gof.result.Empty)
self.failUnless(a.shape == (4,))
self.failUnless(a.dtype == 'float64')
pop_mode()
def test_1(self):
"""producer in build_eval mode"""
build_eval_mode()
a = ones(4)
self.failUnless((a.data == numpy.ones(4)).all())
self.failUnless(a.state is gof.result.Computed)
self.failUnless(a.shape == (4,))
self.failUnless(a.dtype == 'float64')
pop_mode()
# Wrapper to ensure that all inputs to the function impl have the same size (foils numpy's broadcasting) # Wrapper to ensure that all inputs to the function impl have the same size (foils numpy's broadcasting)
def assert_same_shapes(impl): def assert_same_shapes(impl):
...@@ -923,6 +947,8 @@ add_elemwise_inplace = add_elemwise.inplace_version() ...@@ -923,6 +947,8 @@ add_elemwise_inplace = add_elemwise.inplace_version()
add_elemwise_inplace.set_impl(assert_same_shapes(numpy.ndarray.__iadd__)) add_elemwise_inplace.set_impl(assert_same_shapes(numpy.ndarray.__iadd__))
class add_scalar(tensor_scalar_op): class add_scalar(tensor_scalar_op):
impl = tensor_scalar_impl(numpy.ndarray.__add__) impl = tensor_scalar_impl(numpy.ndarray.__add__)
def grad(x, a, gz): def grad(x, a, gz):
...@@ -932,6 +958,13 @@ class add_scalar(tensor_scalar_op): ...@@ -932,6 +958,13 @@ class add_scalar(tensor_scalar_op):
add_scalar_inplace = add_scalar.inplace_version() add_scalar_inplace = add_scalar.inplace_version()
add_scalar_inplace.set_impl(tensor_scalar_impl(numpy.ndarray.__iadd__)) add_scalar_inplace.set_impl(tensor_scalar_impl(numpy.ndarray.__iadd__))
class _testCase_add_build_mode(unittest.TestCase):
def setUp(self):
build_mode()
numpy.random.seed(44)
def tearDown(self):
pop_mode()
class twice(elemwise): class twice(elemwise):
def impl(x): def impl(x):
return 2.0 * x return 2.0 * x
...@@ -1100,10 +1133,7 @@ class dot(omega_op): ...@@ -1100,10 +1133,7 @@ class dot(omega_op):
def c_impl((_x, _y), (_z, )): def c_impl((_x, _y), (_z, )):
return blas.gemm_code('', '1.0', '0.0') return blas.gemm_code('', '1.0', '0.0')
if 0: class _testCase_dot(unittest.TestCase):
print 'SKIPPING DOT TESTS'
else:
class _testCase_dot(unittest.TestCase):
def setUp(self): def setUp(self):
build_eval_mode() build_eval_mode()
numpy.random.seed(44) numpy.random.seed(44)
......
# from op import * import op, ext, lib, link, result, env, prog, features, opt, graph
# from value import *
# from opt import *
# from env import *
# from prog import *
# from diff import *
# import dispatchers
from op import * from op import *
from ext import * from ext import *
...@@ -18,5 +11,3 @@ from features import * ...@@ -18,5 +11,3 @@ from features import *
from opt import * from opt import *
import graph import graph
#import utils
...@@ -44,6 +44,8 @@ def compute_from(nodes, history): ...@@ -44,6 +44,8 @@ def compute_from(nodes, history):
if hasattr(node, 'owner'): #node is storage if hasattr(node, 'owner'): #node is storage
compute_recursive(node.owner) compute_recursive(node.owner)
else: #node is op else: #node is op
if node.destroy_map():
raise ValueError('compute_from() does not work on nodes with destroy_maps')
for input in node.inputs: for input in node.inputs:
compute_recursive(input) compute_recursive(input)
node.perform() node.perform()
...@@ -95,8 +97,6 @@ class ForbidConstantOverwrite(features.Listener, features.Constraint): ...@@ -95,8 +97,6 @@ class ForbidConstantOverwrite(features.Listener, features.Constraint):
else: else:
return True return True
class DestroyHandler(features.Listener, features.Constraint, features.Orderings): class DestroyHandler(features.Listener, features.Constraint, features.Orderings):
def __init__(self, env): def __init__(self, env):
...@@ -383,13 +383,14 @@ class PythonOp(Op): ...@@ -383,13 +383,14 @@ class PythonOp(Op):
return all([ is_result(i) for i in self.inputs]) return all([ is_result(i) for i in self.inputs])
def gen_outputs(self): def gen_outputs(self):
raise NotImplementedError() raise AbstractFunctionError()
def view_map(self): return {} def view_map(self): return {}
def destroy_map(self): return {} def destroy_map(self): return {}
def root_inputs(self, input): @staticmethod
def root_inputs(input):
owner = input.owner owner = input.owner
if owner: if owner:
view_map = owner.view_map() view_map = owner.view_map()
......
...@@ -11,7 +11,7 @@ from err import GofError ...@@ -11,7 +11,7 @@ from err import GofError
from utils import AbstractFunctionError from utils import AbstractFunctionError
__all__ = ['is_result', 'ResultBase', 'BrokenLink', 'BrokenLinkError'] __all__ = ['is_result', 'ResultBase', 'BrokenLink', 'BrokenLinkError' ]
class BrokenLink: class BrokenLink:
...@@ -36,6 +36,10 @@ class BrokenLinkError(GofError): ...@@ -36,6 +36,10 @@ class BrokenLinkError(GofError):
pass pass
# ResultBase state keywords
class Empty : pass
class Computed : pass
############################ ############################
# Result # Result
...@@ -53,6 +57,7 @@ class ResultBase(object): ...@@ -53,6 +57,7 @@ class ResultBase(object):
_role - None or (owner, index) or BrokenLink _role - None or (owner, index) or BrokenLink
_data - anything _data - anything
constant - Boolean constant - Boolean
state - one of (Empty, Allocated, Computed)
Properties: Properties:
role - (rw) role - (rw)
...@@ -60,13 +65,12 @@ class ResultBase(object): ...@@ -60,13 +65,12 @@ class ResultBase(object):
index - (ro) index - (ro)
data - (rw) data - (rw)
replaced - (rw) : True iff _role is BrokenLink replaced - (rw) : True iff _role is BrokenLink
computed - (ro) : True iff contents of data are fresh
Abstract Methods: Abstract Methods:
data_filter data_filter
Notes: Notes (from previous implementation):
A Result instance should be immutable: indeed, if some aspect of a A Result instance should be immutable: indeed, if some aspect of a
Result is changed, operations that use it might suddenly become Result is changed, operations that use it might suddenly become
...@@ -89,22 +93,28 @@ class ResultBase(object): ...@@ -89,22 +93,28 @@ class ResultBase(object):
class AbstractFunction(Exception): class AbstractFunction(Exception):
"""Exception thrown when an abstract function is called""" """Exception thrown when an abstract function is called"""
__slots__ = ['_role', '_data', 'constant'] __slots__ = ['_role', 'constant', '_data', 'state']
def __init__(self, role=None, data=None, constant=False): def __init__(self, role=None, data=None, constant=False):
self._role = role self._role = role
self.constant = constant self.constant = constant
if data is None: #None is not filtered if data is None: #None is not filtered
self._data = None self._data = None
self.state = Empty
else: else:
try: try:
self._data = self.data_filter(data) self._data = self.data_filter(data)
except ResultBase.AbstractFunction: except ResultBase.AbstractFunction:
self._data = data self._data = data
self.state = Computed
#
# role
#
#role is pair: (owner, outputs_position)
def __get_role(self): def __get_role(self):
return self._role return self._role
def __set_role(self, role): def __set_role(self, role):
owner, index = role owner, index = role
if self._role is not None: if self._role is not None:
...@@ -116,29 +126,41 @@ class ResultBase(object): ...@@ -116,29 +126,41 @@ class ResultBase(object):
raise ValueError("Result %s was already mapped to a different index." % self) raise ValueError("Result %s was already mapped to a different index." % self)
return # because _owner is owner and _index == index return # because _owner is owner and _index == index
self._role = role self._role = role
role = property(__get_role, __set_role) role = property(__get_role, __set_role)
#owner is role[0] #
# owner
#
def __get_owner(self): def __get_owner(self):
if self._role is None: return None if self._role is None: return None
if self.replaced: raise ResultBase.BrokenLinkError() if self.replaced: raise ResultBase.BrokenLinkError()
return self._role[0] return self._role[0]
owner = property(__get_owner, owner = property(__get_owner,
doc = "Op of which this Result is an output, or None if role is None") doc = "Op of which this Result is an output, or None if role is None")
#index is role[1] #
# index
#
def __get_index(self): def __get_index(self):
if self._role is None: return None if self._role is None: return None
if self.replaced: raise ResultBase.BrokenLinkError() if self.replaced: raise ResultBase.BrokenLinkError()
return self._role[1] return self._role[1]
index = property(__get_index, index = property(__get_index,
doc = "position of self in owner's outputs, or None if role is None") doc = "position of self in owner's outputs, or None if role is None")
# assigning to self.data will invoke self.data_filter(value) if that #
# function is defined # data
#
def __get_data(self): def __get_data(self):
return self._data return self._data
def __set_data(self, data): def __set_data(self, data):
if self.replaced: raise ResultBase.BrokenLinkError() if self.replaced: raise ResultBase.BrokenLinkError()
if self.constant: raise Exception('cannot set constant ResultBase') if self.constant: raise Exception('cannot set constant ResultBase')
...@@ -146,27 +168,39 @@ class ResultBase(object): ...@@ -146,27 +168,39 @@ class ResultBase(object):
self._data = self.data_filter(data) self._data = self.data_filter(data)
except ResultBase.AbstractFunction: #use default behaviour except ResultBase.AbstractFunction: #use default behaviour
self._data = data self._data = data
self.state = Computed
data = property(__get_data, __set_data, data = property(__get_data, __set_data,
doc = "The storage associated with this result") doc = "The storage associated with this result")
def data_filter(self, data): def data_filter(self, data):
"""(abstract) Return an appropriate _data based on data.""" """(abstract) Return an appropriate _data based on data.
If a subclass overrides this function, then that overriding
implementation will be used in __set_data to map the argument to
self._data. This gives a subclass the opportunity to ensure that
the contents of self._data remain sensible.
"""
raise ResultBase.AbstractFunction() raise ResultBase.AbstractFunction()
#
# replaced # replaced
def __get_replaced(self): return isinstance(self._role, ResultBase.BrokenLink) #
def __get_replaced(self):
return isinstance(self._role, ResultBase.BrokenLink)
def __set_replaced(self, replace): def __set_replaced(self, replace):
if replace == self.replaced: return if replace == self.replaced: return
if replace: if replace:
self._role = ResultBase.BrokenLink(self._role) self._role = ResultBase.BrokenLink(self._role)
else: else:
self._role = self._role.old_role self._role = self._role.old_role
replaced = property(__get_replaced, __set_replaced, doc = "has this Result been replaced?") replaced = property(__get_replaced, __set_replaced, doc = "has this Result been replaced?")
# computed
#TODO: think about how to handle this more correctly
computed = property(lambda self: self._data is not None)
################# #################
......
...@@ -77,7 +77,7 @@ class Grad(object): ...@@ -77,7 +77,7 @@ class Grad(object):
r.shape, dr.shape)) r.shape, dr.shape))
# prevent 'r' from being re-calculated by self.__call__ in 'build_eval' mode # prevent 'r' from being re-calculated by self.__call__ in 'build_eval' mode
if r.computed: if r.state is gof.result.Computed:
self._compute_history.add(r) self._compute_history.add(r)
# add dr to self[r] # add dr to self[r]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论