提交 5d01f64b authored 作者: James Bergstra's avatar James Bergstra

merged

import traceback
from tensor import *
import tensor # for hidden symbols
......@@ -1428,6 +1429,7 @@ class T_tensor(unittest.TestCase):
t.data = numpy.ones((2,7,1))
self.fail()
except ValueError, e:
#traceback.print_exc()
self.failUnless(e[0] is Tensor.filter.E_rank)
try:
t.data = numpy.ones(1)
......@@ -1452,6 +1454,7 @@ class T_tensor(unittest.TestCase):
t.data = numpy.ones((1,2))
self.fail()
except ValueError, e:
#traceback.print_exc()
self.failUnless(e[0] is Tensor.filter.E_shape)
try:
t.data = numpy.ones((0,1))
......
......@@ -67,19 +67,19 @@ class Tensor(Result):
#
# filter
#
def filter(self, arr):
def filter(self, arg):
"""Cast to an L{numpy.ndarray} and ensure arr has correct rank and shape."""
if not (isinstance(arr, numpy.ndarray) \
and arr.dtype==self.dtype):
arr = numpy.asarray(arr, dtype = self.dtype)
if (isinstance(arg, numpy.ndarray) and arg.dtype==self.dtype):
arr = arg
else:
arr = numpy.asarray(arg, dtype = self.dtype)
if len(self.broadcastable) != len(arr.shape):
raise ValueError(Tensor.filter.E_rank,
self.broadcastable,
arr.shape,
self.owner)
for b, s in zip(self.broadcastable, arr.shape):
"Can't assign ndarray of rank %i to Tensor of rank %i, when filtering %s" % (len(arr.shape), len(self.broadcastable), arg))
for i, (b, s) in enumerate(zip(self.broadcastable, arr.shape)):
if b and (s != 1):
raise ValueError(Tensor.filter.E_shape)
raise ValueError(Tensor.filter.E_shape,
('(dimension %i when filtering %s)'%(i, arg)))
return arr
# these strings are here so that tests can use them
filter.E_rank = 'wrong rank'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论