提交 9f231761 authored 作者: --global's avatar --global

Default to safe algo when fft is not supported

上级 ad9646be
...@@ -137,6 +137,47 @@ APPLY_SPECIFIC(conv_fwd)(CudaNdarray *input, CudaNdarray *kerns, ...@@ -137,6 +137,47 @@ APPLY_SPECIFIC(conv_fwd)(CudaNdarray *input, CudaNdarray *kerns,
chosen_algo = CONV_ALGO; chosen_algo = CONV_ALGO;
} }
// The FFT implementation does not support strides, 1x1 filters or
// inputs with a spatial dimension larger than 1024.
// If the chosen implementation is FFT, validate that it can be used
// on the current data and default on a safe implementation if it
// can't.
if (chosen_algo == CUDNN_CONVOLUTION_FWD_ALGO_FFT)
{
// Extract the properties of the convolution descriptor
int pad_h, pad_w, stride_v, stride_h, upscale_x, upscale_y;
cudnnConvolutionMode_t mode;
err = cudnnGetConvolution2dDescriptor(desc, &pad_h, &pad_w,
&stride_v, &stride_h,
&upscale_x, &upscale_y,
&mode);
if (err != CUDNN_STATUS_SUCCESS) {
PyErr_Format(PyExc_RuntimeError,
"GpuDnnConv: error getting convolution properties: %s",
cudnnGetErrorString(err));
return 1;
}
// Extract the spatial size of the filters
int filter_h = CudaNdarray_HOST_DIMS(kerns)[3];
int filter_w = CudaNdarray_HOST_DIMS(kerns)[4];
// Extract the spatial size of the input
int input_h = CudaNdarray_HOST_DIMS(input)[3];
int input_w = CudaNdarray_HOST_DIMS(input)[4];
// Ensure that the selected implementation supports the requested
// convolution. Fall back to a safe implementation otherwise.
if (stride_v != 1 || stride_h != 1 || input_h > 1024 ||
input_w > 1024 || (filter_h == 1 && filter_w == 1))
{
chosen_algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
}
}
err = cudnnGetConvolutionForwardWorkspaceSize(_handle, err = cudnnGetConvolutionForwardWorkspaceSize(_handle,
APPLY_SPECIFIC(input), APPLY_SPECIFIC(input),
APPLY_SPECIFIC(kerns), APPLY_SPECIFIC(kerns),
......
...@@ -97,7 +97,54 @@ APPLY_SPECIFIC(conv_gw)(CudaNdarray *input, CudaNdarray *output, ...@@ -97,7 +97,54 @@ APPLY_SPECIFIC(conv_gw)(CudaNdarray *input, CudaNdarray *output,
} }
else else
{ {
chosen_algo = CONV_ALGO; // The shapes of the input and the output are the same as for the
// last execution. The convolution algorithm used last time can also
// be used here
chosen_algo = APPLY_SPECIFIC(previous_bwd_f_algo);
}
}
else
{
chosen_algo = CONV_ALGO;
}
// The FFT implementation does not support strides, 1x1 filters or
// inputs with a spatial dimension larger than 1024.
// If the chosen implementation is FFT, validate that it can be used
// on the current data and default on a safe implementation if it
// can't.
if (chosen_algo == CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT)
{
// Extract the properties of the convolution descriptor
int pad_h, pad_w, stride_v, stride_h, upscale_x, upscale_y;
cudnnConvolutionMode_t mode;
err = cudnnGetConvolution2dDescriptor(desc, &pad_h, &pad_w,
&stride_v, &stride_h,
&upscale_x, &upscale_y,
&mode);
if (err != CUDNN_STATUS_SUCCESS) {
PyErr_Format(PyExc_RuntimeError,
"GpuDnnConvGradW: error getting convolution properties: %s",
cudnnGetErrorString(err));
return 1;
}
// Extract the spatial size of the filters
int filter_h = CudaNdarray_HOST_DIMS(*kerns)[3];
int filter_w = CudaNdarray_HOST_DIMS(*kerns)[4];
// Extract the spatial size of the input
int input_h = CudaNdarray_HOST_DIMS(input)[3];
int input_w = CudaNdarray_HOST_DIMS(input)[4];
// Ensure that the selected implementation supports the requested
// convolution. Fall back to a safe implementation otherwise.
if (stride_v != 1 || stride_h != 1 || input_h > 1024 ||
input_w > 1024 || (filter_h == 1 && filter_w == 1))
{
chosen_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
} }
} }
...@@ -129,7 +176,7 @@ APPLY_SPECIFIC(conv_gw)(CudaNdarray *input, CudaNdarray *output, ...@@ -129,7 +176,7 @@ APPLY_SPECIFIC(conv_gw)(CudaNdarray *input, CudaNdarray *output,
APPLY_SPECIFIC(output), CudaNdarray_DEV_DATA(output), APPLY_SPECIFIC(output), CudaNdarray_DEV_DATA(output),
desc, desc,
chosen_algo, chosen_algo,
&workspace, worksize, workspace, worksize,
(void *)&beta, (void *)&beta,
APPLY_SPECIFIC(kerns), CudaNdarray_DEV_DATA(*kerns)); APPLY_SPECIFIC(kerns), CudaNdarray_DEV_DATA(*kerns));
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论