提交 330da07d authored 作者: carriepl's avatar carriepl 提交者: Frederic

Retain compatibility with CuDNN v2 (cuda backend)

上级 cde07713
...@@ -3,6 +3,38 @@ ...@@ -3,6 +3,38 @@
#include <cudnn.h> #include <cudnn.h>
// If needed, define element of the V3 interface in terms of elements of
// previous versions
#if defined(CUDNN_VERSION) && CUDNN_VERSION < 3000
// Starting in V3, the cudnnSetConvolutionNdDescriptor has an additional
// parameter that determines the data type in which to do the computation.
// For versions older than V3, we need to define an alias for that function
// that will take the additional parameter as input but ignore it.
static inline cudnnStatus_t cudnnSetConvolutionNdDescriptor_v3(
cudnnConvolutionDescriptor_t convDesc,
int arrayLength,
int padA[],
int filterStrideA[]
int upscaleA[],
cudnnConvolutionMode_t mode,
cudnn_dataType_t dataType)
return cudnnSetConvolutionNdDescriptor(convDesc, arrayLength, padA,
filterStrideA, upscaleA, mode);
)
#endif
// If needed, define element of the V4 interface in terms of elements of
// previous versions
#if defined(CUDNN_VERSION) && CUDNN_VERSION < 4000
#define CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING 5
#define CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING 3
#endif
#ifndef CUDNN_VERSION #ifndef CUDNN_VERSION
#include <assert.h> #include <assert.h>
......
...@@ -362,7 +362,7 @@ class GpuDnnConvDesc(GpuOp): ...@@ -362,7 +362,7 @@ class GpuDnnConvDesc(GpuOp):
} }
} }
err = cudnnSetConvolutionNdDescriptor( err = cudnnSetConvolutionNdDescriptor_v3(
%(desc)s, %(desc)s,
%(nb_dim)d, %(nb_dim)d,
pad, subsample, upscale, pad, subsample, upscale,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论