提交 ad8c564c authored 作者: Olivier Breuleux's avatar Olivier Breuleux

wrapping ResultBase._data, more tests

上级 6ce2678b
...@@ -1508,6 +1508,8 @@ sqr_inplace = sqr.inplace_version() ...@@ -1508,6 +1508,8 @@ sqr_inplace = sqr.inplace_version()
sqr_inplace.set_impl(lambda x: x.__imul__(x)) sqr_inplace.set_impl(lambda x: x.__imul__(x))
class sqrt(elemwise): class sqrt(elemwise):
impl = numpy.sqrt impl = numpy.sqrt
def grad(x, gz): def grad(x, gz):
......
...@@ -104,14 +104,15 @@ class ResultBase(object): ...@@ -104,14 +104,15 @@ class ResultBase(object):
def __init__(self, role=None, data=None, constant=False): def __init__(self, role=None, data=None, constant=False):
self._role = role self._role = role
self.constant = constant self.constant = constant
self._data = [None]
if data is None: #None is not filtered if data is None: #None is not filtered
self._data = None self._data[0] = None
self.state = Empty self.state = Empty
else: else:
try: try:
self._data = self.data_filter(data) self._data[0] = self.data_filter(data)
except ResultBase.AbstractFunction: except ResultBase.AbstractFunction:
self._data = data self._data[0] = data
self.state = Computed self.state = Computed
# #
...@@ -165,15 +166,19 @@ class ResultBase(object): ...@@ -165,15 +166,19 @@ class ResultBase(object):
# #
def __get_data(self): def __get_data(self):
return self._data return self._data[0]
def __set_data(self, data): def __set_data(self, data):
if self.replaced: raise ResultBase.BrokenLinkError() if self.replaced: raise ResultBase.BrokenLinkError()
if self.constant: raise Exception('cannot set constant ResultBase') if self.constant: raise Exception('cannot set constant ResultBase')
if data is None:
self._data[0] = None
self.state = Empty
return
try: try:
self._data = self.data_filter(data) self._data[0] = self.data_filter(data)
except ResultBase.AbstractFunction: #use default behaviour except ResultBase.AbstractFunction: #use default behaviour
self._data = data self._data[0] = data
self.state = Computed self.state = Computed
data = property(__get_data, __set_data, data = property(__get_data, __set_data,
...@@ -240,8 +245,23 @@ class ResultBase(object): ...@@ -240,8 +245,23 @@ class ResultBase(object):
class _test_ResultBase(unittest.TestCase): class _test_ResultBase(unittest.TestCase):
def test_0(self): def test_0(self):
r = ResultBase() r = ResultBase()
def test_1(self):
r = ResultBase()
assert r.state is Empty
r.data = 0
assert r.data == 0
assert r.state is Computed
r.data = 1
assert r.data == 1
assert r.state is Computed
r.data = None
assert r.data == None
assert r.state is Empty
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论