提交 e31d548c authored 作者: notoraptor's avatar notoraptor

Workaround only for cuDNN <= 6020.

上级 0cb8fbe7
...@@ -170,8 +170,11 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns, ...@@ -170,8 +170,11 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
algo == CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING)) algo == CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING))
algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM; algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
/* Algo `small` seems to not work for a batch size > 2^16, with cuDNN >= V5.1. */ // Algo `small` does not work for a batch size > 2^16, with cuDNN >= V5.1.
if (algo == CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM && PyGpuArray_DIM(input, 0) > 65536) // Issue should have been resolved for cuDNN >= V6.0.20.
if (cudnnGetVersion() <= 6020 &&
algo == CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM &&
PyGpuArray_DIM(input, 0) > 65536)
algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM; algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
// The FFT implementation does not support strides, 1x1 filters or inputs // The FFT implementation does not support strides, 1x1 filters or inputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论