提交 d24f5995 authored 作者: --global's avatar --global

Test conv3d also with integer subsample values

上级 e4f276a5
...@@ -1079,7 +1079,7 @@ def get_conv3d_test_cases(): ...@@ -1079,7 +1079,7 @@ def get_conv3d_test_cases():
[(6, 2, 2, 2, 2), (4, 2, 1, 1, 3), (1, 1, 1)], [(6, 2, 2, 2, 2), (4, 2, 1, 1, 3), (1, 1, 1)],
[(6, 2, 2, 2, 2), (4, 2, 5, 5, 5), (1, 1, 1)], [(6, 2, 2, 2, 2), (4, 2, 5, 5, 5), (1, 1, 1)],
] ]
border_modes = ['valid', 'full', (1, 2, 3), (3, 2, 1)] border_modes = ['valid', 'full', (1, 2, 3), (3, 2, 1), 1, 2]
conv_modes = ['conv', 'cross'] conv_modes = ['conv', 'cross']
itt = chain(product(test_shapes, border_modes, conv_modes), itt = chain(product(test_shapes, border_modes, conv_modes),
...@@ -1122,6 +1122,9 @@ def test_conv3d_fwd(): ...@@ -1122,6 +1122,9 @@ def test_conv3d_fwd():
else: else:
if border_mode == 'full': if border_mode == 'full':
pad_per_dim = [filters_shape[i] - 1 for i in range(2,5)] pad_per_dim = [filters_shape[i] - 1 for i in range(2,5)]
else:
if isinstance(border_mode, int):
pad_per_dim = [border_mode] * 3
else: else:
pad_per_dim = border_mode pad_per_dim = border_mode
...@@ -1186,6 +1189,9 @@ def test_conv3d_bwd(): ...@@ -1186,6 +1189,9 @@ def test_conv3d_bwd():
else: else:
if border_mode == 'full': if border_mode == 'full':
pad_per_dim = [filters_shape[i] - 1 for i in range(2,5)] pad_per_dim = [filters_shape[i] - 1 for i in range(2,5)]
else:
if isinstance(border_mode, int):
pad_per_dim = [border_mode] * 3
else: else:
pad_per_dim = border_mode pad_per_dim = border_mode
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论