fixed bug in BaseTensor.filter

上级 05d37e14
...@@ -99,6 +99,11 @@ class T_tensor(unittest.TestCase): ...@@ -99,6 +99,11 @@ class T_tensor(unittest.TestCase):
except ValueError, e: except ValueError, e:
self.failUnless(e[0] is BaseTensor.filter.E_shape) self.failUnless(e[0] is BaseTensor.filter.E_shape)
def test_cast0(self):
t = BaseTensor('float32', [0])
t.data = numpy.random.rand(4) > 0.5
print t.data
class T_stdlib(unittest.TestCase): class T_stdlib(unittest.TestCase):
def test0(self): def test0(self):
t = _tensor(1.0) t = _tensor(1.0)
......
...@@ -57,7 +57,8 @@ class BaseTensor(ResultBase): ...@@ -57,7 +57,8 @@ class BaseTensor(ResultBase):
# #
def filter(self, arr): def filter(self, arr):
"""cast to an ndarray and ensure arr has correct rank, shape""" """cast to an ndarray and ensure arr has correct rank, shape"""
if not isinstance(arr, numpy.ndarray): if not (isinstance(arr, numpy.ndarray) \
and arr.dtype==self.dtype):
arr = numpy.asarray(arr, dtype = self.dtype) arr = numpy.asarray(arr, dtype = self.dtype)
if len(self.broadcastable) != len(arr.shape): if len(self.broadcastable) != len(arr.shape):
raise ValueError(BaseTensor.filter.E_rank, raise ValueError(BaseTensor.filter.E_rank,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论