提交 7c436d0f authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Accept beta = 0.0 and 1.0 in cudnn v1.

上级 d2c39c4e
...@@ -103,11 +103,18 @@ cudnnConvolutionForward_v2( ...@@ -103,11 +103,18 @@ cudnnConvolutionForward_v2(
const cudnnTensorDescriptor_t destDesc, const cudnnTensorDescriptor_t destDesc,
void *destData) { void *destData) {
assert(*(float *)alpha == 1.0); assert(*(float *)alpha == 1.0);
assert(*(float *)beta == 1.0); cudnnAccumulateResult_t r;
if (*(float *)beta == 0.0) {
r = CUDNN_RESULT_NO_ACCUMULATE;
} else if (*(float *)beta == 1.0) {
r = CUDNN_RESULT_ACCUMULATE;
} else {
assert(0 && "beta must be 0.0 or 1.0");
}
return cudnnConvolutionForward(handle, srcDesc, srcData, return cudnnConvolutionForward(handle, srcDesc, srcData,
filterDesc, filterData, filterDesc, filterData,
convDesc, destDesc, destData, convDesc, destDesc, destData,
CUDNN_RESULT_ACCUMULATE); r);
} }
#define cudnnConvolutionForward cudnnConvolutionForward_v2 #define cudnnConvolutionForward cudnnConvolutionForward_v2
...@@ -124,11 +131,18 @@ cudnnConvolutionBackwardFilter_v2( ...@@ -124,11 +131,18 @@ cudnnConvolutionBackwardFilter_v2(
const cudnnFilterDescriptor_t gradDesc, const cudnnFilterDescriptor_t gradDesc,
void *gradData) { void *gradData) {
assert(*(float *)alpha == 1.0); assert(*(float *)alpha == 1.0);
assert(*(float *)beta == 1.0); cudnnAccumulateResult_t r;
if (*(float *)beta == 0.0) {
r = CUDNN_RESULT_NO_ACCUMULATE;
} else if (*(float *)beta == 1.0) {
r = CUDNN_RESULT_ACCUMULATE;
} else {
assert(0 && "beta must be 0.0 or 1.0");
}
return cudnnConvolutionBackwardFilter(handle, srcDesc, srcData, return cudnnConvolutionBackwardFilter(handle, srcDesc, srcData,
diffDesc, diffData, diffDesc, diffData,
convDesc, gradDesc, gradData, convDesc, gradDesc, gradData,
CUDNN_RESULT_ACCUMULATE); r);
} }
#define cudnnConvolutionBackwardFilter cudnnConvolutionBackwardFilter_v2 #define cudnnConvolutionBackwardFilter cudnnConvolutionBackwardFilter_v2
...@@ -146,7 +160,16 @@ cudnnConvolutionBackwardData_v2( ...@@ -146,7 +160,16 @@ cudnnConvolutionBackwardData_v2(
const cudnnTensorDescriptor_t gradDesc, const cudnnTensorDescriptor_t gradDesc,
void *gradData) { void *gradData) {
assert(*(float *)alpha == 1.0); assert(*(float *)alpha == 1.0);
assert(*(float *)beta == 1.0); cudnnAccumulateResult_t r;
if (*(float *)beta == 0.0) {
r = CUDNN_RESULT_NO_ACCUMULATE;
} else if (*(float *)beta == 1.0) {
r = CUDNN_RESULT_ACCUMULATE;
} else {
assert(0 && "beta must be 0.0 or 1.0");
}
/* This function needs the casting because its params are not
declared as const */
return cudnnConvolutionBackwardData(handle, return cudnnConvolutionBackwardData(handle,
(cudnnFilterDescriptor_t)filterDesc, (cudnnFilterDescriptor_t)filterDesc,
filterData, filterData,
...@@ -155,7 +178,7 @@ cudnnConvolutionBackwardData_v2( ...@@ -155,7 +178,7 @@ cudnnConvolutionBackwardData_v2(
(cudnnConvolutionDescriptor_t)convDesc, (cudnnConvolutionDescriptor_t)convDesc,
(cudnnTensorDescriptor_t)gradDesc, (cudnnTensorDescriptor_t)gradDesc,
gradData, gradData,
CUDNN_RESULT_ACCUMULATE); r);
} }
#define cudnnConvolutionBackwardData cudnnConvolutionBackwardData_v2 #define cudnnConvolutionBackwardData cudnnConvolutionBackwardData_v2
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论