提交 dcfe260e authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add dimension check to Gemv.perform. Test some cases that were incorrect.

These tests were previously computing wrong values.
上级 c02d3283
......@@ -87,6 +87,11 @@ class Gemv(Op):
if _have_fblas:
gemv = _blas_gemv_fns[y.dtype]
if (A.shape[0] != y.shape[0] or A.shape[1] != x.shape[0]):
raise ValueError('Incompatible shapes for gemv '
'(beta * y + alpha * dot(A, x)). y: %s, A: %s, x: %s '
% (y.shape, A.shape, x.shape))#
#Here I suppose that A is in c order. If we don't make it explicitly
# as fortran order, scipy 0.7.2 seam to create a copy in fortran
# order instead of just reshaping it and using the trans flag.
......
......@@ -761,6 +761,23 @@ def test_gemv2():
if config.mode != 'FAST_COMPILE':
assert topo[-1].op.inplace==True
class TestGemv(TestCase):
def test_gemv_dimensions(self):
A = T.matrix('A')
x, y = T.vectors('x', 'y')
alpha = theano.shared(1.0, name='alpha')
beta = theano.shared(1.0, name='beta')
z = beta * y + alpha * T.dot(A, x)
f = theano.function([A, x, y], z)
A_val = numpy.ones((5,3), dtype=config.floatX)
f(A_val, numpy.ones(3), numpy.ones(5))
self.assertRaises(ValueError, f, A_val, numpy.ones(4), numpy.ones(5))
self.assertRaises(ValueError, f, A_val, numpy.ones(3), numpy.ones(6))
self.assertRaises(ValueError, f, A_val, numpy.ones(4), numpy.ones(6))
# The following gemv tests were added in March 2011 by Ian Goodfellow
# and are based on the gemv tests from scipy
# http://projects.scipy.org/scipy/browser/trunk/scipy/linalg/tests/test_fblas.py?rev=6803
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论