提交 cde07713 authored 作者: carriepl's avatar carriepl 提交者: Frederic

Add checks to make sure backward tiled-fft is not used when it shouldn't (cuda backend)

上级 b177f3a8
...@@ -159,12 +159,17 @@ APPLY_SPECIFIC(conv_gi)(CudaNdarray *kerns, CudaNdarray *output, ...@@ -159,12 +159,17 @@ APPLY_SPECIFIC(conv_gi)(CudaNdarray *kerns, CudaNdarray *output,
chosen_algo = CONV_ALGO; chosen_algo = CONV_ALGO;
} }
// The FFT implementation (only in v3 and onward) does not support strides, // The FFT implementation (only in V3 and onward) does not support strides,
// 1x1 filters or inputs with a spatial dimension larger than 1024. // 1x1 filters or inputs with a spatial dimension larger than 1024.
// If the chosen implementation is FFT, validate that it can be used // The tiled-FFT implementation (only in V4 onward) does not support
// on the current data and default on a safe implementation if it // strides.
// If the chosen implementation is FFT or tiled-FFT, validate that it can
// be used on the current data and default on a safe implementation if it
// can't. // can't.
if (chosen_algo == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT && nb_dim == 4) // Following code is 2d-specific, but it is fine as FFT and tiled-FFT are
// defined only for 2d-filters
if ((chosen_algo == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING ||
chosen_algo == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT) && nb_dim == 4)
{ {
// Extract the properties of the convolution descriptor // Extract the properties of the convolution descriptor
...@@ -192,12 +197,23 @@ APPLY_SPECIFIC(conv_gi)(CudaNdarray *kerns, CudaNdarray *output, ...@@ -192,12 +197,23 @@ APPLY_SPECIFIC(conv_gi)(CudaNdarray *kerns, CudaNdarray *output,
// Ensure that the selected implementation supports the requested // Ensure that the selected implementation supports the requested
// convolution. Fall back to a safe implementation otherwise. // convolution. Fall back to a safe implementation otherwise.
if (chosen_algo == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT)
{
if (stride_v != 1 || stride_h != 1 || input_h > 1024 || if (stride_v != 1 || stride_h != 1 || input_h > 1024 ||
input_w > 1024 || (filter_h == 1 && filter_w == 1)) input_w > 1024 || (filter_h == 1 && filter_w == 1))
{ {
chosen_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_0; chosen_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_0;
} }
} }
else
{
// chosen_algo == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING
if (stride_v != 1 || stride_h != 1)
{
chosen_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_0;
}
}
}
// Infer required workspace size from the chosen implementation // Infer required workspace size from the chosen implementation
err = cudnnGetConvolutionBackwardDataWorkspaceSize(_handle, err = cudnnGetConvolutionBackwardDataWorkspaceSize(_handle,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论