提交 dab48647 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Expand another corr3d test

上级 37608c88
......@@ -199,25 +199,25 @@ class TestCorr3D(utt.InferShapeTester):
self.validate((1, 1, 6, 6, 6), (1, 1, 3, 3, 3), 1, subsample=(3, 3, 3))
def test_filter_dilation(self):
# Tests correlation where filter dilation != (1,1,1)
self.validate((3, 2, 7, 5, 5), (2, 2, 2, 3, 3), 'valid', filter_dilation=(2, 2, 2))
self.validate((3, 2, 14, 10, 10), (2, 2, 2, 3, 3), 'valid', filter_dilation=(3, 1, 1))
self.validate((1, 1, 14, 14, 14), (1, 1, 3, 3, 3), 'valid', filter_dilation=(2, 3, 3))
self.validate((3, 2, 7, 5, 5), (2, 2, 2, 3, 3), 'full', filter_dilation=(2, 2, 2))
self.validate((3, 2, 7, 5, 5), (2, 2, 2, 3, 3), 'full', filter_dilation=(3, 1, 1))
self.validate((1, 1, 6, 6, 6), (1, 1, 3, 3, 3), 'full', filter_dilation=(2, 3, 3))
self.validate((3, 2, 7, 5, 5), (2, 2, 2, 3, 3), 'half', filter_dilation=(2, 2, 2))
self.validate((3, 2, 7, 5, 5), (2, 2, 2, 3, 3), 'half', filter_dilation=(3, 1, 1))
self.validate((1, 1, 6, 6, 6), (1, 1, 3, 3, 3), 'half', filter_dilation=(2, 3, 3))
self.validate((3, 2, 7, 5, 5), (2, 2, 2, 3, 3), (1, 1, 1), filter_dilation=(2, 2, 2))
self.validate((3, 2, 7, 5, 5), (2, 2, 2, 3, 3), (2, 1, 1), filter_dilation=(2, 1, 1))
self.validate((1, 1, 6, 6, 6), (1, 1, 3, 3, 3), (1, 2, 1), filter_dilation=(1, 2, 1))
self.validate((1, 1, 6, 6, 6), (1, 1, 3, 3, 3), (1, 1, 2), filter_dilation=(1, 1, 2))
# Tests correlation where filter dilation != (1,1,1)
@parameterized.expand([
((3, 2, 7, 5, 5), (2, 2, 2, 3, 3), 'valid', (2, 2, 2)),
((3, 2, 14, 10, 10), (2, 2, 2, 3, 3), 'valid', (3, 1, 1)),
((1, 1, 14, 14, 14), (1, 1, 3, 3, 3), 'valid', (2, 3, 3)),
((3, 2, 7, 5, 5), (2, 2, 2, 3, 3), 'full', (2, 2, 2)),
((3, 2, 7, 5, 5), (2, 2, 2, 3, 3), 'full', (3, 1, 1)),
((1, 1, 6, 6, 6), (1, 1, 3, 3, 3), 'full', (2, 3, 3)),
((3, 2, 7, 5, 5), (2, 2, 2, 3, 3), 'half', (2, 2, 2)),
((3, 2, 7, 5, 5), (2, 2, 2, 3, 3), 'half', (3, 1, 1)),
((1, 1, 6, 6, 6), (1, 1, 3, 3, 3), 'half', (2, 3, 3)),
((3, 2, 7, 5, 5), (2, 2, 2, 3, 3), (1, 1, 1), (2, 2, 2)),
((3, 2, 7, 5, 5), (2, 2, 2, 3, 3), (2, 1, 1), (2, 1, 1)),
((1, 1, 6, 6, 6), (1, 1, 3, 3, 3), (1, 2, 1), (1, 2, 1)),
((1, 1, 6, 6, 6), (1, 1, 3, 3, 3), (1, 1, 2), (1, 1, 2))])
def test_filter_dilation(self, image_shape, filter_shape, border_mode, filter_dilation):
self.validate(image_shape, filter_shape, border_mode, filter_dilation=filter_dilation)
def test_filter_dilation_subsample(self):
self.validate((1, 1, 6, 6, 6), (1, 1, 3, 3, 3), 1, subsample=(3, 3, 3), filter_dilation=(2, 2, 2))
@parameterized.expand([('valid',), ('full',), ('half',), ((1, 1, 1),),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论