提交 7c436d0f authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Accept beta = 0.0 and 1.0 in cudnn v1.

上级 d2c39c4e
......@@ -103,11 +103,18 @@ cudnnConvolutionForward_v2(
const cudnnTensorDescriptor_t destDesc,
void *destData) {
assert(*(float *)alpha == 1.0);
assert(*(float *)beta == 1.0);
cudnnAccumulateResult_t r;
if (*(float *)beta == 0.0) {
r = CUDNN_RESULT_NO_ACCUMULATE;
} else if (*(float *)beta == 1.0) {
r = CUDNN_RESULT_ACCUMULATE;
} else {
assert(0 && "beta must be 0.0 or 1.0");
}
return cudnnConvolutionForward(handle, srcDesc, srcData,
filterDesc, filterData,
convDesc, destDesc, destData,
CUDNN_RESULT_ACCUMULATE);
r);
}
#define cudnnConvolutionForward cudnnConvolutionForward_v2
......@@ -124,11 +131,18 @@ cudnnConvolutionBackwardFilter_v2(
const cudnnFilterDescriptor_t gradDesc,
void *gradData) {
assert(*(float *)alpha == 1.0);
assert(*(float *)beta == 1.0);
cudnnAccumulateResult_t r;
if (*(float *)beta == 0.0) {
r = CUDNN_RESULT_NO_ACCUMULATE;
} else if (*(float *)beta == 1.0) {
r = CUDNN_RESULT_ACCUMULATE;
} else {
assert(0 && "beta must be 0.0 or 1.0");
}
return cudnnConvolutionBackwardFilter(handle, srcDesc, srcData,
diffDesc, diffData,
convDesc, gradDesc, gradData,
CUDNN_RESULT_ACCUMULATE);
r);
}
#define cudnnConvolutionBackwardFilter cudnnConvolutionBackwardFilter_v2
......@@ -146,7 +160,16 @@ cudnnConvolutionBackwardData_v2(
const cudnnTensorDescriptor_t gradDesc,
void *gradData) {
assert(*(float *)alpha == 1.0);
assert(*(float *)beta == 1.0);
cudnnAccumulateResult_t r;
if (*(float *)beta == 0.0) {
r = CUDNN_RESULT_NO_ACCUMULATE;
} else if (*(float *)beta == 1.0) {
r = CUDNN_RESULT_ACCUMULATE;
} else {
assert(0 && "beta must be 0.0 or 1.0");
}
/* This function needs the casting because its params are not
declared as const */
return cudnnConvolutionBackwardData(handle,
(cudnnFilterDescriptor_t)filterDesc,
filterData,
......@@ -155,7 +178,7 @@ cudnnConvolutionBackwardData_v2(
(cudnnConvolutionDescriptor_t)convDesc,
(cudnnTensorDescriptor_t)gradDesc,
gradData,
CUDNN_RESULT_ACCUMULATE);
r);
}
#define cudnnConvolutionBackwardData cudnnConvolutionBackwardData_v2
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论