提交 b2f47d62 authored 作者: Gabe Schwartz's avatar Gabe Schwartz

Proper input shapes for 3D dilated conv tests.

上级 3938ba99
...@@ -1121,7 +1121,7 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1), dilation=(1 ...@@ -1121,7 +1121,7 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1), dilation=(1
conv = GpuDnnConvGradW()(img, kerns, out, desc) conv = GpuDnnConvGradW()(img, kerns, out, desc)
return as_gpuarray_variable(conv.dimshuffle(1, 0, 2, 3, 4), ctx_name) return as_gpuarray_variable(conv.dimshuffle(1, 0, 2, 3, 4), ctx_name)
elif (border_mode == 'full' and subsample == (1, 1, 1) and elif (border_mode == 'full' and subsample == (1, 1, 1) and dilation == (1, 1, 1) and
direction_hint != 'forward!'): direction_hint != 'forward!'):
# Special case: We can be faster by using GpuDnnConvGradI to compute # Special case: We can be faster by using GpuDnnConvGradI to compute
# the full convolution as the backward pass of a valid convolution. # the full convolution as the backward pass of a valid convolution.
......
...@@ -656,17 +656,15 @@ class TestDnnInferShapes(utt.InferShapeTester): ...@@ -656,17 +656,15 @@ class TestDnnInferShapes(utt.InferShapeTester):
@parameterized.expand(product(border_modes, conv_modes), utt.custom_name_func) @parameterized.expand(product(border_modes, conv_modes), utt.custom_name_func)
def test_conv3d_none(self, border_mode, conv_mode): def test_conv3d_none(self, border_mode, conv_mode):
# CUDNN docs don't say that 3D conv can't handle dilation, but it returns
# CUDNN_STATUS_NOT_SUPPORTED if you try it.
self._test_conv(T.tensor5('img'), self._test_conv(T.tensor5('img'),
T.tensor5('kerns'), T.tensor5('kerns'),
T.tensor5('out'), T.tensor5('out'),
np.random.rand(10, 2, 6, 4, 11), np.random.rand(10, 2, 15, 16, 17),
np.random.rand(8, 2, 4, 3, 1), np.random.rand(8, 2, 4, 3, 1),
border_mode, border_mode,
conv_mode, conv_mode,
[(1, 1, 1), (2, 2, 2)], [(1, 1, 1), (2, 2, 2)],
[(1, 1, 1)], [(1, 1, 1), (2, 2, 2)],
'none') 'none')
def _test_conv_gradw(self, img, topgrad, kerns, img_shape, kerns_shape, border_mode, conv_mode, subsamples, dilations): def _test_conv_gradw(self, img, topgrad, kerns, img_shape, kerns_shape, border_mode, conv_mode, subsamples, dilations):
...@@ -1025,6 +1023,8 @@ def get_conv3d_test_cases(): ...@@ -1025,6 +1023,8 @@ def get_conv3d_test_cases():
[(8, 4, 20, 12, 15), (5, 4, 6, 12, 4), (2, 2, 2), (1, 1, 1)], [(8, 4, 20, 12, 15), (5, 4, 6, 12, 4), (2, 2, 2), (1, 1, 1)],
[(8, 1, 20, 12, 15), (5, 1, 6, 12, 4), (3, 3, 3), (1, 1, 1)], [(8, 1, 20, 12, 15), (5, 1, 6, 12, 4), (3, 3, 3), (1, 1, 1)],
[(8, 1, 20, 12, 15), (5, 1, 6, 12, 4), (3, 2, 1), (1, 1, 1)], [(8, 1, 20, 12, 15), (5, 1, 6, 12, 4), (3, 2, 1), (1, 1, 1)],
[(8, 1, 20, 12, 15), (5, 1, 6, 3, 4), (1, 1, 2), (3, 2, 1)],
[(8, 1, 20, 12, 15), (5, 1, 6, 3, 4), (2, 2, 1), (1, 2, 3)],
# Test with 1x1x1 filters # Test with 1x1x1 filters
[(8, 1, 10, 10, 10), (10, 1, 1, 1, 1), (1, 1, 1), (1, 1, 1)], [(8, 1, 10, 10, 10), (10, 1, 1, 1, 1), (1, 1, 1), (1, 1, 1)],
# Test with dimensions larger than 1024 (thread block dim) # Test with dimensions larger than 1024 (thread block dim)
...@@ -1042,7 +1042,8 @@ def get_conv3d_test_cases(): ...@@ -1042,7 +1042,8 @@ def get_conv3d_test_cases():
test_shapes_full = [[(6, 2, 2, 2, 2), (4, 2, 3, 1, 1), (1, 1, 1), (1, 1, 1)], test_shapes_full = [[(6, 2, 2, 2, 2), (4, 2, 3, 1, 1), (1, 1, 1), (1, 1, 1)],
[(6, 2, 2, 2, 2), (4, 2, 1, 3, 1), (1, 1, 1), (1, 1, 1)], [(6, 2, 2, 2, 2), (4, 2, 1, 3, 1), (1, 1, 1), (1, 1, 1)],
[(6, 2, 2, 2, 2), (4, 2, 1, 1, 3), (1, 1, 1), (1, 1, 1)], [(6, 2, 2, 2, 2), (4, 2, 1, 1, 3), (1, 1, 1), (1, 1, 1)],
[(6, 2, 2, 2, 2), (4, 2, 5, 5, 5), (1, 1, 1), (1, 1, 1)]] [(6, 2, 2, 2, 2), (4, 2, 5, 5, 5), (1, 1, 1), (1, 1, 1)],
[(6, 2, 2, 2, 2), (4, 2, 5, 5, 5), (1, 1, 1), (3, 2, 1)]]
border_modes = ['valid', 'full', 'half', (1, 2, 3), (3, 2, 1), 1, 2] border_modes = ['valid', 'full', 'half', (1, 2, 3), (3, 2, 1), 1, 2]
conv_modes = ['conv', 'cross'] conv_modes = ['conv', 'cross']
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论