提交 ad8c564c authored 作者: Olivier Breuleux's avatar Olivier Breuleux

wrapping ResultBase._data, more tests

上级 6ce2678b
......@@ -1508,6 +1508,8 @@ sqr_inplace = sqr.inplace_version()
sqr_inplace.set_impl(lambda x: x.__imul__(x))
class sqrt(elemwise):
impl = numpy.sqrt
def grad(x, gz):
......
......@@ -104,14 +104,15 @@ class ResultBase(object):
def __init__(self, role=None, data=None, constant=False):
self._role = role
self.constant = constant
self._data = [None]
if data is None: #None is not filtered
self._data = None
self._data[0] = None
self.state = Empty
else:
try:
self._data = self.data_filter(data)
self._data[0] = self.data_filter(data)
except ResultBase.AbstractFunction:
self._data = data
self._data[0] = data
self.state = Computed
#
......@@ -165,15 +166,19 @@ class ResultBase(object):
#
def __get_data(self):
return self._data
return self._data[0]
def __set_data(self, data):
if self.replaced: raise ResultBase.BrokenLinkError()
if self.constant: raise Exception('cannot set constant ResultBase')
if data is None:
self._data[0] = None
self.state = Empty
return
try:
self._data = self.data_filter(data)
self._data[0] = self.data_filter(data)
except ResultBase.AbstractFunction: #use default behaviour
self._data = data
self._data[0] = data
self.state = Computed
data = property(__get_data, __set_data,
......@@ -240,8 +245,23 @@ class ResultBase(object):
class _test_ResultBase(unittest.TestCase):
def test_0(self):
r = ResultBase()
def test_1(self):
r = ResultBase()
assert r.state is Empty
r.data = 0
assert r.data == 0
assert r.state is Computed
r.data = 1
assert r.data == 1
assert r.state is Computed
r.data = None
assert r.data == None
assert r.state is Empty
if __name__ == '__main__':
unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论