minor fixes, working on producer in build mode... getting state right

上级 32249295
...@@ -84,6 +84,8 @@ def _compile_dir(): ...@@ -84,6 +84,8 @@ def _compile_dir():
sys.path.append(cachedir) sys.path.append(cachedir)
return cachedir return cachedir
class Allocated:
"""Memory has been allocated, but contents are not the owner's output."""
class Numpy2(ResultBase): class Numpy2(ResultBase):
"""Result storing a numpy ndarray""" """Result storing a numpy ndarray"""
__slots__ = ['_dtype', '_shape', ] __slots__ = ['_dtype', '_shape', ]
...@@ -121,6 +123,7 @@ class Numpy2(ResultBase): ...@@ -121,6 +123,7 @@ class Numpy2(ResultBase):
def data_alloc(self): def data_alloc(self):
self.data = numpy.ndarray(self.shape, self.dtype) self.data = numpy.ndarray(self.shape, self.dtype)
self.state = Allocated
# self._dtype is used when self._data hasn't been set yet # self._dtype is used when self._data hasn't been set yet
def __dtype_get(self): def __dtype_get(self):
...@@ -351,15 +354,15 @@ class _testCase_literal(unittest.TestCase): ...@@ -351,15 +354,15 @@ class _testCase_literal(unittest.TestCase):
def cgetspecs(names, vals, converters):
d = {}
for name, value in zip(names, vals):
d[name] = value.data
specs = weave.ext_tools.assign_variable_types(names, d, type_converters = converters) #, auto_downcast = 0)
return d, specs
def cgen(name, behavior, names, vals, converters = None): def cgen(name, behavior, names, vals, converters = None):
def cgetspecs(names, vals, converters):
d = {}
for name, value in zip(names, vals):
d[name] = value.data
specs = weave.ext_tools.assign_variable_types(names, d, type_converters = converters) #, auto_downcast = 0)
return d, specs
if not converters: if not converters:
converters = type_spec.default converters = type_spec.default
for converter in converters: for converter in converters:
...@@ -872,6 +875,27 @@ array = wrap_producer(numpy.array) ...@@ -872,6 +875,27 @@ array = wrap_producer(numpy.array)
zeros = wrap_producer(numpy.zeros) zeros = wrap_producer(numpy.zeros)
ones = wrap_producer(numpy.ones) ones = wrap_producer(numpy.ones)
class _testCase_producer_build_mode(unittest.TestCase):
def test_0(self):
"""producer in build mode"""
build_mode()
a = ones(4)
self.failUnless(a.data is None)
self.failUnless(a.state is gof.result.Empty)
self.failUnless(a.shape == (4,))
self.failUnless(a.dtype == 'float64')
pop_mode()
def test_1(self):
"""producer in build_eval mode"""
build_eval_mode()
a = ones(4)
self.failUnless((a.data == numpy.ones(4)).all())
self.failUnless(a.state is gof.result.Computed)
self.failUnless(a.shape == (4,))
self.failUnless(a.dtype == 'float64')
pop_mode()
# Wrapper to ensure that all inputs to the function impl have the same size (foils numpy's broadcasting) # Wrapper to ensure that all inputs to the function impl have the same size (foils numpy's broadcasting)
def assert_same_shapes(impl): def assert_same_shapes(impl):
...@@ -923,6 +947,8 @@ add_elemwise_inplace = add_elemwise.inplace_version() ...@@ -923,6 +947,8 @@ add_elemwise_inplace = add_elemwise.inplace_version()
add_elemwise_inplace.set_impl(assert_same_shapes(numpy.ndarray.__iadd__)) add_elemwise_inplace.set_impl(assert_same_shapes(numpy.ndarray.__iadd__))
class add_scalar(tensor_scalar_op): class add_scalar(tensor_scalar_op):
impl = tensor_scalar_impl(numpy.ndarray.__add__) impl = tensor_scalar_impl(numpy.ndarray.__add__)
def grad(x, a, gz): def grad(x, a, gz):
...@@ -932,6 +958,13 @@ class add_scalar(tensor_scalar_op): ...@@ -932,6 +958,13 @@ class add_scalar(tensor_scalar_op):
add_scalar_inplace = add_scalar.inplace_version() add_scalar_inplace = add_scalar.inplace_version()
add_scalar_inplace.set_impl(tensor_scalar_impl(numpy.ndarray.__iadd__)) add_scalar_inplace.set_impl(tensor_scalar_impl(numpy.ndarray.__iadd__))
class _testCase_add_build_mode(unittest.TestCase):
def setUp(self):
build_mode()
numpy.random.seed(44)
def tearDown(self):
pop_mode()
class twice(elemwise): class twice(elemwise):
def impl(x): def impl(x):
return 2.0 * x return 2.0 * x
...@@ -1100,160 +1133,157 @@ class dot(omega_op): ...@@ -1100,160 +1133,157 @@ class dot(omega_op):
def c_impl((_x, _y), (_z, )): def c_impl((_x, _y), (_z, )):
return blas.gemm_code('', '1.0', '0.0') return blas.gemm_code('', '1.0', '0.0')
if 0: class _testCase_dot(unittest.TestCase):
print 'SKIPPING DOT TESTS' def setUp(self):
else: build_eval_mode()
class _testCase_dot(unittest.TestCase): numpy.random.seed(44)
def setUp(self): def tearDown(self):
build_eval_mode() pop_mode()
numpy.random.seed(44)
def tearDown(self): @staticmethod
pop_mode() def rand(*args):
return numpy.random.rand(*args)
@staticmethod
def rand(*args): def cmp_dot(self,x,y):
return numpy.random.rand(*args) def spec(x):
def cmp_dot(self,x,y):
def spec(x):
x = numpy.asarray(x)
return type(x), x.dtype, x.shape
zspec = dot.specs(spec(x), spec(y))
nz = numpy.dot(x,y)
self.failUnless(zspec == spec(nz))
self.failUnless(_approx_eq(dot(x,y), numpy.dot(x,y)))
def cmp_dot_comp(self, x,y):
x = numpy.asarray(x) x = numpy.asarray(x)
y = numpy.asarray(y) return type(x), x.dtype, x.shape
zspec = dot.specs(spec(x), spec(y))
nz = numpy.dot(x,y)
self.failUnless(zspec == spec(nz))
self.failUnless(_approx_eq(dot(x,y), numpy.dot(x,y)))
def cmp_dot_comp(self, x,y):
x = numpy.asarray(x)
y = numpy.asarray(y)
z = dot(x,y)
p = compile.single(z)
if len(x.shape):
x[:] = numpy.random.rand(*x.shape)
else:
x.fill(numpy.random.rand(*x.shape))
if len(y.shape):
y[:] = numpy.random.rand(*y.shape)
else:
y.fill(numpy.random.rand(*y.shape))
p() # recalculate z
self.failUnless(_approx_eq(z, numpy.dot(x,y)))
def test_dot_0d_0d(self): self.cmp_dot(1.1, 2.2)
def test_dot_0d_1d(self): self.cmp_dot(1.1, self.rand(5))
def test_dot_0d_2d(self): self.cmp_dot(3.0, self.rand(6,7))
def test_dot_0d_3d(self): self.cmp_dot(3.0, self.rand(8,6,7))
def test_dot_1d_0d(self): self.cmp_dot(self.rand(5), 1.1 )
def test_dot_1d_1d(self): self.cmp_dot(self.rand(5), self.rand(5))
def test_dot_1d_2d(self): self.cmp_dot(self.rand(6), self.rand(6,7))
def test_dot_1d_3d(self): self.cmp_dot(self.rand(6), self.rand(8,6,7))
def test_dot_2d_0d(self): self.cmp_dot(self.rand(5,6), 1.0)
def test_dot_2d_1d(self): self.cmp_dot(self.rand(5,6), self.rand(6))
def test_dot_2d_2d(self): self.cmp_dot(self.rand(5,6), self.rand(6,7))
def test_dot_2d_3d(self): self.cmp_dot(self.rand(5,6), self.rand(8,6,7))
def test_dot_3d_0d(self): self.cmp_dot(self.rand(4,5,6), 1.0)
def test_dot_3d_1d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6))
def test_dot_3d_2d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6,7))
def test_dot_3d_3d(self): self.cmp_dot(self.rand(4,5,6), self.rand(8,6,7))
def test_dot_0d_0d_(self): self.cmp_dot_comp(1.1, 2.2)
def test_dot_0d_1d_(self): self.cmp_dot_comp(1.1, self.rand(5))
def test_dot_0d_2d_(self): self.cmp_dot_comp(3.0, self.rand(6,7))
def test_dot_0d_3d_(self): self.cmp_dot_comp(3.0, self.rand(8,6,7))
def test_dot_1d_0d_(self): self.cmp_dot_comp(self.rand(5), 1.1 )
def test_dot_1d_1d_(self): self.cmp_dot_comp(self.rand(5), self.rand(5))
def test_dot_1d_2d_(self): self.cmp_dot_comp(self.rand(6), self.rand(6,7))
def test_dot_1d_3d_(self): self.cmp_dot_comp(self.rand(6), self.rand(8,6,7))
def test_dot_2d_0d_(self): self.cmp_dot_comp(self.rand(5,6), 1.0)
def test_dot_2d_1d_(self): self.cmp_dot_comp(self.rand(5,6), self.rand(6))
def test_dot_2d_2d_(self): self.cmp_dot_comp(self.rand(5,6), self.rand(6,7))
def test_dot_2d_3d_(self): self.cmp_dot_comp(self.rand(5,6), self.rand(8,6,7))
def test_dot_3d_0d_(self): self.cmp_dot_comp(self.rand(4,5,6), 1.0)
def test_dot_3d_1d_(self): self.cmp_dot_comp(self.rand(4,5,6), self.rand(6))
def test_dot_3d_2d_(self): self.cmp_dot_comp(self.rand(4,5,6), self.rand(6,7))
def test_dot_3d_3d_(self): self.cmp_dot_comp(self.rand(4,5,6), self.rand(8,6,7))
def test_dot_fail_1_1(self):
x = numpy.random.rand(5)
y = numpy.random.rand(6)
try:
z = dot(x,y) z = dot(x,y)
p = compile.single(z) except ValueError, e:
if len(x.shape): self.failUnless(str(e) == 'objects are not aligned', e)
x[:] = numpy.random.rand(*x.shape) return
else: self.fail()
x.fill(numpy.random.rand(*x.shape))
if len(y.shape): def test_dot_fail_1_2(self):
y[:] = numpy.random.rand(*y.shape) x = numpy.random.rand(5)
else: y = numpy.random.rand(6,4)
y.fill(numpy.random.rand(*y.shape)) try:
p() # recalculate z z = dot(x,y)
self.failUnless(_approx_eq(z, numpy.dot(x,y))) except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned', e)
def test_dot_0d_0d(self): self.cmp_dot(1.1, 2.2) return
def test_dot_0d_1d(self): self.cmp_dot(1.1, self.rand(5)) self.fail()
def test_dot_0d_2d(self): self.cmp_dot(3.0, self.rand(6,7)) def test_dot_fail_1_3(self):
def test_dot_0d_3d(self): self.cmp_dot(3.0, self.rand(8,6,7)) x = numpy.random.rand(5)
def test_dot_1d_0d(self): self.cmp_dot(self.rand(5), 1.1 ) y = numpy.random.rand(6,4,7)
def test_dot_1d_1d(self): self.cmp_dot(self.rand(5), self.rand(5)) try:
def test_dot_1d_2d(self): self.cmp_dot(self.rand(6), self.rand(6,7)) z = dot(x,y)
def test_dot_1d_3d(self): self.cmp_dot(self.rand(6), self.rand(8,6,7)) except ValueError, e:
def test_dot_2d_0d(self): self.cmp_dot(self.rand(5,6), 1.0) self.failUnless(str(e) == 'objects are not aligned', e)
def test_dot_2d_1d(self): self.cmp_dot(self.rand(5,6), self.rand(6)) return
def test_dot_2d_2d(self): self.cmp_dot(self.rand(5,6), self.rand(6,7)) self.fail()
def test_dot_2d_3d(self): self.cmp_dot(self.rand(5,6), self.rand(8,6,7)) def test_dot_fail_2_1(self):
def test_dot_3d_0d(self): self.cmp_dot(self.rand(4,5,6), 1.0) x = numpy.random.rand(5,4)
def test_dot_3d_1d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6)) y = numpy.random.rand(6)
def test_dot_3d_2d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6,7)) try:
def test_dot_3d_3d(self): self.cmp_dot(self.rand(4,5,6), self.rand(8,6,7)) z = dot(x,y)
def test_dot_0d_0d_(self): self.cmp_dot_comp(1.1, 2.2) except ValueError, e:
def test_dot_0d_1d_(self): self.cmp_dot_comp(1.1, self.rand(5)) self.failUnless(str(e) == 'objects are not aligned', e)
def test_dot_0d_2d_(self): self.cmp_dot_comp(3.0, self.rand(6,7)) return
def test_dot_0d_3d_(self): self.cmp_dot_comp(3.0, self.rand(8,6,7)) self.fail()
def test_dot_1d_0d_(self): self.cmp_dot_comp(self.rand(5), 1.1 ) def test_dot_fail_2_2(self):
def test_dot_1d_1d_(self): self.cmp_dot_comp(self.rand(5), self.rand(5)) x = numpy.random.rand(5,4)
def test_dot_1d_2d_(self): self.cmp_dot_comp(self.rand(6), self.rand(6,7)) y = numpy.random.rand(6,7)
def test_dot_1d_3d_(self): self.cmp_dot_comp(self.rand(6), self.rand(8,6,7)) try:
def test_dot_2d_0d_(self): self.cmp_dot_comp(self.rand(5,6), 1.0) z = dot(x,y)
def test_dot_2d_1d_(self): self.cmp_dot_comp(self.rand(5,6), self.rand(6)) except ValueError, e:
def test_dot_2d_2d_(self): self.cmp_dot_comp(self.rand(5,6), self.rand(6,7)) self.failUnless(str(e) == 'objects are not aligned', e)
def test_dot_2d_3d_(self): self.cmp_dot_comp(self.rand(5,6), self.rand(8,6,7)) return
def test_dot_3d_0d_(self): self.cmp_dot_comp(self.rand(4,5,6), 1.0) self.fail()
def test_dot_3d_1d_(self): self.cmp_dot_comp(self.rand(4,5,6), self.rand(6)) def test_dot_fail_2_3(self):
def test_dot_3d_2d_(self): self.cmp_dot_comp(self.rand(4,5,6), self.rand(6,7)) x = numpy.random.rand(5,4)
def test_dot_3d_3d_(self): self.cmp_dot_comp(self.rand(4,5,6), self.rand(8,6,7)) y = numpy.random.rand(6,7,8)
try:
def test_dot_fail_1_1(self): z = dot(x,y)
x = numpy.random.rand(5) except ValueError, e:
y = numpy.random.rand(6) self.failUnless(str(e) == 'objects are not aligned', e)
try: return
z = dot(x,y) self.fail()
except ValueError, e: def test_dot_fail_3_1(self):
self.failUnless(str(e) == 'objects are not aligned', e) x = numpy.random.rand(5,4,3)
return y = numpy.random.rand(6)
self.fail() try:
z = dot(x,y)
def test_dot_fail_1_2(self): except ValueError, e:
x = numpy.random.rand(5) self.failUnless(str(e) == 'objects are not aligned', e)
y = numpy.random.rand(6,4) return
try: self.fail()
z = dot(x,y) def test_dot_fail_3_2(self):
except ValueError, e: x = numpy.random.rand(5,4,3)
self.failUnless(str(e) == 'objects are not aligned', e) y = numpy.random.rand(6,7)
return try:
self.fail() z = dot(x,y)
def test_dot_fail_1_3(self): except ValueError, e:
x = numpy.random.rand(5) self.failUnless(str(e) == 'objects are not aligned', e)
y = numpy.random.rand(6,4,7) return
try: self.fail()
z = dot(x,y) def test_dot_fail_3_3(self):
except ValueError, e: x = numpy.random.rand(5,4,3)
self.failUnless(str(e) == 'objects are not aligned', e) y = numpy.random.rand(6,7,8)
return try:
self.fail() z = dot(x,y)
def test_dot_fail_2_1(self): except ValueError, e:
x = numpy.random.rand(5,4) self.failUnless(str(e) == 'objects are not aligned', e)
y = numpy.random.rand(6) return
try: self.fail()
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned', e)
return
self.fail()
def test_dot_fail_2_2(self):
x = numpy.random.rand(5,4)
y = numpy.random.rand(6,7)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned', e)
return
self.fail()
def test_dot_fail_2_3(self):
x = numpy.random.rand(5,4)
y = numpy.random.rand(6,7,8)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned', e)
return
self.fail()
def test_dot_fail_3_1(self):
x = numpy.random.rand(5,4,3)
y = numpy.random.rand(6)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned', e)
return
self.fail()
def test_dot_fail_3_2(self):
x = numpy.random.rand(5,4,3)
y = numpy.random.rand(6,7)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned', e)
return
self.fail()
def test_dot_fail_3_3(self):
x = numpy.random.rand(5,4,3)
y = numpy.random.rand(6,7,8)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned', e)
return
self.fail()
class gemm(omega_op): class gemm(omega_op):
def destroy_map(self): return {self.out:[self.inputs[0]]} def destroy_map(self): return {self.out:[self.inputs[0]]}
......
# from op import * import op, ext, lib, link, result, env, prog, features, opt, graph
# from value import *
# from opt import *
# from env import *
# from prog import *
# from diff import *
# import dispatchers
from op import * from op import *
from ext import * from ext import *
...@@ -18,5 +11,3 @@ from features import * ...@@ -18,5 +11,3 @@ from features import *
from opt import * from opt import *
import graph import graph
#import utils
...@@ -44,6 +44,8 @@ def compute_from(nodes, history): ...@@ -44,6 +44,8 @@ def compute_from(nodes, history):
if hasattr(node, 'owner'): #node is storage if hasattr(node, 'owner'): #node is storage
compute_recursive(node.owner) compute_recursive(node.owner)
else: #node is op else: #node is op
if node.destroy_map():
raise ValueError('compute_from() does not work on nodes with destroy_maps')
for input in node.inputs: for input in node.inputs:
compute_recursive(input) compute_recursive(input)
node.perform() node.perform()
...@@ -95,8 +97,6 @@ class ForbidConstantOverwrite(features.Listener, features.Constraint): ...@@ -95,8 +97,6 @@ class ForbidConstantOverwrite(features.Listener, features.Constraint):
else: else:
return True return True
class DestroyHandler(features.Listener, features.Constraint, features.Orderings): class DestroyHandler(features.Listener, features.Constraint, features.Orderings):
def __init__(self, env): def __init__(self, env):
...@@ -383,13 +383,14 @@ class PythonOp(Op): ...@@ -383,13 +383,14 @@ class PythonOp(Op):
return all([ is_result(i) for i in self.inputs]) return all([ is_result(i) for i in self.inputs])
def gen_outputs(self): def gen_outputs(self):
raise NotImplementedError() raise AbstractFunctionError()
def view_map(self): return {} def view_map(self): return {}
def destroy_map(self): return {} def destroy_map(self): return {}
def root_inputs(self, input): @staticmethod
def root_inputs(input):
owner = input.owner owner = input.owner
if owner: if owner:
view_map = owner.view_map() view_map = owner.view_map()
......
...@@ -11,7 +11,7 @@ from err import GofError ...@@ -11,7 +11,7 @@ from err import GofError
from utils import AbstractFunctionError from utils import AbstractFunctionError
__all__ = ['is_result', 'ResultBase', 'BrokenLink', 'BrokenLinkError'] __all__ = ['is_result', 'ResultBase', 'BrokenLink', 'BrokenLinkError' ]
class BrokenLink: class BrokenLink:
...@@ -36,6 +36,10 @@ class BrokenLinkError(GofError): ...@@ -36,6 +36,10 @@ class BrokenLinkError(GofError):
pass pass
# ResultBase state keywords
class Empty : pass
class Computed : pass
############################ ############################
# Result # Result
...@@ -53,6 +57,7 @@ class ResultBase(object): ...@@ -53,6 +57,7 @@ class ResultBase(object):
_role - None or (owner, index) or BrokenLink _role - None or (owner, index) or BrokenLink
_data - anything _data - anything
constant - Boolean constant - Boolean
state - one of (Empty, Allocated, Computed)
Properties: Properties:
role - (rw) role - (rw)
...@@ -60,13 +65,12 @@ class ResultBase(object): ...@@ -60,13 +65,12 @@ class ResultBase(object):
index - (ro) index - (ro)
data - (rw) data - (rw)
replaced - (rw) : True iff _role is BrokenLink replaced - (rw) : True iff _role is BrokenLink
computed - (ro) : True iff contents of data are fresh
Abstract Methods: Abstract Methods:
data_filter data_filter
Notes: Notes (from previous implementation):
A Result instance should be immutable: indeed, if some aspect of a A Result instance should be immutable: indeed, if some aspect of a
Result is changed, operations that use it might suddenly become Result is changed, operations that use it might suddenly become
...@@ -89,22 +93,28 @@ class ResultBase(object): ...@@ -89,22 +93,28 @@ class ResultBase(object):
class AbstractFunction(Exception): class AbstractFunction(Exception):
"""Exception thrown when an abstract function is called""" """Exception thrown when an abstract function is called"""
__slots__ = ['_role', '_data', 'constant'] __slots__ = ['_role', 'constant', '_data', 'state']
def __init__(self, role=None, data=None, constant=False): def __init__(self, role=None, data=None, constant=False):
self._role = role self._role = role
self.constant = constant self.constant = constant
if data is None: #None is not filtered if data is None: #None is not filtered
self._data = None self._data = None
self.state = Empty
else: else:
try: try:
self._data = self.data_filter(data) self._data = self.data_filter(data)
except ResultBase.AbstractFunction: except ResultBase.AbstractFunction:
self._data = data self._data = data
self.state = Computed
#
# role
#
#role is pair: (owner, outputs_position)
def __get_role(self): def __get_role(self):
return self._role return self._role
def __set_role(self, role): def __set_role(self, role):
owner, index = role owner, index = role
if self._role is not None: if self._role is not None:
...@@ -116,29 +126,41 @@ class ResultBase(object): ...@@ -116,29 +126,41 @@ class ResultBase(object):
raise ValueError("Result %s was already mapped to a different index." % self) raise ValueError("Result %s was already mapped to a different index." % self)
return # because _owner is owner and _index == index return # because _owner is owner and _index == index
self._role = role self._role = role
role = property(__get_role, __set_role) role = property(__get_role, __set_role)
#owner is role[0] #
# owner
#
def __get_owner(self): def __get_owner(self):
if self._role is None: return None if self._role is None: return None
if self.replaced: raise ResultBase.BrokenLinkError() if self.replaced: raise ResultBase.BrokenLinkError()
return self._role[0] return self._role[0]
owner = property(__get_owner, owner = property(__get_owner,
doc = "Op of which this Result is an output, or None if role is None") doc = "Op of which this Result is an output, or None if role is None")
#index is role[1] #
# index
#
def __get_index(self): def __get_index(self):
if self._role is None: return None if self._role is None: return None
if self.replaced: raise ResultBase.BrokenLinkError() if self.replaced: raise ResultBase.BrokenLinkError()
return self._role[1] return self._role[1]
index = property(__get_index, index = property(__get_index,
doc = "position of self in owner's outputs, or None if role is None") doc = "position of self in owner's outputs, or None if role is None")
# assigning to self.data will invoke self.data_filter(value) if that #
# function is defined # data
#
def __get_data(self): def __get_data(self):
return self._data return self._data
def __set_data(self, data): def __set_data(self, data):
if self.replaced: raise ResultBase.BrokenLinkError() if self.replaced: raise ResultBase.BrokenLinkError()
if self.constant: raise Exception('cannot set constant ResultBase') if self.constant: raise Exception('cannot set constant ResultBase')
...@@ -146,27 +168,39 @@ class ResultBase(object): ...@@ -146,27 +168,39 @@ class ResultBase(object):
self._data = self.data_filter(data) self._data = self.data_filter(data)
except ResultBase.AbstractFunction: #use default behaviour except ResultBase.AbstractFunction: #use default behaviour
self._data = data self._data = data
self.state = Computed
data = property(__get_data, __set_data, data = property(__get_data, __set_data,
doc = "The storage associated with this result") doc = "The storage associated with this result")
def data_filter(self, data): def data_filter(self, data):
"""(abstract) Return an appropriate _data based on data.""" """(abstract) Return an appropriate _data based on data.
If a subclass overrides this function, then that overriding
implementation will be used in __set_data to map the argument to
self._data. This gives a subclass the opportunity to ensure that
the contents of self._data remain sensible.
"""
raise ResultBase.AbstractFunction() raise ResultBase.AbstractFunction()
#
# replaced # replaced
def __get_replaced(self): return isinstance(self._role, ResultBase.BrokenLink) #
def __get_replaced(self):
return isinstance(self._role, ResultBase.BrokenLink)
def __set_replaced(self, replace): def __set_replaced(self, replace):
if replace == self.replaced: return if replace == self.replaced: return
if replace: if replace:
self._role = ResultBase.BrokenLink(self._role) self._role = ResultBase.BrokenLink(self._role)
else: else:
self._role = self._role.old_role self._role = self._role.old_role
replaced = property(__get_replaced, __set_replaced, doc = "has this Result been replaced?") replaced = property(__get_replaced, __set_replaced, doc = "has this Result been replaced?")
# computed
#TODO: think about how to handle this more correctly
computed = property(lambda self: self._data is not None)
################# #################
......
...@@ -77,7 +77,7 @@ class Grad(object): ...@@ -77,7 +77,7 @@ class Grad(object):
r.shape, dr.shape)) r.shape, dr.shape))
# prevent 'r' from being re-calculated by self.__call__ in 'build_eval' mode # prevent 'r' from being re-calculated by self.__call__ in 'build_eval' mode
if r.computed: if r.state is gof.result.Computed:
self._compute_history.add(r) self._compute_history.add(r)
# add dr to self[r] # add dr to self[r]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论