提交 351efd7a authored 作者: Frederic's avatar Frederic

moved some gpu ifelse tests to the common cpu/gpu tests class.

上级 f46a6c00
...@@ -339,6 +339,7 @@ class TestIfElse(theano.tests.test_ifelse.test_ifelse): ...@@ -339,6 +339,7 @@ class TestIfElse(theano.tests.test_ifelse.test_ifelse):
dtype = "float32" dtype = "float32"
mode = mode_with_gpu mode = mode_with_gpu
cast_output = staticmethod(basic_ops.as_cuda_ndarray_variable) cast_output = staticmethod(basic_ops.as_cuda_ndarray_variable)
shared = staticmethod(cuda.shared_constructor)
def get_ifelse(self, n): def get_ifelse(self, n):
return theano.ifelse.IfElse(n, gpu=True, as_view=True) return theano.ifelse.IfElse(n, gpu=True, as_view=True)
......
...@@ -5,7 +5,6 @@ from nose.plugins.skip import SkipTest ...@@ -5,7 +5,6 @@ from nose.plugins.skip import SkipTest
import theano import theano
from theano import tensor from theano import tensor
from theano.ifelse import ifelse
from theano import sparse from theano import sparse
from theano.tensor import TensorType from theano.tensor import TensorType
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
...@@ -91,77 +90,3 @@ class T_updates(unittest.TestCase): ...@@ -91,77 +90,3 @@ class T_updates(unittest.TestCase):
output_func = theano.function(inputs=[], outputs=[], output_func = theano.function(inputs=[], outputs=[],
updates={output_var: output_var.sum().dimshuffle('x', 'x')}) updates={output_var: output_var.sum().dimshuffle('x', 'x')})
output_func() output_func()
class T_ifelse(unittest.TestCase):
def setUp(self):
utt.seed_rng()
self.rng = numpy.random.RandomState(seed=utt.fetch_seed())
def test_cuda_tensor(self):
data = self.rng.rand(4).astype('float32')
x = f32sc(data)
y = x + 1
cond = theano.tensor.iscalar('cond')
assert isinstance(x.type, CudaNdarrayType)
assert isinstance(y.type, TensorType)
out1 = ifelse(cond, x, y)
out2 = ifelse(cond, y, x)
assert isinstance(out1.type, TensorType)
assert isinstance(out2.type, TensorType)
f = theano.function([cond], out1)
g = theano.function([cond], out2)
assert numpy.all(f(0) == data + 1)
assert numpy.all(f(1) == data)
assert numpy.all(g(0) == data)
assert numpy.all(g(1) == data + 1)
def test_dtype_mismatch(self):
data = self.rng.rand(5).astype('float32')
x = f32sc(data)
y = tensor.cast(x, 'float64')
cond = theano.tensor.iscalar('cond')
self.assertRaises(TypeError, ifelse, cond, x, y)
self.assertRaises(TypeError, ifelse, cond, y, x)
def test_ndim_mismatch(self):
data = self.rng.rand(5).astype('float32')
x = f32sc(data)
y = tensor.fcol('y')
cond = theano.tensor.iscalar('cond')
self.assertRaises(TypeError, ifelse, cond, x, y)
self.assertRaises(TypeError, ifelse, cond, y, x)
def test_broadcast_mismatch(self):
data = self.rng.rand(2, 3).astype('float32')
x = f32sc(data)
print x.broadcastable
y = tensor.frow('y')
print y.broadcastable
cond = theano.tensor.iscalar('cond')
self.assertRaises(TypeError, ifelse, cond, x, y)
self.assertRaises(TypeError, ifelse, cond, y, x)
def test_sparse_tensor_error(self):
data = self.rng.rand(2, 3).astype('float32')
x = f32sc(data)
y = sparse.matrix('csc', dtype='float32', name='y')
z = sparse.matrix('csr', dtype='float32', name='z')
cond = theano.tensor.iscalar('cond')
# Right now (2012-01-19), a ValueError gets raised, but I thing
# a TypeError (like in the other cases) would be fine.
self.assertRaises((TypeError, ValueError), ifelse, cond, x, y)
self.assertRaises((TypeError, ValueError), ifelse, cond, y, x)
self.assertRaises((TypeError, ValueError), ifelse, cond, x, z)
self.assertRaises((TypeError, ValueError), ifelse, cond, z, x)
self.assertRaises((TypeError, ValueError), ifelse, cond, y, z)
self.assertRaises((TypeError, ValueError), ifelse, cond, z, y)
...@@ -22,6 +22,7 @@ class test_ifelse(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -22,6 +22,7 @@ class test_ifelse(unittest.TestCase, utt.TestOptimizationMixin):
mode = None mode = None
dtype = theano.config.floatX dtype = theano.config.floatX
cast_output = staticmethod(tensor.as_tensor_variable) cast_output = staticmethod(tensor.as_tensor_variable)
shared = staticmethod(theano.shared)
def get_ifelse(self, n): def get_ifelse(self, n):
if theano.config.mode == "FAST_COMPILE": if theano.config.mode == "FAST_COMPILE":
...@@ -157,6 +158,56 @@ class test_ifelse(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -157,6 +158,56 @@ class test_ifelse(unittest.TestCase, utt.TestOptimizationMixin):
assert numpy.all(outs_0[2] == 1.) assert numpy.all(outs_0[2] == 1.)
assert numpy.all(outs_0[3] == 1.) assert numpy.all(outs_0[3] == 1.)
def test_dtype_mismatch(self):
rng = numpy.random.RandomState(utt.fetch_seed())
data = rng.rand(5).astype(self.dtype)
x = self.shared(data)
y = tensor.cast(x * 10, 'int8')
cond = theano.tensor.iscalar('cond')
self.assertRaises(TypeError, ifelse, cond, x, y)
self.assertRaises(TypeError, ifelse, cond, y, x)
def test_ndim_mismatch(self):
rng = numpy.random.RandomState(utt.fetch_seed())
data = rng.rand(5).astype(self.dtype)
x = self.shared(data)
y = tensor.col('y', self.dtype)
cond = theano.tensor.iscalar('cond')
self.assertRaises(TypeError, ifelse, cond, x, y)
self.assertRaises(TypeError, ifelse, cond, y, x)
def test_broadcast_mismatch(self):
rng = numpy.random.RandomState(utt.fetch_seed())
data = rng.rand(5).astype(self.dtype)
x = self.shared(data)
#print x.broadcastable
y = tensor.row('y', self.dtype)
#print y.broadcastable
cond = theano.tensor.iscalar('cond')
self.assertRaises(TypeError, ifelse, cond, x, y)
self.assertRaises(TypeError, ifelse, cond, y, x)
def test_sparse_tensor_error(self):
import theano.sparse
if not theano.sparse.enable_sparse:
raise SkipTest("Optimization temporarily disabled")
rng = numpy.random.RandomState(utt.fetch_seed())
data = rng.rand(2, 3).astype(self.dtype)
x = self.shared(data)
y = theano.sparse.matrix('csc', dtype=self.dtype, name='y')
z = theano.sparse.matrix('csr', dtype=self.dtype, name='z')
cond = theano.tensor.iscalar('cond')
self.assertRaises(TypeError, ifelse, cond, x, y)
self.assertRaises(TypeError, ifelse, cond, y, x)
self.assertRaises(TypeError, ifelse, cond, x, z)
self.assertRaises(TypeError, ifelse, cond, z, x)
self.assertRaises(TypeError, ifelse, cond, y, z)
self.assertRaises(TypeError, ifelse, cond, z, y)
def test_merge(self): def test_merge(self):
raise SkipTest("Optimization temporarily disabled") raise SkipTest("Optimization temporarily disabled")
x = tensor.vector('x') x = tensor.vector('x')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论