提交 b171cc23 authored 作者: Mathieu Germain's avatar Mathieu Germain

fix cuDNN v5 cudnnSetPoolingNdDescriptor in pool grad

上级 5813433e
......@@ -42,7 +42,7 @@ APPLY_SPECIFIC(pool) = NULL;
}
if ((err = cudnnCreatePoolingDescriptor(&APPLY_SPECIFIC(pool))) != CUDNN_STATUS_SUCCESS) {
PyErr_Format(PyExc_MemoryError, "could not allocate pooling descriptor"
"(pool): %s", cudnnGetErrorString(err));
"(pool): %s", cudnnGetErrorString(err));
FAIL;
}
}
......@@ -60,7 +60,7 @@ if (APPLY_SPECIFIC(pool) != NULL) { cudnnDestroyPoolingDescriptor(APPLY_SPECIFIC
int APPLY_SPECIFIC(dnn_pool_grad)(PyGpuArrayObject *inp,
PyGpuArrayObject *out,
PyGpuArrayObject *out_grad,
PyArrayObject *ws,
PyArrayObject *ws,
PyArrayObject *stride,
PyArrayObject *pad,
PyGpuArrayObject **inp_grad,
......@@ -109,7 +109,12 @@ int APPLY_SPECIFIC(dnn_pool_grad)(PyGpuArrayObject *inp,
for(int i = 0; i < ndims; i++) {
s[i] = *((npy_intp*)PyArray_GETPTR1(stride, i));
}
#if CUDNN_VERSION >= 5000
err = cudnnSetPoolingNdDescriptor(APPLY_SPECIFIC(pool), MODE_FLAG, CUDNN_PROPAGATE_NAN, ndims, w, p, s);
#else
err = cudnnSetPoolingNdDescriptor(APPLY_SPECIFIC(pool), MODE_FLAG, ndims, w, p, s);
#endif
if (err != CUDNN_STATUS_SUCCESS) {
PyErr_Format(PyExc_RuntimeError, "could not set op descriptor %s", cudnnGetErrorString(err));
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论