提交 99295012 authored 作者: Markus Roth's avatar Markus Roth

Refactor. Use decorators for skipping cuda-only-tests.

上级 ff0abb5f
import unittest
import numpy
from nose.plugins.skip import SkipTest
from numpy.testing.decorators import skipif
import theano
from theano import tensor
......@@ -8,13 +8,10 @@ from theano import tensor
from theano.sandbox.cuda.var import float32_shared_constructor as f32sc
from theano.sandbox.cuda import CudaNdarrayType, cuda_available
import theano.sandbox.cuda as cuda
# Skip test if cuda_ndarray is not available.
if cuda_available == False:
raise SkipTest('Optional package cuda disabled')
@skipif(not cuda_available, msg='Optional package cuda disabled')
def test_float32_shared_constructor():
npy_row = numpy.zeros((1, 10), dtype='float32')
def eq(a, b):
......@@ -40,7 +37,7 @@ def test_float32_shared_constructor():
f32sc(numpy.zeros((2, 3, 4, 5), dtype='float32')).type,
CudaNdarrayType((False,) * 4))
@skipif(not cuda_available, msg='Optional package cuda disabled')
def test_givens():
# Test that you can use a TensorType expression to replace a
# CudaNdarrayType in the givens dictionary.
......@@ -56,6 +53,7 @@ class T_updates(unittest.TestCase):
# Test that you can use a TensorType expression to update a
# CudaNdarrayType in the updates dictionary.
@skipif(not cuda_available, msg='Optional package cuda disabled')
def test_1(self):
data = numpy.float32([1, 2, 3, 4])
x = f32sc(data)
......@@ -67,6 +65,7 @@ class T_updates(unittest.TestCase):
f = theano.function([], y, updates=[(x, cuda.gpu_from_host(x + 1))])
f()
@skipif(not cuda_available, msg='Optional package cuda disabled')
def test_2(self):
# This test case uses code mentionned in #698
data = numpy.random.rand(10, 10).astype('float32')
......@@ -80,6 +79,7 @@ class T_updates(unittest.TestCase):
updates=output_updates, givens=output_givens)
output_func()
@skipif(not cuda_available, msg='Optional package cuda disabled')
def test_err_ndim(self):
# Test that we raise a good error message when we don't
# have the same number of dimensions.
......@@ -92,6 +92,7 @@ class T_updates(unittest.TestCase):
updates=[(output_var,
output_var.sum())])
@skipif(not cuda_available, msg='Optional package cuda disabled')
def test_err_broadcast(self):
# Test that we raise a good error message when we don't
# have the same number of dimensions.
......@@ -104,6 +105,7 @@ class T_updates(unittest.TestCase):
updates=[(output_var,
output_var.sum().dimshuffle('x', 'x'))])
@skipif(not cuda_available, msg='Optional package cuda disabled')
def test_broadcast(self):
# Test that we can rebroadcast
data = numpy.random.rand(10, 10).astype('float32')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论