提交 66f946d4 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix some problems with cudnn R1.

上级 24e33a70
......@@ -4,6 +4,7 @@
#include <cudnn.h>
#ifndef CUDNN_VERSION
#include <assert.h>
// Here we define the R2 API in terms of functions in the R1 interface
// This is only for what we use
......@@ -38,10 +39,10 @@ static inline const char *cudnnGetErrorString(cudnnStatus_t err) {
#define cudnnDestroyTensorDescriptor cudnnDestroyTensor4dDescriptor
#define cudnnSetFilter4dDescriptor cudnnSetFilterDescriptor
typedef cudnnTensorDescriptor_t cudnnTensor4dDescriptor_t;
typedef cudnnTensor4dDescriptor_t cudnnTensorDescriptor_t;
static inline cudnnStatus_t
cdnnGetConvolution2dForwardOutputDim(
cudnnGetConvolution2dForwardOutputDim(
const cudnnConvolutionDescriptor_t convDesc,
const cudnnTensorDescriptor_t inputTensorDesc,
const cudnnFilterDescriptor_t filterDesc,
......@@ -54,6 +55,9 @@ cdnnGetConvolution2dForwardOutputDim(
}
typedef int cudnnConvolutionFwdAlgo_t;
typedef int cudnnConvolutionFwdPreference_t;
#define CUDNN_CONVOLUTION_FWD_NO_WORKSPACE 0
static inline cudnnStatus_t
cudnnGetConvolutionForwardAlgorithm(
......@@ -73,7 +77,7 @@ static inline cudnnStatus_t
cudnnConvolutionForward_v2(
cudnnHandle_t handle,
const void *alpha,
const cudnnTensorDescriptor_t srcDest,
const cudnnTensorDescriptor_t srcDesc,
const void *srcData,
const cudnnFilterDescriptor_t filterDesc,
const void *filterData,
......@@ -119,20 +123,25 @@ static inline cudnnStatus_t
cudnnConvolutionBackwardData_v2(
cudnnHandle_t handle,
const void *alpha,
const cudnnTensorDescriptor_t filterDesc,
const cudnnFilterDescriptor_t filterDesc,
const void *filterData,
const cudnnTensorDescriptor_t diffDesc,
const void *diffData,
const cudnnConvolutionDescriptor_t convDesc,
const void *beta,
const cudnnFilterDescriptor_t gradDesc,
const cudnnTensorDescriptor_t gradDesc,
void *gradData) {
assert(*(float *)alpha == 1.0);
assert(*(float *)beta == 0.0);
return cudnnConvolutionBackwardFilter(handle, filterDesc, filterData,
diffDesc, diffData,
convDesc, gradDesc, gradData,
CUDNN_RESULT_NO_ACCUMULATE);
return cudnnConvolutionBackwardData(handle,
(cudnnFilterDescriptor_t)filterDesc,
filterData,
(cudnnTensorDescriptor_t)diffDesc,
diffData,
(cudnnConvolutionDescriptor_t)convDesc,
(cudnnTensorDescriptor_t)gradDesc,
gradData,
CUDNN_RESULT_NO_ACCUMULATE);
}
#define cudnnConvolutionBackwardData cudnnConvolutionBackwardData_v2
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论