提交 ed509d18 authored 作者: Frederic Bastien's avatar Frederic Bastien

Add test for gemv and dot with dimensions of zeros.

上级 50605d2c
...@@ -189,7 +189,7 @@ class Gemv(Op): ...@@ -189,7 +189,7 @@ class Gemv(Op):
def perform(self, node, inputs, out_storage): def perform(self, node, inputs, out_storage):
y, alpha, A, x, beta = inputs y, alpha, A, x, beta = inputs
if _have_fblas: if _have_fblas and y.shape[0]!=0 and x.shape[0]!=0:
gemv = _blas_gemv_fns[y.dtype] gemv = _blas_gemv_fns[y.dtype]
if (A.shape[0] != y.shape[0] or A.shape[1] != x.shape[0]): if (A.shape[0] != y.shape[0] or A.shape[1] != x.shape[0]):
......
...@@ -3173,11 +3173,27 @@ class t_dot(unittest.TestCase): ...@@ -3173,11 +3173,27 @@ class t_dot(unittest.TestCase):
#def test_dot_0d_3d(self): self.cmp_dot(3.0, self.rand(8,6,7)) #def test_dot_0d_3d(self): self.cmp_dot(3.0, self.rand(8,6,7))
#def test_dot_1d_0d(self): self.cmp_dot(self.rand(5), 1.1 ) #def test_dot_1d_0d(self): self.cmp_dot(self.rand(5), 1.1 )
def test_dot_1d_1d(self): self.cmp_dot(self.rand(5), self.rand(5)) def test_dot_1d_1d(self): self.cmp_dot(self.rand(5), self.rand(5))
def test_dot_1d0_1d0(self): self.cmp_dot(self.rand(0), self.rand(0))
#numpy return matrix not aligned...
#def test_dot_1d_1d0(self): self.cmp_dot(self.rand(5), self.rand(0))
#numpy return matrix not aligned...
#def test_dot_1d0_1d(self): self.cmp_dot(self.rand(0), self.rand(5))
def test_dot_1d_2d(self): self.cmp_dot(self.rand(6), self.rand(6,7)) def test_dot_1d_2d(self): self.cmp_dot(self.rand(6), self.rand(6,7))
def test_dot_1d0_2d(self): self.cmp_dot(self.rand(0), self.rand(0,7))
def test_dot_1d_2d0(self): self.cmp_dot(self.rand(6), self.rand(6,0))
def test_dot_1d0_2d0(self): self.cmp_dot(self.rand(0), self.rand(0,0))
#def test_dot_1d_3d(self): self.cmp_dot(self.rand(6), self.rand(8,6,7)) #def test_dot_1d_3d(self): self.cmp_dot(self.rand(6), self.rand(8,6,7))
#def test_dot_2d_0d(self): self.cmp_dot(self.rand(5,6), 1.0) #def test_dot_2d_0d(self): self.cmp_dot(self.rand(5,6), 1.0)
def test_dot_2d_1d(self): self.cmp_dot(self.rand(5,6), self.rand(6)) def test_dot_2d_1d(self): self.cmp_dot(self.rand(5,6), self.rand(6))
def test_dot_2d0_1d(self): self.cmp_dot(self.rand(0,6), self.rand(6))
def test_dot_2d_1d0(self): self.cmp_dot(self.rand(5,0), self.rand(0))
def test_dot_2d0_1d0(self): self.cmp_dot(self.rand(0,0), self.rand(0))
def test_dot_2d_2d(self): self.cmp_dot(self.rand(5,6), self.rand(6,7)) def test_dot_2d_2d(self): self.cmp_dot(self.rand(5,6), self.rand(6,7))
def test_dot_2d0_2d(self): self.cmp_dot(self.rand(0,6), self.rand(6,7))
def test_dot_2d_2d0(self): self.cmp_dot(self.rand(5,6), self.rand(6,0))
def test_dot_2d0_2d0(self): self.cmp_dot(self.rand(0,6), self.rand(6,0))
def test_dot_2d_0_2d(self): self.cmp_dot(self.rand(5,0), self.rand(0,7))
def test_dot_2d0_0_2d0(self): self.cmp_dot(self.rand(0,6), self.rand(6,0))
#def test_dot_2d_3d(self): self.cmp_dot(self.rand(5,6), self.rand(8,6,7)) #def test_dot_2d_3d(self): self.cmp_dot(self.rand(5,6), self.rand(8,6,7))
#def test_dot_3d_0d(self): self.cmp_dot(self.rand(4,5,6), 1.0) #def test_dot_3d_0d(self): self.cmp_dot(self.rand(4,5,6), 1.0)
#def test_dot_3d_1d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6)) #def test_dot_3d_1d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6))
......
...@@ -821,13 +821,14 @@ class TestGemv(TestCase): ...@@ -821,13 +821,14 @@ class TestGemv(TestCase):
assert sum([isinstance(node.op, T.blas.Dot22) for node in assert sum([isinstance(node.op, T.blas.Dot22) for node in
f.maker.env.toposort() ]) == 1 f.maker.env.toposort() ]) == 1
def test_gemv1(self): @staticmethod
def t_gemv1(m_shp):
''' test vector2+dot(matrix,vector1) ''' ''' test vector2+dot(matrix,vector1) '''
rng = numpy.random.RandomState(unittest_tools.fetch_seed()) rng = numpy.random.RandomState(unittest_tools.fetch_seed())
v1 = theano.shared(numpy.array(rng.uniform(size=(2,)), dtype='float32')) v1 = theano.shared(numpy.array(rng.uniform(size=(m_shp[1],)), dtype='float32'))
v2_orig = numpy.array(rng.uniform(size=(3,)), dtype='float32') v2_orig = numpy.array(rng.uniform(size=(m_shp[0],)), dtype='float32')
v2 = theano.shared(v2_orig) v2 = theano.shared(v2_orig)
m = theano.shared(numpy.array(rng.uniform(size=(3,2)), dtype='float32')) m = theano.shared(numpy.array(rng.uniform(size=m_shp), dtype='float32'))
f = theano.function([], v2+theano.dot(m,v1), mode = mode_blas_opt) f = theano.function([], v2+theano.dot(m,v1), mode = mode_blas_opt)
...@@ -853,6 +854,12 @@ class TestGemv(TestCase): ...@@ -853,6 +854,12 @@ class TestGemv(TestCase):
if config.mode != 'FAST_COMPILE': if config.mode != 'FAST_COMPILE':
assert topo[0].op.inplace==True assert topo[0].op.inplace==True
def test_gemv1(self):
self.t_gemv1((3,2))
self.t_gemv1((0,2))
self.t_gemv1((3,0))
self.t_gemv1((0,0))
def test_gemv2(self): def test_gemv2(self):
''' test vector2+dot(vector1,matrix) ''' ''' test vector2+dot(vector1,matrix) '''
rng = numpy.random.RandomState(unittest_tools.fetch_seed()) rng = numpy.random.RandomState(unittest_tools.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论