提交 687f4b20 authored 作者: affanv14's avatar affanv14

remove duplication of code in c_set_tensor

上级 aacca03a
#section support_code
static int
c_set_tensorNd(PyGpuArrayObject *var, cudnnTensorDescriptor_t desc) {
c_set_tensor_for_conv(PyGpuArrayObject *var, cudnnTensorDescriptor_t desc, int groups) {
cudnnDataType_t dt;
size_t ds;
switch (var->ga.typecode) {
......@@ -42,7 +42,8 @@ c_set_tensorNd(PyGpuArrayObject *var, cudnnTensorDescriptor_t desc) {
strs[i] = 1;
dims[i] = 1;
}
//only for grouped convolution i.e when groups > 1
dims[1] = dims[1] / groups;
cudnnStatus_t err = cudnnSetTensorNdDescriptor(desc, dt, nd < 3 ? 3 : nd,
dims, strs);
if (err != CUDNN_STATUS_SUCCESS) {
......@@ -55,46 +56,8 @@ c_set_tensorNd(PyGpuArrayObject *var, cudnnTensorDescriptor_t desc) {
}
static int
c_set_tensor_for_conv(PyGpuArrayObject *var, cudnnTensorDescriptor_t desc, int groups) {
cudnnDataType_t dt;
size_t ds;
switch (var->ga.typecode) {
case GA_FLOAT:
dt = CUDNN_DATA_FLOAT;
break;
case GA_DOUBLE:
dt = CUDNN_DATA_DOUBLE;
break;
case GA_HALF:
dt = CUDNN_DATA_HALF;
break;
default:
PyErr_SetString(PyExc_TypeError, "Non-float datatype in c_set_tensorNd");
return -1;
}
ds = gpuarray_get_elsize(var->ga.typecode);
int strs[8], dims[8], default_stride = 1;
unsigned int nd = PyGpuArray_NDIM(var);
for (unsigned int _i = nd; _i > 0; _i--) {
unsigned int i = _i - 1;
strs[i] = (PyGpuArray_DIM(var, i) != 1 && PyGpuArray_STRIDE(var, i)) ?
PyGpuArray_STRIDE(var, i)/ds : default_stride;
default_stride *= PyGpuArray_DIM(var, i);
dims[i] = PyGpuArray_DIM(var, i);
}
dims[1] = dims[1] / groups;
cudnnStatus_t err = cudnnSetTensorNdDescriptor(desc, dt, nd,
dims, strs);
if (err != CUDNN_STATUS_SUCCESS) {
PyErr_Format(PyExc_RuntimeError,
"Could not set tensorNd descriptor: %s",
cudnnGetErrorString(err));
return -1;
}
return 0;
c_set_tensorNd(PyGpuArrayObject *var, cudnnTensorDescriptor_t desc) {
return c_set_tensor_for_conv(var, desc, 1);
}
static int c_make_tensorNd(PyGpuArrayObject *var, cudnnTensorDescriptor_t *desc) {
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论