提交 b177f3a8 authored 作者: carriepl's avatar carriepl 提交者: Frederic

Add checks to make sure tiled-fft is not used when it shouldn't (cuda backend)

上级 cf26ae5c
...@@ -164,12 +164,15 @@ APPLY_SPECIFIC(conv_fwd)(CudaNdarray *input, CudaNdarray *kerns, ...@@ -164,12 +164,15 @@ APPLY_SPECIFIC(conv_fwd)(CudaNdarray *input, CudaNdarray *kerns,
#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 3000 #if defined(CUDNN_VERSION) && CUDNN_VERSION >= 3000
// The FFT implementation (only in V3 and onward) does not support strides, // The FFT implementation (only in V3 and onward) does not support strides,
// 1x1 filters or inputs with a spatial dimension larger than 1024. // 1x1 filters or inputs with a spatial dimension larger than 1024.
// If the chosen implementation is FFT, validate that it can be used // The tiled-FFT implementation (only in V4 onward) does not support
// on the current data and default on a safe implementation if it // strides.
// If the chosen implementation is FFT or tiled-FFT, validate that it can
// be used on the current data and default on a safe implementation if it
// can't. // can't.
// Following code is 2d-specific, but it is fine as ftt is defined only for // Following code is 2d-specific, but it is fine as FFT and tiled-FFT are
// 2d-filters // defined only for 2d-filters
if (chosen_algo == CUDNN_CONVOLUTION_FWD_ALGO_FFT && nb_dim == 4) if ((chosen_algo == CUDNN_CONVOLUTION_FWD_ALGO_FFT ||
chosen_algo == CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING) && nb_dim == 4)
{ {
// Extract the properties of the convolution descriptor // Extract the properties of the convolution descriptor
...@@ -197,10 +200,21 @@ APPLY_SPECIFIC(conv_fwd)(CudaNdarray *input, CudaNdarray *kerns, ...@@ -197,10 +200,21 @@ APPLY_SPECIFIC(conv_fwd)(CudaNdarray *input, CudaNdarray *kerns,
// Ensure that the selected implementation supports the requested // Ensure that the selected implementation supports the requested
// convolution. Fall back to a safe implementation otherwise. // convolution. Fall back to a safe implementation otherwise.
if (stride_v != 1 || stride_h != 1 || input_h > 1024 || if (chosen_algo == CUDNN_CONVOLUTION_FWD_ALGO_FFT)
input_w > 1024 || (filter_h == 1 && filter_w == 1))
{ {
chosen_algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM; if (stride_v != 1 || stride_h != 1 || input_h > 1024 ||
input_w > 1024 || (filter_h == 1 && filter_w == 1))
{
chosen_algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
}
}
else
{
// chosen_algo == CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING
if (stride_v != 1 || stride_h != 1)
{
chosen_algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
}
} }
} }
#endif #endif
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论