提交 66f946d4 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix some problems with cudnn R1.

上级 24e33a70
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <cudnn.h> #include <cudnn.h>
#ifndef CUDNN_VERSION #ifndef CUDNN_VERSION
#include <assert.h>
// Here we define the R2 API in terms of functions in the R1 interface // Here we define the R2 API in terms of functions in the R1 interface
// This is only for what we use // This is only for what we use
...@@ -38,10 +39,10 @@ static inline const char *cudnnGetErrorString(cudnnStatus_t err) { ...@@ -38,10 +39,10 @@ static inline const char *cudnnGetErrorString(cudnnStatus_t err) {
#define cudnnDestroyTensorDescriptor cudnnDestroyTensor4dDescriptor #define cudnnDestroyTensorDescriptor cudnnDestroyTensor4dDescriptor
#define cudnnSetFilter4dDescriptor cudnnSetFilterDescriptor #define cudnnSetFilter4dDescriptor cudnnSetFilterDescriptor
typedef cudnnTensorDescriptor_t cudnnTensor4dDescriptor_t; typedef cudnnTensor4dDescriptor_t cudnnTensorDescriptor_t;
static inline cudnnStatus_t static inline cudnnStatus_t
cdnnGetConvolution2dForwardOutputDim( cudnnGetConvolution2dForwardOutputDim(
const cudnnConvolutionDescriptor_t convDesc, const cudnnConvolutionDescriptor_t convDesc,
const cudnnTensorDescriptor_t inputTensorDesc, const cudnnTensorDescriptor_t inputTensorDesc,
const cudnnFilterDescriptor_t filterDesc, const cudnnFilterDescriptor_t filterDesc,
...@@ -54,6 +55,9 @@ cdnnGetConvolution2dForwardOutputDim( ...@@ -54,6 +55,9 @@ cdnnGetConvolution2dForwardOutputDim(
} }
typedef int cudnnConvolutionFwdAlgo_t; typedef int cudnnConvolutionFwdAlgo_t;
typedef int cudnnConvolutionFwdPreference_t;
#define CUDNN_CONVOLUTION_FWD_NO_WORKSPACE 0
static inline cudnnStatus_t static inline cudnnStatus_t
cudnnGetConvolutionForwardAlgorithm( cudnnGetConvolutionForwardAlgorithm(
...@@ -73,7 +77,7 @@ static inline cudnnStatus_t ...@@ -73,7 +77,7 @@ static inline cudnnStatus_t
cudnnConvolutionForward_v2( cudnnConvolutionForward_v2(
cudnnHandle_t handle, cudnnHandle_t handle,
const void *alpha, const void *alpha,
const cudnnTensorDescriptor_t srcDest, const cudnnTensorDescriptor_t srcDesc,
const void *srcData, const void *srcData,
const cudnnFilterDescriptor_t filterDesc, const cudnnFilterDescriptor_t filterDesc,
const void *filterData, const void *filterData,
...@@ -119,20 +123,25 @@ static inline cudnnStatus_t ...@@ -119,20 +123,25 @@ static inline cudnnStatus_t
cudnnConvolutionBackwardData_v2( cudnnConvolutionBackwardData_v2(
cudnnHandle_t handle, cudnnHandle_t handle,
const void *alpha, const void *alpha,
const cudnnTensorDescriptor_t filterDesc, const cudnnFilterDescriptor_t filterDesc,
const void *filterData, const void *filterData,
const cudnnTensorDescriptor_t diffDesc, const cudnnTensorDescriptor_t diffDesc,
const void *diffData, const void *diffData,
const cudnnConvolutionDescriptor_t convDesc, const cudnnConvolutionDescriptor_t convDesc,
const void *beta, const void *beta,
const cudnnFilterDescriptor_t gradDesc, const cudnnTensorDescriptor_t gradDesc,
void *gradData) { void *gradData) {
assert(*(float *)alpha == 1.0); assert(*(float *)alpha == 1.0);
assert(*(float *)beta == 0.0); assert(*(float *)beta == 0.0);
return cudnnConvolutionBackwardFilter(handle, filterDesc, filterData, return cudnnConvolutionBackwardData(handle,
diffDesc, diffData, (cudnnFilterDescriptor_t)filterDesc,
convDesc, gradDesc, gradData, filterData,
CUDNN_RESULT_NO_ACCUMULATE); (cudnnTensorDescriptor_t)diffDesc,
diffData,
(cudnnConvolutionDescriptor_t)convDesc,
(cudnnTensorDescriptor_t)gradDesc,
gradData,
CUDNN_RESULT_NO_ACCUMULATE);
} }
#define cudnnConvolutionBackwardData cudnnConvolutionBackwardData_v2 #define cudnnConvolutionBackwardData cudnnConvolutionBackwardData_v2
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论