提交 29b0fa88 authored 作者: James Bergstra's avatar James Bergstra 提交者: Frederic

fixing test_blas post cblas ops and opts

上级 1e212181
...@@ -872,22 +872,21 @@ def test_dot_w_self(): ...@@ -872,22 +872,21 @@ def test_dot_w_self():
## Tests for Gemv ## Tests for Gemv
############################################################################### ###############################################################################
class TestGemv(TestCase): class TestGemv(TestCase, unittest_tools.TestOptimizationMixin):
def test_dot_vm(self): def test_dot_vm(self):
''' Test vector dot matrix ''' ''' Test vector dot matrix '''
rng = numpy.random.RandomState(unittest_tools.fetch_seed()) rng = numpy.random.RandomState(unittest_tools.fetch_seed())
v = theano.shared(numpy.array(rng.uniform(size=(2,)), dtype='float32')) v = theano.shared(numpy.array(rng.uniform(size=(2,)), dtype='float32'))
m = theano.shared(numpy.array(rng.uniform(size=(2,3)), dtype='float32')) m = theano.shared(numpy.array(rng.uniform(size=(2,3)), dtype='float32'))
f = theano.function([], theano.dot(v,m), mode = mode_blas_opt) f = theano.function([], theano.dot(v,m), mode=mode_blas_opt)
# Assert that the dot was optimized somehow
self.assertFunctionContains0(f, T.dot)
self.assertFunctionContains1(f, Gemv(True))
# Assert they produce the same output # Assert they produce the same output
assert numpy.allclose(f(), numpy.dot(v.get_value(), m.get_value())) assert numpy.allclose(f(), numpy.dot(v.get_value(), m.get_value()))
# Assert that the dot was optimized somehow
assert sum([isinstance(node.op, T.Dot) for node in
f.maker.env.toposort() ]) == 0
assert sum([isinstance(node.op, T.blas.Dot22) for node in
f.maker.env.toposort() ]) == 1
def test_dot_mv(self): def test_dot_mv(self):
''' Test matrix dot vector ''' ''' Test matrix dot vector '''
...@@ -895,17 +894,15 @@ class TestGemv(TestCase): ...@@ -895,17 +894,15 @@ class TestGemv(TestCase):
v = theano.shared(numpy.array(rng.uniform(size=(2,)), dtype='float32')) v = theano.shared(numpy.array(rng.uniform(size=(2,)), dtype='float32'))
m = theano.shared(numpy.array(rng.uniform(size=(3,2)), m = theano.shared(numpy.array(rng.uniform(size=(3,2)),
dtype='float32')) dtype='float32'))
f = theano.function([], theano.dot(m,v), mode = mode_blas_opt) f = theano.function([], theano.dot(m,v), mode=mode_blas_opt)
# Assert that the dot was optimized somehow
self.assertFunctionContains0(f, T.dot)
self.assertFunctionContains1(f, Gemv(True))
# Assert they produce the same output # Assert they produce the same output
assert numpy.allclose(f(), numpy.dot(m.get_value(), v.get_value())) assert numpy.allclose(f(), numpy.dot(m.get_value(), v.get_value()))
# Assert that the dot was optimized somehow
assert sum([isinstance(node.op, T.Dot) for node in
f.maker.env.toposort() ]) == 0
assert sum([isinstance(node.op, T.blas.Dot22) for node in
f.maker.env.toposort() ]) == 1
@staticmethod @staticmethod
def t_gemv1(m_shp): def t_gemv1(m_shp):
''' test vector2+dot(matrix,vector1) ''' ''' test vector2+dot(matrix,vector1) '''
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论