提交 0cbfeae8 authored 作者: notoraptor's avatar notoraptor

Add 3D tests and add a fix

(replace algo `small` by `none` if necessary).
上级 2d41daaf
...@@ -163,10 +163,15 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns, ...@@ -163,10 +163,15 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
} }
} }
/* These two algos are not supported for 3d conv */ /* Only these algos are supported for 3d conv with cuDNN >= V5.1. */
if (PyGpuArray_NDIM(input) == 5 && if (PyGpuArray_NDIM(input) == 5 &&
(algo == CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM || !(algo == CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM ||
algo == CUDNN_CONVOLUTION_FWD_ALGO_GEMM)) algo == CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM ||
algo == CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING))
algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
/* Algo `small` seems to not work for a batch size > 2^16, with cuDNN >= V5.1. */
if (algo == CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM && PyGpuArray_DIM(input, 0) > 65536)
algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM; algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
// The FFT implementation does not support strides, 1x1 filters or inputs // The FFT implementation does not support strides, 1x1 filters or inputs
......
...@@ -1068,7 +1068,7 @@ def get_conv3d_test_cases(): ...@@ -1068,7 +1068,7 @@ def get_conv3d_test_cases():
def run_conv_small_batched_vs_multicall(inputs_shape, filters_shape, batch_sub, subsample): def run_conv_small_batched_vs_multicall(inputs_shape, filters_shape, batch_sub, subsample):
# Run function for issue $5985 (see tests below): https://github.com/Theano/Theano/issues/5985 # Function to check issue #5985 (see tests below): https://github.com/Theano/Theano/issues/5985
algo = 'small' algo = 'small'
batch_size = inputs_shape[0] batch_size = inputs_shape[0]
...@@ -1081,16 +1081,20 @@ def run_conv_small_batched_vs_multicall(inputs_shape, filters_shape, batch_sub, ...@@ -1081,16 +1081,20 @@ def run_conv_small_batched_vs_multicall(inputs_shape, filters_shape, batch_sub,
inputs = theano.shared(inputs_val) inputs = theano.shared(inputs_val)
filters = theano.shared(filters_val) filters = theano.shared(filters_val)
conv = dnn.dnn_conv(img=inputs, kerns=filters, algo=algo, subsample=subsample) if len(inputs_shape) == 5:
dnn_func = dnn.dnn_conv3d
else:
dnn_func = dnn.dnn_conv
conv = dnn_func(img=inputs, kerns=filters, algo=algo, subsample=subsample)
# Just compute firt and last outputs to reduce execution time. # Just compute firt and last outputs to reduce execution time.
sub_conv_top = dnn.dnn_conv(img=inputs[:batch_sub], sub_conv_top = dnn_func(img=inputs[:batch_sub],
kerns=filters, algo=algo, subsample=subsample) kerns=filters, algo=algo, subsample=subsample)
sub_conv_bottom = dnn.dnn_conv(img=inputs[(batch_size - batch_sub):], sub_conv_bottom = dnn_func(img=inputs[(batch_size - batch_sub):],
kerns=filters, algo=algo, subsample=subsample) kerns=filters, algo=algo, subsample=subsample)
f = theano.function([], [conv, sub_conv_top, sub_conv_bottom], mode=mode_with_gpu) f = theano.function([], [conv, sub_conv_top, sub_conv_bottom], mode=mode_with_gpu)
res_all, res_batch_top, res_batch_bottom = f() res_all, res_batch_top, res_batch_bottom = f()
for i in range(0, batch_sub): for i in range(0, batch_sub):
utt.assert_allclose(res_all[i], res_batch_top[i]) utt.assert_allclose(res_batch_top[i], res_all[i])
p = batch_size - batch_sub + i p = batch_size - batch_sub + i
# It seems there is a liimit batch size of 65536 for a good computation # It seems there is a liimit batch size of 65536 for a good computation
# with algorithm `small`. # with algorithm `small`.
...@@ -1100,7 +1104,7 @@ def run_conv_small_batched_vs_multicall(inputs_shape, filters_shape, batch_sub, ...@@ -1100,7 +1104,7 @@ def run_conv_small_batched_vs_multicall(inputs_shape, filters_shape, batch_sub,
# It should not happen. # It should not happen.
if np.allclose(res_all[p % checked_limit], res_all[p]): if np.allclose(res_all[p % checked_limit], res_all[p]):
print('\nconv[%d] == conv[%d] == %s' % (p % checked_limit, p, res_all[p])) print('\nconv[%d] == conv[%d] == %s' % (p % checked_limit, p, res_all[p]))
utt.assert_allclose(res_all[p], res_batch_bottom[i]) utt.assert_allclose(res_batch_bottom[i], res_all[p])
def test_batched_conv_small(): def test_batched_conv_small():
...@@ -1110,6 +1114,13 @@ def test_batched_conv_small(): ...@@ -1110,6 +1114,13 @@ def test_batched_conv_small():
yield (run_conv_small_batched_vs_multicall, (65537, 2, 2, 2), (1, 2, 2, 2), 5, (1, 1)) # ERROR yield (run_conv_small_batched_vs_multicall, (65537, 2, 2, 2), (1, 2, 2, 2), 5, (1, 1)) # ERROR
def test_batched_conv3d_small():
yield (run_conv_small_batched_vs_multicall, (65534, 2, 2, 2, 2), (1, 2, 2, 2, 2), 5, (1, 1, 1)) # OK
yield (run_conv_small_batched_vs_multicall, (65535, 2, 2, 2, 2), (1, 2, 2, 2, 2), 5, (1, 1, 1)) # OK
yield (run_conv_small_batched_vs_multicall, (65536, 2, 2, 2, 2), (1, 2, 2, 2, 2), 5, (1, 1, 1)) # OK
yield (run_conv_small_batched_vs_multicall, (65537, 2, 2, 2, 2), (1, 2, 2, 2, 2), 5, (1, 1, 1)) # ERROR ALSO.
def test_conv3d_fwd(): def test_conv3d_fwd():
if not dnn.dnn_available(test_ctx_name): if not dnn.dnn_available(test_ctx_name):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论