提交 268bc917 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron 提交者: --global

Extend the current pooling op to support 3d pooling.

This does not have any tests or optimizations for 3d pooling, but the 2d support still works at least.
上级 525c9c84
......@@ -197,6 +197,85 @@ cudnnConvolutionBackwardData_v2(
#define cudnnConvolutionBackwardData cudnnConvolutionBackwardData_v2
static inline cudnnStatus_t
cudnnSetPoolingNdDescriptor(
cudnnPoolingDescriptor_t poolingDesc,
const cudnnPoolingMode_t mode,
int nbDims,
const int windowDimA[],
const int paddingA[],
const in strideA[]) {
if (nbDims != 2) return CUDNN_STATUS_NOT_SUPPORTED;
if (paddingA[0] != 0 || paddingA[1] != 0) return CUDNN_STATUS_NOT_SUPPORTED;
return cudnnSetPoolingDescriptor(poolingDesc, mode,
windowDimA[0], windowDimA[1],
strideA[0], strideA[1]);
}
static inline cudnnStatus_t
cudnnGetPoolingNdDescriptor(
const cudnnPoolingDescriptor_t poolingDesc,
const int nbDimsRequested,
cudnnPoolingMode_t *mode,
int *nbDims,
int windowA[],
int paddingA[],
int strideA[]) {
int win0, win1, str0, str1;
cudnnStatus_t err;
if (ndDimsRequested < 2) return CUDNN_STATUS_NOT_SUPPORTED;
err = cudnnGetPoolingDescriptor(poolingDesc, mode, &win0, &win1,
&str0, &str1);
if (err != CUDNN_STATUS_SUCCESS) return err;
*nbDims = 2;
paddingA[0] = 0;
paddingA[1] = 0;
windowA[0] = win0;
windowA[1] = win1;
strideA[0] = str0;
strideA[1] = str1;
return CUDNN_STATUS_SUCCESS;
}
static inline cudnnStatus_t
cudnnPoolingForward_v2(
cudnnHandle_t handle,
const cudnnPoolingDescriptor_t poolingDesc,
const void *alpha,
const cudnnTensorDescriptor_t srcDesc,
const void *srcData,
const void *beta,
const cudnnTensorDescriptor_t destDesc,
void *destData) {
if (*(float*)alpha != 1.0 || *(float *)beta != 0.0) return CUDNN_STATUS_NOT_SUPPORTED;
return cudnnPoolingForward(handle, poolingDesc, srcDesc, srcData,
destDesc, destData);
}
#define cudnnPoolingForward cudnnPoolingForward_v2
static inline cudnnStatus_t
cudnnPoolingBackward_v2(
cudnnHandle_t handle,
const cudnnPoolingDescriptor_t poolingDesc,
const void *alpha,
const cudnnTensorDescriptor_t srcDesc,
const void *srcData,
const cudnnTensorDescriptor_t srcDiffDesc,
const void *srcDiffData,
const cudnnTensorDescriptor_t destDesc,
const void *destData,
const void *beta,
const cudnnTensorDescriptor_t destDiffDesc,
void *destDiffData) {
if (*(float*)alpha != 1.0 || *(float *)beta != 0.0) return CUDNN_STATUS_NOT_SUPPORTED;
return cudnnPoolingBackward(handle, poolingDesc,
srcDesc, srcData,
srcDiffDesc, srcDiffData,
destDesc, destData,
destDiffDesc, destDiffData);
}
#define cudnnPoolingBackward cudnnPoolingBackward_v2
//Needed for R2 rc2
# define CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING CUDNN_POOLING_AVERAGE
#else
......
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论