minor fixes, working on producer in build mode... getting state right

上级 32249295
差异被折叠。
# from op import *
# from value import *
# from opt import *
# from env import *
# from prog import *
# from diff import *
# import dispatchers
import op, ext, lib, link, result, env, prog, features, opt, graph
from op import *
from ext import *
......@@ -18,5 +11,3 @@ from features import *
from opt import *
import graph
#import utils
......@@ -44,6 +44,8 @@ def compute_from(nodes, history):
if hasattr(node, 'owner'): #node is storage
compute_recursive(node.owner)
else: #node is op
if node.destroy_map():
raise ValueError('compute_from() does not work on nodes with destroy_maps')
for input in node.inputs:
compute_recursive(input)
node.perform()
......@@ -95,8 +97,6 @@ class ForbidConstantOverwrite(features.Listener, features.Constraint):
else:
return True
class DestroyHandler(features.Listener, features.Constraint, features.Orderings):
def __init__(self, env):
......@@ -383,13 +383,14 @@ class PythonOp(Op):
return all([ is_result(i) for i in self.inputs])
def gen_outputs(self):
raise NotImplementedError()
raise AbstractFunctionError()
def view_map(self): return {}
def destroy_map(self): return {}
def root_inputs(self, input):
@staticmethod
def root_inputs(input):
owner = input.owner
if owner:
view_map = owner.view_map()
......
......@@ -11,7 +11,7 @@ from err import GofError
from utils import AbstractFunctionError
__all__ = ['is_result', 'ResultBase', 'BrokenLink', 'BrokenLinkError']
__all__ = ['is_result', 'ResultBase', 'BrokenLink', 'BrokenLinkError' ]
class BrokenLink:
......@@ -36,6 +36,10 @@ class BrokenLinkError(GofError):
pass
# ResultBase state keywords
class Empty : pass
class Computed : pass
############################
# Result
......@@ -53,6 +57,7 @@ class ResultBase(object):
_role - None or (owner, index) or BrokenLink
_data - anything
constant - Boolean
state - one of (Empty, Allocated, Computed)
Properties:
role - (rw)
......@@ -60,13 +65,12 @@ class ResultBase(object):
index - (ro)
data - (rw)
replaced - (rw) : True iff _role is BrokenLink
computed - (ro) : True iff contents of data are fresh
Abstract Methods:
data_filter
Notes:
Notes (from previous implementation):
A Result instance should be immutable: indeed, if some aspect of a
Result is changed, operations that use it might suddenly become
......@@ -89,22 +93,28 @@ class ResultBase(object):
class AbstractFunction(Exception):
"""Exception thrown when an abstract function is called"""
__slots__ = ['_role', '_data', 'constant']
__slots__ = ['_role', 'constant', '_data', 'state']
def __init__(self, role=None, data=None, constant=False):
self._role = role
self.constant = constant
if data is None: #None is not filtered
self._data = None
self.state = Empty
else:
try:
self._data = self.data_filter(data)
except ResultBase.AbstractFunction:
self._data = data
self.state = Computed
#
# role
#
#role is pair: (owner, outputs_position)
def __get_role(self):
return self._role
def __set_role(self, role):
owner, index = role
if self._role is not None:
......@@ -116,29 +126,41 @@ class ResultBase(object):
raise ValueError("Result %s was already mapped to a different index." % self)
return # because _owner is owner and _index == index
self._role = role
role = property(__get_role, __set_role)
#owner is role[0]
#
# owner
#
def __get_owner(self):
if self._role is None: return None
if self.replaced: raise ResultBase.BrokenLinkError()
return self._role[0]
owner = property(__get_owner,
doc = "Op of which this Result is an output, or None if role is None")
#index is role[1]
#
# index
#
def __get_index(self):
if self._role is None: return None
if self.replaced: raise ResultBase.BrokenLinkError()
return self._role[1]
index = property(__get_index,
doc = "position of self in owner's outputs, or None if role is None")
# assigning to self.data will invoke self.data_filter(value) if that
# function is defined
#
# data
#
def __get_data(self):
return self._data
def __set_data(self, data):
if self.replaced: raise ResultBase.BrokenLinkError()
if self.constant: raise Exception('cannot set constant ResultBase')
......@@ -146,27 +168,39 @@ class ResultBase(object):
self._data = self.data_filter(data)
except ResultBase.AbstractFunction: #use default behaviour
self._data = data
self.state = Computed
data = property(__get_data, __set_data,
doc = "The storage associated with this result")
def data_filter(self, data):
"""(abstract) Return an appropriate _data based on data."""
"""(abstract) Return an appropriate _data based on data.
If a subclass overrides this function, then that overriding
implementation will be used in __set_data to map the argument to
self._data. This gives a subclass the opportunity to ensure that
the contents of self._data remain sensible.
"""
raise ResultBase.AbstractFunction()
#
# replaced
def __get_replaced(self): return isinstance(self._role, ResultBase.BrokenLink)
#
def __get_replaced(self):
return isinstance(self._role, ResultBase.BrokenLink)
def __set_replaced(self, replace):
if replace == self.replaced: return
if replace:
self._role = ResultBase.BrokenLink(self._role)
else:
self._role = self._role.old_role
replaced = property(__get_replaced, __set_replaced, doc = "has this Result been replaced?")
# computed
#TODO: think about how to handle this more correctly
computed = property(lambda self: self._data is not None)
#################
......
......@@ -77,7 +77,7 @@ class Grad(object):
r.shape, dr.shape))
# prevent 'r' from being re-calculated by self.__call__ in 'build_eval' mode
if r.computed:
if r.state is gof.result.Computed:
self._compute_history.add(r)
# add dr to self[r]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论