提交 2d77d637 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Pad descriptors to 3d if the input tensor is less than that and up the limit to…

Pad descriptors to 3d if the input tensor is less than that and up the limit to the one supported by cudnn (8d).
上级 026a96c4
......@@ -11,22 +11,20 @@ c_set_tensorNd(PyGpuArrayObject *var, cudnnTensorDescriptor_t desc) {
case GA_DOUBLE:
dt = CUDNN_DATA_DOUBLE;
break;
#if CUDNN_VERSION > 3000
case GA_HALF:
dt = CUDNN_DATA_HALF;
break;
#endif
default:
PyErr_SetString(PyExc_TypeError, "Non-float datatype in c_set_tensorNd");
return -1;
}
ds = gpuarray_get_elsize(var->ga.typecode);
int strs[5], dims[5], default_stride = 1;
int strs[8], dims[8], default_stride = 1;
unsigned int nd = PyGpuArray_NDIM(var);
if (nd > 5) {
PyErr_SetString(PyExc_TypeError, "Tensor of more than 5d");
if (nd > 8) {
PyErr_SetString(PyExc_TypeError, "Tensor of more than 8d");
return -1;
}
......@@ -38,7 +36,15 @@ c_set_tensorNd(PyGpuArrayObject *var, cudnnTensorDescriptor_t desc) {
dims[i] = PyGpuArray_DIM(var, i);
}
cudnnStatus_t err = cudnnSetTensorNdDescriptor(desc, dt, nd, dims, strs);
/* Tensors can't be smaller than 3d for cudnn so we pad the
* descriptor if they are */
for (unsigned int i = nd; i < 3; i++) {
strs[i] = 1;
dims[i] = 1;
}
cudnnStatus_t err = cudnnSetTensorNdDescriptor(desc, dt, nd < 3 ? 3 : nd,
dims, strs);
if (err != CUDNN_STATUS_SUCCESS) {
PyErr_Format(PyExc_RuntimeError,
"Could not set tensorNd descriptor: %s",
......@@ -65,21 +71,19 @@ c_set_filter(PyGpuArrayObject *var, cudnnFilterDescriptor_t desc) {
case GA_DOUBLE:
dt = CUDNN_DATA_DOUBLE;
break;
#if CUDNN_VERSION > 3000
case GA_HALF:
dt = CUDNN_DATA_HALF;
break;
#endif
default:
PyErr_SetString(PyExc_TypeError, "Non-float datatype in c_set_filter");
return -1;
}
int dims[5];
int dims[8];
unsigned int nd = PyGpuArray_NDIM(var);
if (nd > 5) {
PyErr_SetString(PyExc_TypeError, "Tensor of more than 5d");
if (nd > 8) {
PyErr_SetString(PyExc_TypeError, "Tensor of more than 8d");
return -1;
}
......@@ -88,6 +92,13 @@ c_set_filter(PyGpuArrayObject *var, cudnnFilterDescriptor_t desc) {
dims[i] = PyGpuArray_DIM(var, i);
}
/* Filters can't be less than 3d so we pad */
for (unsigned int i = nd; i < 3; i++)
dims[i] = 1;
if (nd < 3)
nd = 3;
#if CUDNN_VERSION >= 5000
err = cudnnSetFilterNdDescriptor(desc, dt, CUDNN_TENSOR_NCHW, nd, dims);
#else
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论