提交 d3fb7189 authored 作者: Gijs van Tulder's avatar Gijs van Tulder

Add 3d tests for get_conv_output_shape.

上级 1b5d4fc4
......@@ -77,6 +77,23 @@ class TestGetConvOutShape(unittest.TestCase):
self.assertTrue(test3_params == (3, 4, 20, 7))
self.assertTrue(test4_params == (3, 4, 6, 4))
def test_basic_3d(self):
image_shape, kernel_shape = (3, 2, 12, 9, 7), (4, 2, 5, 6, 4)
sub_sample = (1, 2, 1)
filter_dilation = (2, 1, 1)
test1_params = get_conv_output_shape(
image_shape, kernel_shape, 'valid', sub_sample, filter_dilation)
test2_params = get_conv_output_shape(
image_shape, kernel_shape, 'half', sub_sample, filter_dilation)
test3_params = get_conv_output_shape(
image_shape, kernel_shape, 'full', sub_sample, filter_dilation)
test4_params = get_conv_output_shape(
image_shape, kernel_shape, (1, 2, 3), sub_sample, filter_dilation)
self.assertTrue(test1_params == (3, 4, 4, 2, 4))
self.assertTrue(test2_params == (3, 4, 12, 5, 8))
self.assertTrue(test3_params == (3, 4, 20, 7, 10))
self.assertTrue(test4_params == (3, 4, 6, 4, 10))
class BaseTestConv2d:
@classmethod
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论