added tests, comments for base tensor

上级 cb772e00
from base_tensor import *
import unittest
from copy import copy
def _tensor(data, broadcastable=None, role=None, name=None):
"""Return a BaseTensor containing given data"""
data = numpy.asarray(data)
if broadcastable is None:
broadcastable = [s==1 for s in data.shape]
elif broadcastable in [0, 1]:
broadcastable = [broadcastable] * len(data.shape)
rval = BaseTensor(data.dtype, broadcastable, role, name)
rval.data = data # will raise if broadcastable was mis-specified
return rval
class T_tensor(unittest.TestCase):
def test0(self): # allocate from a scalar float
t = _tensor(1.0)
self.failUnless(isinstance(t, BaseTensor))
self.failUnless(t.dtype == 'float64')
self.failUnless(t.broadcastable == ())
self.failUnless(t.role == None)
self.failUnless(isinstance(t.data, numpy.ndarray))
self.failUnless(str(t.data.dtype) == 'float64')
self.failUnless(t.data == 1.0)
def test0_int(self): # allocate from a scalar float
t = _tensor(1)
self.failUnless(isinstance(t, BaseTensor))
self.failUnless(t.dtype == 'int64' or t.dtype == 'int32')
def test1(self): # allocate from a vector of ints, not broadcastable
t = _tensor(numpy.ones(5,dtype='int32'))
self.failUnless(isinstance(t, BaseTensor))
self.failUnless(t.dtype == 'int32')
self.failUnless(t.broadcastable == (0,))
self.failUnless(isinstance(t.data, numpy.ndarray))
self.failUnless(str(t.data.dtype) == 'int32')
def test2(self): # allocate from a column matrix of complex with name
t = _tensor(numpy.ones((5,1),dtype='complex64'),name='bart')
self.failUnless(isinstance(t, BaseTensor))
self.failUnless(t.dtype == 'complex64')
self.failUnless(t.broadcastable == (0,1))
self.failUnless(isinstance(t.data, numpy.ndarray))
self.failUnless(t.name == 'bart')
def test2b(self): # allocate from a column matrix, not broadcastable
t = _tensor(numpy.ones((5,1),dtype='complex64'),broadcastable=0)
self.failUnless(isinstance(t, BaseTensor))
self.failUnless(t.dtype == 'complex64')
self.failUnless(t.broadcastable == (0,0))
self.failUnless(isinstance(t.data, numpy.ndarray))
def test_data_normal(self): #test that assigning to .data works when it should
t = _tensor(numpy.ones((5,1),dtype='complex64'), broadcastable=0)
o27 = numpy.ones((2,7))
t.data = o27
lst = t._data
self.failUnless(t.data.shape == (2,7))
self.failUnless(t.data is o27)
self.failUnless(t._data is lst)
def test_data_badrank0(self):
t = _tensor(numpy.ones((5,1),dtype='complex64'), broadcastable=0)
try:
t.data = numpy.ones((2,7,1))
self.fail()
except ValueError, e:
self.failUnless(e[0] is BaseTensor.filter.E_rank)
try:
t.data = numpy.ones(1)
self.fail()
except ValueError, e:
self.failUnless(e[0] is BaseTensor.filter.E_rank)
def test_data_badrank1(self):
t = _tensor(numpy.ones((1,1),dtype='complex64'), broadcastable=1)
try:
t.data = numpy.ones((1,1,1))
self.fail()
except ValueError, e:
self.failUnless(e[0] is BaseTensor.filter.E_rank)
try:
t.data = numpy.ones(1)
self.fail()
except ValueError, e:
self.failUnless(e[0] is BaseTensor.filter.E_rank)
def test_data_badshape0(self):
t = _tensor(numpy.ones((1,1),dtype='complex64'), broadcastable=1)
try:
t.data = numpy.ones((1,2))
self.fail()
except ValueError, e:
self.failUnless(e[0] is BaseTensor.filter.E_shape)
try:
t.data = numpy.ones((0,1))
self.fail()
except ValueError, e:
self.failUnless(e[0] is BaseTensor.filter.E_shape)
class T_stdlib(unittest.TestCase):
def test0(self):
t = _tensor(1.0)
tt = t.clone(False)
self.failUnless(t.dtype == tt.dtype)
self.failUnless(t.broadcastable is tt.broadcastable)
self.failUnless(tt.data is None)
self.failUnless(t.data == 1.0)
def test0b(self):
t = _tensor(1.0)
tt = t.clone()
self.failUnless(t.dtype == tt.dtype)
self.failUnless(t.broadcastable is tt.broadcastable)
self.failUnless(tt.data is None)
self.failUnless(t.data == 1.0)
def test1(self):
t = _tensor(1.0)
tt = t.clone(True)
self.failUnless(t.dtype == tt.dtype)
self.failUnless(t.broadcastable is tt.broadcastable)
self.failUnless(tt.data == 1.0)
self.failUnless(t.data == 1.0)
self.failUnless(t.data is not tt.data)
def test1b(self):
t = _tensor(1.0)
tt = copy(t)
self.failUnless(t.dtype == tt.dtype)
self.failUnless(t.broadcastable is tt.broadcastable)
self.failUnless(tt.data == 1.0)
self.failUnless(t.data == 1.0)
self.failUnless(t.data is not tt.data)
if __name__ == '__main__':
unittest.main()
......@@ -46,93 +46,6 @@ def verify_grad(testcase, op_cls, pt_list, n_tests=1, rng=numpy.random, eps=0.00
verify_grad.E_grad = 'gradient error exceeded tolerance'
class T_tensor(unittest.TestCase):
def test0(self): # allocate from a scalar float
t = tensor(1.0)
self.failUnless(isinstance(t, Tensor))
self.failUnless(t.dtype == 'float64')
self.failUnless(t.broadcastable == ())
self.failUnless(t.role == None)
self.failUnless(isinstance(t.data, numpy.ndarray))
self.failUnless(str(t.data.dtype) == 'float64')
self.failUnless(t.data == 1.0)
def test0_int(self): # allocate from a scalar float
t = tensor(1)
self.failUnless(isinstance(t, Tensor))
self.failUnless(t.dtype == 'int64' or t.dtype == 'int32')
def test1(self): # allocate from a vector of ints, not broadcastable
t = tensor(numpy.ones(5,dtype='int32'))
self.failUnless(isinstance(t, Tensor))
self.failUnless(t.dtype == 'int32')
self.failUnless(t.broadcastable == (0,))
self.failUnless(isinstance(t.data, numpy.ndarray))
self.failUnless(str(t.data.dtype) == 'int32')
def test2(self): # allocate from a column matrix of complex with name
t = tensor(numpy.ones((5,1),dtype='complex64'),name='bart')
self.failUnless(isinstance(t, Tensor))
self.failUnless(t.dtype == 'complex64')
self.failUnless(t.broadcastable == (0,1))
self.failUnless(isinstance(t.data, numpy.ndarray))
self.failUnless(t.name == 'bart')
def test2b(self): # allocate from a column matrix, not broadcastable
t = tensor(numpy.ones((5,1),dtype='complex64'),broadcastable=0)
self.failUnless(isinstance(t, Tensor))
self.failUnless(t.dtype == 'complex64')
self.failUnless(t.broadcastable == (0,0))
self.failUnless(isinstance(t.data, numpy.ndarray))
def test_data_normal(self): #test that assigning to .data works when it should
t = tensor(numpy.ones((5,1),dtype='complex64'), broadcastable=0)
o27 = numpy.ones((2,7))
t.data = o27
lst = t._data
self.failUnless(t.data.shape == (2,7))
self.failUnless(t.data is o27)
self.failUnless(t._data is lst)
def test_data_badrank0(self):
t = tensor(numpy.ones((5,1),dtype='complex64'), broadcastable=0)
try:
t.data = numpy.ones((2,7,1))
self.fail()
except ValueError, e:
self.failUnless(e[0] is Tensor.filter.E_rank)
try:
t.data = numpy.ones(1)
self.fail()
except ValueError, e:
self.failUnless(e[0] is Tensor.filter.E_rank)
def test_data_badrank1(self):
t = tensor(numpy.ones((1,1),dtype='complex64'), broadcastable=1)
try:
t.data = numpy.ones((1,1,1))
self.fail()
except ValueError, e:
self.failUnless(e[0] is Tensor.filter.E_rank)
try:
t.data = numpy.ones(1)
self.fail()
except ValueError, e:
self.failUnless(e[0] is Tensor.filter.E_rank)
def test_data_badshape0(self):
t = tensor(numpy.ones((1,1),dtype='complex64'), broadcastable=1)
try:
t.data = numpy.ones((1,2))
self.fail()
except ValueError, e:
self.failUnless(e[0] is Tensor.filter.E_shape)
try:
t.data = numpy.ones((0,1))
self.fail()
except ValueError, e:
self.failUnless(e[0] is Tensor.filter.E_shape)
class T_stdlib(unittest.TestCase):
def test0(self):
t = tensor(1.0)
tt = copy(t)
self.failUnless(t.dtype == tt.dtype)
self.failUnless(t.broadcastable == tt.broadcastable)
self.failUnless(t.broadcastable is tt.broadcastable)
self.failIf(t.data is tt.data)
def check_eq(self, node_in, node_out, arg_in, arg_out):
fn = Function([node_in], [node_out])
......
"""A simple class to store ndarray data """
from gof import ResultBase
import numpy
......@@ -50,6 +51,7 @@ class BaseTensor(ResultBase):
# filter
#
def filter(self, arr):
"""cast to an ndarray and ensure arr has correct rank, shape"""
if not isinstance(arr, numpy.ndarray):
arr = numpy.asarray(arr, dtype = self.dtype)
if len(self.broadcastable) != len(arr.shape):
......@@ -159,8 +161,10 @@ class BaseTensor(ResultBase):
return self.clone(True)
def clone(self, transfer_data = False):
"""
Returns a copy of this Tensor. If there is data stored inside it, it is also copied.
"""Return a copy of this instance (with its own attributes)
If transfer_data is True, a copy of self.data is assigned to the copy's
data property, otherwise the copy's data is left as None.
"""
cpy = self.__class__(self.dtype, self.broadcastable, None, self.name)
if transfer_data:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论