提交 b2f47d62 authored 作者: Gabe Schwartz's avatar Gabe Schwartz

Proper input shapes for 3D dilated conv tests.

上级 3938ba99
......@@ -1121,7 +1121,7 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1), dilation=(1
conv = GpuDnnConvGradW()(img, kerns, out, desc)
return as_gpuarray_variable(conv.dimshuffle(1, 0, 2, 3, 4), ctx_name)
elif (border_mode == 'full' and subsample == (1, 1, 1) and
elif (border_mode == 'full' and subsample == (1, 1, 1) and dilation == (1, 1, 1) and
direction_hint != 'forward!'):
# Special case: We can be faster by using GpuDnnConvGradI to compute
# the full convolution as the backward pass of a valid convolution.
......
......@@ -656,17 +656,15 @@ class TestDnnInferShapes(utt.InferShapeTester):
@parameterized.expand(product(border_modes, conv_modes), utt.custom_name_func)
def test_conv3d_none(self, border_mode, conv_mode):
# CUDNN docs don't say that 3D conv can't handle dilation, but it returns
# CUDNN_STATUS_NOT_SUPPORTED if you try it.
self._test_conv(T.tensor5('img'),
T.tensor5('kerns'),
T.tensor5('out'),
np.random.rand(10, 2, 6, 4, 11),
np.random.rand(10, 2, 15, 16, 17),
np.random.rand(8, 2, 4, 3, 1),
border_mode,
conv_mode,
[(1, 1, 1), (2, 2, 2)],
[(1, 1, 1)],
[(1, 1, 1), (2, 2, 2)],
'none')
def _test_conv_gradw(self, img, topgrad, kerns, img_shape, kerns_shape, border_mode, conv_mode, subsamples, dilations):
......@@ -1025,6 +1023,8 @@ def get_conv3d_test_cases():
[(8, 4, 20, 12, 15), (5, 4, 6, 12, 4), (2, 2, 2), (1, 1, 1)],
[(8, 1, 20, 12, 15), (5, 1, 6, 12, 4), (3, 3, 3), (1, 1, 1)],
[(8, 1, 20, 12, 15), (5, 1, 6, 12, 4), (3, 2, 1), (1, 1, 1)],
[(8, 1, 20, 12, 15), (5, 1, 6, 3, 4), (1, 1, 2), (3, 2, 1)],
[(8, 1, 20, 12, 15), (5, 1, 6, 3, 4), (2, 2, 1), (1, 2, 3)],
# Test with 1x1x1 filters
[(8, 1, 10, 10, 10), (10, 1, 1, 1, 1), (1, 1, 1), (1, 1, 1)],
# Test with dimensions larger than 1024 (thread block dim)
......@@ -1042,7 +1042,8 @@ def get_conv3d_test_cases():
test_shapes_full = [[(6, 2, 2, 2, 2), (4, 2, 3, 1, 1), (1, 1, 1), (1, 1, 1)],
[(6, 2, 2, 2, 2), (4, 2, 1, 3, 1), (1, 1, 1), (1, 1, 1)],
[(6, 2, 2, 2, 2), (4, 2, 1, 1, 3), (1, 1, 1), (1, 1, 1)],
[(6, 2, 2, 2, 2), (4, 2, 5, 5, 5), (1, 1, 1), (1, 1, 1)]]
[(6, 2, 2, 2, 2), (4, 2, 5, 5, 5), (1, 1, 1), (1, 1, 1)],
[(6, 2, 2, 2, 2), (4, 2, 5, 5, 5), (1, 1, 1), (3, 2, 1)]]
border_modes = ['valid', 'full', 'half', (1, 2, 3), (3, 2, 1), 1, 2]
conv_modes = ['conv', 'cross']
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论