提交 4182a1ad authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Change docstrings to comments.

上级 4475869c
...@@ -3,6 +3,8 @@ from __future__ import absolute_import, print_function, division ...@@ -3,6 +3,8 @@ from __future__ import absolute_import, print_function, division
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from nose.plugins.attrib import attr from nose.plugins.attrib import attr
from nose.tools import assert_equals from nose.tools import assert_equals
from nose_parameterized import parameterized
import numpy as np import numpy as np
from six import integer_types from six import integer_types
...@@ -151,11 +153,9 @@ class TestCorr3D(utt.InferShapeTester): ...@@ -151,11 +153,9 @@ class TestCorr3D(utt.InferShapeTester):
@attr('slow') @attr('slow')
def test_basic(self): def test_basic(self):
""" # Tests that basic correlations work for odd and even
Tests that basic correlations work for odd and even # dimensions of image and filter shapes, as well as rectangular
dimensions of image and filter shapes, as well as rectangular # images and filters.
images and filters.
"""
border_modes = ['valid', 'full', 'half', (1, 1, 1), border_modes = ['valid', 'full', 'half', (1, 1, 1),
(2, 1, 1), (1, 2, 1), (1, 1, 2), (2, 1, 1), (1, 2, 1), (1, 1, 2),
(3, 3, 3), 1] (3, 3, 3), 1]
...@@ -180,9 +180,7 @@ class TestCorr3D(utt.InferShapeTester): ...@@ -180,9 +180,7 @@ class TestCorr3D(utt.InferShapeTester):
@attr('slow') @attr('slow')
def test_subsample(self): def test_subsample(self):
""" # Tests correlation where subsampling != (1,1,1)
Tests correlation where subsampling != (1,1,1)
"""
self.validate((3, 2, 7, 5, 5), (2, 2, 2, 3, 3), 'valid', subsample=(2, 2, 2)) self.validate((3, 2, 7, 5, 5), (2, 2, 2, 3, 3), 'valid', subsample=(2, 2, 2))
self.validate((3, 2, 7, 5, 5), (2, 2, 2, 3, 3), 'valid', subsample=(2, 1, 1)) self.validate((3, 2, 7, 5, 5), (2, 2, 2, 3, 3), 'valid', subsample=(2, 1, 1))
self.validate((1, 1, 6, 6, 6), (1, 1, 3, 3, 3), 'valid', subsample=(3, 3, 3)) self.validate((1, 1, 6, 6, 6), (1, 1, 3, 3, 3), 'valid', subsample=(3, 3, 3))
...@@ -202,9 +200,7 @@ class TestCorr3D(utt.InferShapeTester): ...@@ -202,9 +200,7 @@ class TestCorr3D(utt.InferShapeTester):
self.validate((1, 1, 6, 6, 6), (1, 1, 3, 3, 3), 1, subsample=(3, 3, 3)) self.validate((1, 1, 6, 6, 6), (1, 1, 3, 3, 3), 1, subsample=(3, 3, 3))
def test_filter_dilation(self): def test_filter_dilation(self):
""" # Tests correlation where filter dilation != (1,1,1)
Tests correlation where filter dilation != (1,1,1)
"""
self.validate((3, 2, 7, 5, 5), (2, 2, 2, 3, 3), 'valid', filter_dilation=(2, 2, 2)) self.validate((3, 2, 7, 5, 5), (2, 2, 2, 3, 3), 'valid', filter_dilation=(2, 2, 2))
self.validate((3, 2, 14, 10, 10), (2, 2, 2, 3, 3), 'valid', filter_dilation=(3, 1, 1)) self.validate((3, 2, 14, 10, 10), (2, 2, 2, 3, 3), 'valid', filter_dilation=(3, 1, 1))
self.validate((1, 1, 14, 14, 14), (1, 1, 3, 3, 3), 'valid', filter_dilation=(2, 3, 3)) self.validate((1, 1, 14, 14, 14), (1, 1, 3, 3, 3), 'valid', filter_dilation=(2, 3, 3))
...@@ -242,18 +238,14 @@ class TestCorr3D(utt.InferShapeTester): ...@@ -242,18 +238,14 @@ class TestCorr3D(utt.InferShapeTester):
border_mode) border_mode)
def test_invalid_filter_shape(self): def test_invalid_filter_shape(self):
""" # Tests scenario where filter_shape[1] != input_shape[1]
Tests scenario where filter_shape[1] != input_shape[1]
"""
self.assertRaises(ValueError, self.validate, self.assertRaises(ValueError, self.validate,
(3, 2, 8, 8, 8), (4, 3, 5, 5, 8), (3, 2, 8, 8, 8), (4, 3, 5, 5, 8),
'valid') 'valid')
def test_full_mode(self): def test_full_mode(self):
""" # Tests basic correlation in full mode and case where filter
Tests basic correlation in full mode and case where filter # is larger than the input image.
is larger than the input image.
"""
self.validate((3, 1, 4, 4, 4), (2, 1, 5, 5, 5), 'full') self.validate((3, 1, 4, 4, 4), (2, 1, 5, 5, 5), 'full')
def f(): def f():
...@@ -261,9 +253,7 @@ class TestCorr3D(utt.InferShapeTester): ...@@ -261,9 +253,7 @@ class TestCorr3D(utt.InferShapeTester):
self.assertRaises(Exception, f) self.assertRaises(Exception, f)
def test_wrong_input(self): def test_wrong_input(self):
""" # Make sure errors are raised when image and kernel are not 5D tensors
Make sure errors are raised when image and kernel are not 5D tensors
"""
self.assertRaises(Exception, self.validate, (3, 2, 8, 8, 8), (4, 2, 5, 5, 5), self.assertRaises(Exception, self.validate, (3, 2, 8, 8, 8), (4, 2, 5, 5, 5),
'valid', input=T.dmatrix()) 'valid', input=T.dmatrix())
self.assertRaises(Exception, self.validate, (3, 2, 8, 8, 8), (4, 2, 5, 5, 5), self.assertRaises(Exception, self.validate, (3, 2, 8, 8, 8), (4, 2, 5, 5, 5),
...@@ -274,9 +264,7 @@ class TestCorr3D(utt.InferShapeTester): ...@@ -274,9 +264,7 @@ class TestCorr3D(utt.InferShapeTester):
'valid', input=T.dtensor4()) 'valid', input=T.dtensor4())
def test_dtype_upcast(self): def test_dtype_upcast(self):
""" # Checks dtype upcast for Corr3dMM methods.
Checks dtype upcast for Corr3dMM methods.
"""
if not theano.config.cxx: if not theano.config.cxx:
raise SkipTest("Need cxx for this test") raise SkipTest("Need cxx for this test")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论