提交 e68b5055 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix another test in float32.

上级 0a18d7da
...@@ -765,17 +765,26 @@ class TestGemv(TestCase): ...@@ -765,17 +765,26 @@ class TestGemv(TestCase):
def test_gemv_dimensions(self): def test_gemv_dimensions(self):
A = T.matrix('A') A = T.matrix('A')
x, y = T.vectors('x', 'y') x, y = T.vectors('x', 'y')
alpha = theano.shared(1.0, name='alpha') alpha = theano.shared(theano._asarray(1.0, dtype=config.floatX),
beta = theano.shared(1.0, name='beta') name='alpha')
beta = theano.shared(theano._asarray(1.0, dtype=config.floatX),
name='beta')
z = beta * y + alpha * T.dot(A, x) z = beta * y + alpha * T.dot(A, x)
f = theano.function([A, x, y], z) f = theano.function([A, x, y], z)
# Matrix value
A_val = numpy.ones((5,3), dtype=config.floatX) A_val = numpy.ones((5,3), dtype=config.floatX)
f(A_val, numpy.ones(3), numpy.ones(5)) # Different vector length
self.assertRaises(ValueError, f, A_val, numpy.ones(4), numpy.ones(5)) ones_3 = numpy.ones(3, dtype=config.floatX)
self.assertRaises(ValueError, f, A_val, numpy.ones(3), numpy.ones(6)) ones_4 = numpy.ones(4, dtype=config.floatX)
self.assertRaises(ValueError, f, A_val, numpy.ones(4), numpy.ones(6)) ones_5 = numpy.ones(5, dtype=config.floatX)
ones_6 = numpy.ones(6, dtype=config.floatX)
f(A_val, ones_3, ones_5)
self.assertRaises(ValueError, f, A_val, ones_4, ones_5)
self.assertRaises(ValueError, f, A_val, ones_3, ones_6)
self.assertRaises(ValueError, f, A_val, ones_4, ones_6)
# The following gemv tests were added in March 2011 by Ian Goodfellow # The following gemv tests were added in March 2011 by Ian Goodfellow
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论