提交 4e094fc0 authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: GitHub

Merge pull request #5936 from HapeMask/cudnnv6_dilation

Add support for cudnn v6 dilated convolution.
......@@ -5,19 +5,19 @@ int APPLY_SPECIFIC(conv_desc)(PyArrayObject *filt_shp,
cudnnStatus_t err;
int pad[3] = {PAD_0, PAD_1, PAD_2};
int strides[3] = {SUB_0, SUB_1, SUB_2};
int upscale[3] = {1, 1, 1};
int dilation[3] = {DIL_0, DIL_1, DIL_2};
#if BORDER_MODE == 0
pad[0] = *(npy_int64 *)PyArray_GETPTR1(filt_shp, 2) - 1;
pad[1] = *(npy_int64 *)PyArray_GETPTR1(filt_shp, 3) - 1;
pad[0] = (*(npy_int64 *)PyArray_GETPTR1(filt_shp, 2) - 1) * DIL_0;
pad[1] = (*(npy_int64 *)PyArray_GETPTR1(filt_shp, 3) - 1) * DIL_1;
#if NB_DIMS > 2
pad[2] = *(npy_int64 *)PyArray_GETPTR1(filt_shp, 4) - 1;
pad[2] = (*(npy_int64 *)PyArray_GETPTR1(filt_shp, 4) - 1) * DIL_2;
#endif
#elif BORDER_MODE == 2
pad[0] = *(npy_int64 *)PyArray_GETPTR1(filt_shp, 2) / 2;
pad[1] = *(npy_int64 *)PyArray_GETPTR1(filt_shp, 3) / 2;
pad[0] = ((*(npy_int64 *)PyArray_GETPTR1(filt_shp, 2) - 1) * DIL_0 + 1) / 2;
pad[1] = ((*(npy_int64 *)PyArray_GETPTR1(filt_shp, 3) - 1) * DIL_1 + 1) / 2;
#if NB_DIMS > 2
pad[2] = *(npy_int64 *)PyArray_GETPTR1(filt_shp, 4) / 2;
pad[2] = ((*(npy_int64 *)PyArray_GETPTR1(filt_shp, 4) - 1) * DIL_2 + 1) / 2;
#endif
#endif
......@@ -36,6 +36,11 @@ int APPLY_SPECIFIC(conv_desc)(PyArrayObject *filt_shp,
}
err = cudnnSetConvolutionNdDescriptor(*desc, NB_DIMS, pad, strides,
upscale, CONV_MODE, PRECISION);
dilation, CONV_MODE, PRECISION);
if (err != CUDNN_STATUS_SUCCESS) {
PyErr_Format(PyExc_MemoryError, "could not set convolution "
"descriptor: %s", cudnnGetErrorString(err));
return -1;
}
return 0;
}
差异被折叠。
......@@ -188,11 +188,11 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
int nd;
int pad[2];
int stride[2];
int upscale[2];
int dilation[2];
cudnnConvolutionMode_t mode;
cudnnDataType_t data_type;
err = cudnnGetConvolutionNdDescriptor(desc, 2, &nd, pad, stride,
upscale, &mode, &data_type);
dilation, &mode, &data_type);
if (err != CUDNN_STATUS_SUCCESS) {
PyErr_Format(PyExc_RuntimeError,
"error getting convolution properties: %s",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论