minor fixes, working on producer in build mode... getting state right

上级 32249295
差异被折叠。
# from op import * import op, ext, lib, link, result, env, prog, features, opt, graph
# from value import *
# from opt import *
# from env import *
# from prog import *
# from diff import *
# import dispatchers
from op import * from op import *
from ext import * from ext import *
...@@ -18,5 +11,3 @@ from features import * ...@@ -18,5 +11,3 @@ from features import *
from opt import * from opt import *
import graph import graph
#import utils
...@@ -44,6 +44,8 @@ def compute_from(nodes, history): ...@@ -44,6 +44,8 @@ def compute_from(nodes, history):
if hasattr(node, 'owner'): #node is storage if hasattr(node, 'owner'): #node is storage
compute_recursive(node.owner) compute_recursive(node.owner)
else: #node is op else: #node is op
if node.destroy_map():
raise ValueError('compute_from() does not work on nodes with destroy_maps')
for input in node.inputs: for input in node.inputs:
compute_recursive(input) compute_recursive(input)
node.perform() node.perform()
...@@ -95,8 +97,6 @@ class ForbidConstantOverwrite(features.Listener, features.Constraint): ...@@ -95,8 +97,6 @@ class ForbidConstantOverwrite(features.Listener, features.Constraint):
else: else:
return True return True
class DestroyHandler(features.Listener, features.Constraint, features.Orderings): class DestroyHandler(features.Listener, features.Constraint, features.Orderings):
def __init__(self, env): def __init__(self, env):
...@@ -383,13 +383,14 @@ class PythonOp(Op): ...@@ -383,13 +383,14 @@ class PythonOp(Op):
return all([ is_result(i) for i in self.inputs]) return all([ is_result(i) for i in self.inputs])
def gen_outputs(self): def gen_outputs(self):
raise NotImplementedError() raise AbstractFunctionError()
def view_map(self): return {} def view_map(self): return {}
def destroy_map(self): return {} def destroy_map(self): return {}
def root_inputs(self, input): @staticmethod
def root_inputs(input):
owner = input.owner owner = input.owner
if owner: if owner:
view_map = owner.view_map() view_map = owner.view_map()
......
...@@ -11,7 +11,7 @@ from err import GofError ...@@ -11,7 +11,7 @@ from err import GofError
from utils import AbstractFunctionError from utils import AbstractFunctionError
__all__ = ['is_result', 'ResultBase', 'BrokenLink', 'BrokenLinkError'] __all__ = ['is_result', 'ResultBase', 'BrokenLink', 'BrokenLinkError' ]
class BrokenLink: class BrokenLink:
...@@ -36,6 +36,10 @@ class BrokenLinkError(GofError): ...@@ -36,6 +36,10 @@ class BrokenLinkError(GofError):
pass pass
# ResultBase state keywords
class Empty : pass
class Computed : pass
############################ ############################
# Result # Result
...@@ -53,6 +57,7 @@ class ResultBase(object): ...@@ -53,6 +57,7 @@ class ResultBase(object):
_role - None or (owner, index) or BrokenLink _role - None or (owner, index) or BrokenLink
_data - anything _data - anything
constant - Boolean constant - Boolean
state - one of (Empty, Allocated, Computed)
Properties: Properties:
role - (rw) role - (rw)
...@@ -60,13 +65,12 @@ class ResultBase(object): ...@@ -60,13 +65,12 @@ class ResultBase(object):
index - (ro) index - (ro)
data - (rw) data - (rw)
replaced - (rw) : True iff _role is BrokenLink replaced - (rw) : True iff _role is BrokenLink
computed - (ro) : True iff contents of data are fresh
Abstract Methods: Abstract Methods:
data_filter data_filter
Notes: Notes (from previous implementation):
A Result instance should be immutable: indeed, if some aspect of a A Result instance should be immutable: indeed, if some aspect of a
Result is changed, operations that use it might suddenly become Result is changed, operations that use it might suddenly become
...@@ -89,22 +93,28 @@ class ResultBase(object): ...@@ -89,22 +93,28 @@ class ResultBase(object):
class AbstractFunction(Exception): class AbstractFunction(Exception):
"""Exception thrown when an abstract function is called""" """Exception thrown when an abstract function is called"""
__slots__ = ['_role', '_data', 'constant'] __slots__ = ['_role', 'constant', '_data', 'state']
def __init__(self, role=None, data=None, constant=False): def __init__(self, role=None, data=None, constant=False):
self._role = role self._role = role
self.constant = constant self.constant = constant
if data is None: #None is not filtered if data is None: #None is not filtered
self._data = None self._data = None
self.state = Empty
else: else:
try: try:
self._data = self.data_filter(data) self._data = self.data_filter(data)
except ResultBase.AbstractFunction: except ResultBase.AbstractFunction:
self._data = data self._data = data
self.state = Computed
#
# role
#
#role is pair: (owner, outputs_position)
def __get_role(self): def __get_role(self):
return self._role return self._role
def __set_role(self, role): def __set_role(self, role):
owner, index = role owner, index = role
if self._role is not None: if self._role is not None:
...@@ -116,29 +126,41 @@ class ResultBase(object): ...@@ -116,29 +126,41 @@ class ResultBase(object):
raise ValueError("Result %s was already mapped to a different index." % self) raise ValueError("Result %s was already mapped to a different index." % self)
return # because _owner is owner and _index == index return # because _owner is owner and _index == index
self._role = role self._role = role
role = property(__get_role, __set_role) role = property(__get_role, __set_role)
#owner is role[0] #
# owner
#
def __get_owner(self): def __get_owner(self):
if self._role is None: return None if self._role is None: return None
if self.replaced: raise ResultBase.BrokenLinkError() if self.replaced: raise ResultBase.BrokenLinkError()
return self._role[0] return self._role[0]
owner = property(__get_owner, owner = property(__get_owner,
doc = "Op of which this Result is an output, or None if role is None") doc = "Op of which this Result is an output, or None if role is None")
#index is role[1] #
# index
#
def __get_index(self): def __get_index(self):
if self._role is None: return None if self._role is None: return None
if self.replaced: raise ResultBase.BrokenLinkError() if self.replaced: raise ResultBase.BrokenLinkError()
return self._role[1] return self._role[1]
index = property(__get_index, index = property(__get_index,
doc = "position of self in owner's outputs, or None if role is None") doc = "position of self in owner's outputs, or None if role is None")
# assigning to self.data will invoke self.data_filter(value) if that #
# function is defined # data
#
def __get_data(self): def __get_data(self):
return self._data return self._data
def __set_data(self, data): def __set_data(self, data):
if self.replaced: raise ResultBase.BrokenLinkError() if self.replaced: raise ResultBase.BrokenLinkError()
if self.constant: raise Exception('cannot set constant ResultBase') if self.constant: raise Exception('cannot set constant ResultBase')
...@@ -146,27 +168,39 @@ class ResultBase(object): ...@@ -146,27 +168,39 @@ class ResultBase(object):
self._data = self.data_filter(data) self._data = self.data_filter(data)
except ResultBase.AbstractFunction: #use default behaviour except ResultBase.AbstractFunction: #use default behaviour
self._data = data self._data = data
self.state = Computed
data = property(__get_data, __set_data, data = property(__get_data, __set_data,
doc = "The storage associated with this result") doc = "The storage associated with this result")
def data_filter(self, data): def data_filter(self, data):
"""(abstract) Return an appropriate _data based on data.""" """(abstract) Return an appropriate _data based on data.
If a subclass overrides this function, then that overriding
implementation will be used in __set_data to map the argument to
self._data. This gives a subclass the opportunity to ensure that
the contents of self._data remain sensible.
"""
raise ResultBase.AbstractFunction() raise ResultBase.AbstractFunction()
#
# replaced # replaced
def __get_replaced(self): return isinstance(self._role, ResultBase.BrokenLink) #
def __get_replaced(self):
return isinstance(self._role, ResultBase.BrokenLink)
def __set_replaced(self, replace): def __set_replaced(self, replace):
if replace == self.replaced: return if replace == self.replaced: return
if replace: if replace:
self._role = ResultBase.BrokenLink(self._role) self._role = ResultBase.BrokenLink(self._role)
else: else:
self._role = self._role.old_role self._role = self._role.old_role
replaced = property(__get_replaced, __set_replaced, doc = "has this Result been replaced?") replaced = property(__get_replaced, __set_replaced, doc = "has this Result been replaced?")
# computed
#TODO: think about how to handle this more correctly
computed = property(lambda self: self._data is not None)
################# #################
......
...@@ -77,7 +77,7 @@ class Grad(object): ...@@ -77,7 +77,7 @@ class Grad(object):
r.shape, dr.shape)) r.shape, dr.shape))
# prevent 'r' from being re-calculated by self.__call__ in 'build_eval' mode # prevent 'r' from being re-calculated by self.__call__ in 'build_eval' mode
if r.computed: if r.state is gof.result.Computed:
self._compute_history.add(r) self._compute_history.add(r)
# add dr to self[r] # add dr to self[r]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论