提交 9f37bce1 authored 作者: Gabe Schwartz's avatar Gabe Schwartz

Added support for cudnn v6 dilated convolution.

上级 89aac420
...@@ -5,7 +5,7 @@ int APPLY_SPECIFIC(conv_desc)(PyArrayObject *filt_shp, ...@@ -5,7 +5,7 @@ int APPLY_SPECIFIC(conv_desc)(PyArrayObject *filt_shp,
cudnnStatus_t err; cudnnStatus_t err;
int pad[3] = {PAD_0, PAD_1, PAD_2}; int pad[3] = {PAD_0, PAD_1, PAD_2};
int strides[3] = {SUB_0, SUB_1, SUB_2}; int strides[3] = {SUB_0, SUB_1, SUB_2};
int upscale[3] = {1, 1, 1}; int dilation[3] = {DIL_0, DIL_1, DIL_2};
#if BORDER_MODE == 0 #if BORDER_MODE == 0
pad[0] = *(npy_int64 *)PyArray_GETPTR1(filt_shp, 2) - 1; pad[0] = *(npy_int64 *)PyArray_GETPTR1(filt_shp, 2) - 1;
...@@ -36,6 +36,11 @@ int APPLY_SPECIFIC(conv_desc)(PyArrayObject *filt_shp, ...@@ -36,6 +36,11 @@ int APPLY_SPECIFIC(conv_desc)(PyArrayObject *filt_shp,
} }
err = cudnnSetConvolutionNdDescriptor(*desc, NB_DIMS, pad, strides, err = cudnnSetConvolutionNdDescriptor(*desc, NB_DIMS, pad, strides,
upscale, CONV_MODE, PRECISION); dilation, CONV_MODE, PRECISION);
if (err != CUDNN_STATUS_SUCCESS) {
PyErr_Format(PyExc_MemoryError, "could not set convolution "
"descriptor: %s", cudnnGetErrorString(err));
return -1;
}
return 0; return 0;
} }
差异被折叠。
...@@ -188,11 +188,11 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns, ...@@ -188,11 +188,11 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
int nd; int nd;
int pad[2]; int pad[2];
int stride[2]; int stride[2];
int upscale[2]; int dilation[2];
cudnnConvolutionMode_t mode; cudnnConvolutionMode_t mode;
cudnnDataType_t data_type; cudnnDataType_t data_type;
err = cudnnGetConvolutionNdDescriptor(desc, 2, &nd, pad, stride, err = cudnnGetConvolutionNdDescriptor(desc, 2, &nd, pad, stride,
upscale, &mode, &data_type); dilation, &mode, &data_type);
if (err != CUDNN_STATUS_SUCCESS) { if (err != CUDNN_STATUS_SUCCESS) {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"error getting convolution properties: %s", "error getting convolution properties: %s",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论