提交 5d2dd362 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #6034 from abergeron/split_long

Split long-running test so that it helps travis not to give up
...@@ -3,6 +3,8 @@ from __future__ import absolute_import, print_function, division ...@@ -3,6 +3,8 @@ from __future__ import absolute_import, print_function, division
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from nose.plugins.attrib import attr from nose.plugins.attrib import attr
from nose.tools import assert_equals from nose.tools import assert_equals
from nose_parameterized import parameterized
import numpy as np import numpy as np
from six import integer_types from six import integer_types
...@@ -151,11 +153,9 @@ class TestCorr3D(utt.InferShapeTester): ...@@ -151,11 +153,9 @@ class TestCorr3D(utt.InferShapeTester):
@attr('slow') @attr('slow')
def test_basic(self): def test_basic(self):
""" # Tests that basic correlations work for odd and even
Tests that basic correlations work for odd and even # dimensions of image and filter shapes, as well as rectangular
dimensions of image and filter shapes, as well as rectangular # images and filters.
images and filters.
"""
border_modes = ['valid', 'full', 'half', (1, 1, 1), border_modes = ['valid', 'full', 'half', (1, 1, 1),
(2, 1, 1), (1, 2, 1), (1, 1, 2), (2, 1, 1), (1, 2, 1), (1, 1, 2),
(3, 3, 3), 1] (3, 3, 3), 1]
...@@ -180,9 +180,7 @@ class TestCorr3D(utt.InferShapeTester): ...@@ -180,9 +180,7 @@ class TestCorr3D(utt.InferShapeTester):
@attr('slow') @attr('slow')
def test_subsample(self): def test_subsample(self):
""" # Tests correlation where subsampling != (1,1,1)
Tests correlation where subsampling != (1,1,1)
"""
self.validate((3, 2, 7, 5, 5), (2, 2, 2, 3, 3), 'valid', subsample=(2, 2, 2)) self.validate((3, 2, 7, 5, 5), (2, 2, 2, 3, 3), 'valid', subsample=(2, 2, 2))
self.validate((3, 2, 7, 5, 5), (2, 2, 2, 3, 3), 'valid', subsample=(2, 1, 1)) self.validate((3, 2, 7, 5, 5), (2, 2, 2, 3, 3), 'valid', subsample=(2, 1, 1))
self.validate((1, 1, 6, 6, 6), (1, 1, 3, 3, 3), 'valid', subsample=(3, 3, 3)) self.validate((1, 1, 6, 6, 6), (1, 1, 3, 3, 3), 'valid', subsample=(3, 3, 3))
...@@ -202,9 +200,7 @@ class TestCorr3D(utt.InferShapeTester): ...@@ -202,9 +200,7 @@ class TestCorr3D(utt.InferShapeTester):
self.validate((1, 1, 6, 6, 6), (1, 1, 3, 3, 3), 1, subsample=(3, 3, 3)) self.validate((1, 1, 6, 6, 6), (1, 1, 3, 3, 3), 1, subsample=(3, 3, 3))
def test_filter_dilation(self): def test_filter_dilation(self):
""" # Tests correlation where filter dilation != (1,1,1)
Tests correlation where filter dilation != (1,1,1)
"""
self.validate((3, 2, 7, 5, 5), (2, 2, 2, 3, 3), 'valid', filter_dilation=(2, 2, 2)) self.validate((3, 2, 7, 5, 5), (2, 2, 2, 3, 3), 'valid', filter_dilation=(2, 2, 2))
self.validate((3, 2, 14, 10, 10), (2, 2, 2, 3, 3), 'valid', filter_dilation=(3, 1, 1)) self.validate((3, 2, 14, 10, 10), (2, 2, 2, 3, 3), 'valid', filter_dilation=(3, 1, 1))
self.validate((1, 1, 14, 14, 14), (1, 1, 3, 3, 3), 'valid', filter_dilation=(2, 3, 3)) self.validate((1, 1, 14, 14, 14), (1, 1, 3, 3, 3), 'valid', filter_dilation=(2, 3, 3))
...@@ -224,38 +220,32 @@ class TestCorr3D(utt.InferShapeTester): ...@@ -224,38 +220,32 @@ class TestCorr3D(utt.InferShapeTester):
self.validate((1, 1, 6, 6, 6), (1, 1, 3, 3, 3), 1, subsample=(3, 3, 3), filter_dilation=(2, 2, 2)) self.validate((1, 1, 6, 6, 6), (1, 1, 3, 3, 3), 1, subsample=(3, 3, 3), filter_dilation=(2, 2, 2))
@attr('slow') @parameterized.expand([('valid',), ('full',), ('half',), ((1, 1, 1),),
def test_shape_Constant_tensor(self): ((2, 1, 1),), ((1, 2, 1),), ((1, 1, 2),),
""" ((3, 3, 3),), (1,)])
Tests correlation where the {image,filter}_shape is a Constant tensor. # @attr('slow')
""" def test_shape_Constant_tensor(self, border_mode):
# Tests correlation where the {image,filter}_shape is a Constant tensor
as_t = T.as_tensor_variable as_t = T.as_tensor_variable
border_modes = ['valid', 'full', 'half', (1, 1, 1), (2, 1, 1), self.validate((as_t(3), as_t(2), as_t(7), as_t(5), as_t(5)),
(1, 2, 1), (1, 1, 2), (3, 3, 3), 1] (5, 2, 2, 3, 3), border_mode)
self.validate(as_t([3, 2, 7, 5, 5]), (5, 2, 2, 3, 3), border_mode)
for border_mode in border_modes: self.validate(as_t((3, 2, 7, 5, 5)), (5, 2, 2, 3, 3), border_mode)
self.validate((as_t(3), as_t(2), as_t(7), as_t(5), as_t(5)), self.validate((3, 2, 7, 5, 5), (as_t(5), as_t(2), as_t(2),
(5, 2, 2, 3, 3), border_mode) as_t(3), as_t(3)), 'valid')
self.validate(as_t([3, 2, 7, 5, 5]), (5, 2, 2, 3, 3), border_mode) self.validate((3, 2, 7, 5, 5), as_t([5, 2, 2, 3, 3]), border_mode)
self.validate(as_t((3, 2, 7, 5, 5)), (5, 2, 2, 3, 3), border_mode) self.validate(as_t([3, 2, 7, 5, 5]), as_t([5, 2, 2, 3, 3]),
self.validate((3, 2, 7, 5, 5), (as_t(5), as_t(2), as_t(2), border_mode)
as_t(3), as_t(3)), 'valid')
self.validate((3, 2, 7, 5, 5), as_t([5, 2, 2, 3, 3]), border_mode)
self.validate(as_t([3, 2, 7, 5, 5]), as_t([5, 2, 2, 3, 3]), border_mode)
def test_invalid_filter_shape(self): def test_invalid_filter_shape(self):
""" # Tests scenario where filter_shape[1] != input_shape[1]
Tests scenario where filter_shape[1] != input_shape[1]
"""
self.assertRaises(ValueError, self.validate, self.assertRaises(ValueError, self.validate,
(3, 2, 8, 8, 8), (4, 3, 5, 5, 8), (3, 2, 8, 8, 8), (4, 3, 5, 5, 8),
'valid') 'valid')
def test_full_mode(self): def test_full_mode(self):
""" # Tests basic correlation in full mode and case where filter
Tests basic correlation in full mode and case where filter # is larger than the input image.
is larger than the input image.
"""
self.validate((3, 1, 4, 4, 4), (2, 1, 5, 5, 5), 'full') self.validate((3, 1, 4, 4, 4), (2, 1, 5, 5, 5), 'full')
def f(): def f():
...@@ -263,9 +253,7 @@ class TestCorr3D(utt.InferShapeTester): ...@@ -263,9 +253,7 @@ class TestCorr3D(utt.InferShapeTester):
self.assertRaises(Exception, f) self.assertRaises(Exception, f)
def test_wrong_input(self): def test_wrong_input(self):
""" # Make sure errors are raised when image and kernel are not 5D tensors
Make sure errors are raised when image and kernel are not 5D tensors
"""
self.assertRaises(Exception, self.validate, (3, 2, 8, 8, 8), (4, 2, 5, 5, 5), self.assertRaises(Exception, self.validate, (3, 2, 8, 8, 8), (4, 2, 5, 5, 5),
'valid', input=T.dmatrix()) 'valid', input=T.dmatrix())
self.assertRaises(Exception, self.validate, (3, 2, 8, 8, 8), (4, 2, 5, 5, 5), self.assertRaises(Exception, self.validate, (3, 2, 8, 8, 8), (4, 2, 5, 5, 5),
...@@ -276,9 +264,7 @@ class TestCorr3D(utt.InferShapeTester): ...@@ -276,9 +264,7 @@ class TestCorr3D(utt.InferShapeTester):
'valid', input=T.dtensor4()) 'valid', input=T.dtensor4())
def test_dtype_upcast(self): def test_dtype_upcast(self):
""" # Checks dtype upcast for Corr3dMM methods.
Checks dtype upcast for Corr3dMM methods.
"""
if not theano.config.cxx: if not theano.config.cxx:
raise SkipTest("Need cxx for this test") raise SkipTest("Need cxx for this test")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论