提交 5d01f64b authored 作者: James Bergstra's avatar James Bergstra

merged

import traceback
from tensor import * from tensor import *
import tensor # for hidden symbols import tensor # for hidden symbols
...@@ -1428,6 +1429,7 @@ class T_tensor(unittest.TestCase): ...@@ -1428,6 +1429,7 @@ class T_tensor(unittest.TestCase):
t.data = numpy.ones((2,7,1)) t.data = numpy.ones((2,7,1))
self.fail() self.fail()
except ValueError, e: except ValueError, e:
#traceback.print_exc()
self.failUnless(e[0] is Tensor.filter.E_rank) self.failUnless(e[0] is Tensor.filter.E_rank)
try: try:
t.data = numpy.ones(1) t.data = numpy.ones(1)
...@@ -1452,6 +1454,7 @@ class T_tensor(unittest.TestCase): ...@@ -1452,6 +1454,7 @@ class T_tensor(unittest.TestCase):
t.data = numpy.ones((1,2)) t.data = numpy.ones((1,2))
self.fail() self.fail()
except ValueError, e: except ValueError, e:
#traceback.print_exc()
self.failUnless(e[0] is Tensor.filter.E_shape) self.failUnless(e[0] is Tensor.filter.E_shape)
try: try:
t.data = numpy.ones((0,1)) t.data = numpy.ones((0,1))
......
...@@ -67,19 +67,19 @@ class Tensor(Result): ...@@ -67,19 +67,19 @@ class Tensor(Result):
# #
# filter # filter
# #
def filter(self, arr): def filter(self, arg):
"""Cast to an L{numpy.ndarray} and ensure arr has correct rank and shape.""" """Cast to an L{numpy.ndarray} and ensure arr has correct rank and shape."""
if not (isinstance(arr, numpy.ndarray) \ if (isinstance(arg, numpy.ndarray) and arg.dtype==self.dtype):
and arr.dtype==self.dtype): arr = arg
arr = numpy.asarray(arr, dtype = self.dtype) else:
arr = numpy.asarray(arg, dtype = self.dtype)
if len(self.broadcastable) != len(arr.shape): if len(self.broadcastable) != len(arr.shape):
raise ValueError(Tensor.filter.E_rank, raise ValueError(Tensor.filter.E_rank,
self.broadcastable, "Can't assign ndarray of rank %i to Tensor of rank %i, when filtering %s" % (len(arr.shape), len(self.broadcastable), arg))
arr.shape, for i, (b, s) in enumerate(zip(self.broadcastable, arr.shape)):
self.owner)
for b, s in zip(self.broadcastable, arr.shape):
if b and (s != 1): if b and (s != 1):
raise ValueError(Tensor.filter.E_shape) raise ValueError(Tensor.filter.E_shape,
('(dimension %i when filtering %s)'%(i, arg)))
return arr return arr
# these strings are here so that tests can use them # these strings are here so that tests can use them
filter.E_rank = 'wrong rank' filter.E_rank = 'wrong rank'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论