提交 e8fdf903 authored 作者: Gijs van Tulder's avatar Gijs van Tulder

GpuDnnConv (sandbox.cuda) with zero-sized inputs, channels, filters.

上级 771cf248
...@@ -12,11 +12,6 @@ APPLY_SPECIFIC(conv_fwd)(CudaNdarray *input, CudaNdarray *kerns, ...@@ -12,11 +12,6 @@ APPLY_SPECIFIC(conv_fwd)(CudaNdarray *input, CudaNdarray *kerns,
return 1; return 1;
} }
if (c_set_tensorNd(input, APPLY_SPECIFIC(input)) == -1)
return 1;
if (c_set_filterNd(kerns, APPLY_SPECIFIC(kerns)) == -1)
return 1;
int nb_dim = CudaNdarray_NDIM(input); int nb_dim = CudaNdarray_NDIM(input);
#ifdef CONV_INPLACE #ifdef CONV_INPLACE
...@@ -30,6 +25,22 @@ APPLY_SPECIFIC(conv_fwd)(CudaNdarray *input, CudaNdarray *kerns, ...@@ -30,6 +25,22 @@ APPLY_SPECIFIC(conv_fwd)(CudaNdarray *input, CudaNdarray *kerns,
return 1; return 1;
#endif #endif
if (CudaNdarray_DIMS(input)[0] == 0 || CudaNdarray_DIMS(kerns)[0] == 0 || CudaNdarray_DIMS(kerns)[1] == 0) {
cudaError_t err2 = cudaMemset((*output)->devdata, 0,
CudaNdarray_SIZE(*output) * sizeof(real));
if (err2 != cudaSuccess) {
PyErr_Format(PyExc_RuntimeError,
"GpuDnnConv could not fill the output with zeros: %s",
cudaGetErrorString(err2));
return 1;
}
return 0;
}
if (c_set_tensorNd(input, APPLY_SPECIFIC(input)) == -1)
return 1;
if (c_set_filterNd(kerns, APPLY_SPECIFIC(kerns)) == -1)
return 1;
if (c_set_tensorNd(*output, APPLY_SPECIFIC(output)) == -1) if (c_set_tensorNd(*output, APPLY_SPECIFIC(output)) == -1)
return 1; return 1;
......
...@@ -12,11 +12,6 @@ APPLY_SPECIFIC(conv_gi)(CudaNdarray *kerns, CudaNdarray *output, ...@@ -12,11 +12,6 @@ APPLY_SPECIFIC(conv_gi)(CudaNdarray *kerns, CudaNdarray *output,
return 1; return 1;
} }
if (c_set_tensorNd(output, APPLY_SPECIFIC(output)) == -1)
return 1;
if (c_set_filterNd(kerns, APPLY_SPECIFIC(kerns)) == -1)
return 1;
int nb_dim = CudaNdarray_NDIM(output); int nb_dim = CudaNdarray_NDIM(output);
#ifdef CONV_INPLACE #ifdef CONV_INPLACE
...@@ -30,6 +25,22 @@ APPLY_SPECIFIC(conv_gi)(CudaNdarray *kerns, CudaNdarray *output, ...@@ -30,6 +25,22 @@ APPLY_SPECIFIC(conv_gi)(CudaNdarray *kerns, CudaNdarray *output,
return 1; return 1;
#endif #endif
if (CudaNdarray_DIMS(im)[0] == 0 || CudaNdarray_DIMS(kerns)[0] == 0 || CudaNdarray_DIMS(kerns)[1] == 0) {
cudaError_t err2 = cudaMemset((*input)->devdata, 0,
CudaNdarray_SIZE(*input) * sizeof(real));
if (err2 != cudaSuccess) {
PyErr_Format(PyExc_RuntimeError,
"GpuDnnConv grad wrt. inputs could not fill the output with zeros: %s",
cudaGetErrorString(err2));
return 1;
}
return 0;
}
if (c_set_tensorNd(output, APPLY_SPECIFIC(output)) == -1)
return 1;
if (c_set_filterNd(kerns, APPLY_SPECIFIC(kerns)) == -1)
return 1;
if (c_set_tensorNd(*input, APPLY_SPECIFIC(input)) == -1) if (c_set_tensorNd(*input, APPLY_SPECIFIC(input)) == -1)
return 1; return 1;
......
...@@ -12,11 +12,6 @@ APPLY_SPECIFIC(conv_gw)(CudaNdarray *input, CudaNdarray *output, ...@@ -12,11 +12,6 @@ APPLY_SPECIFIC(conv_gw)(CudaNdarray *input, CudaNdarray *output,
return 1; return 1;
} }
if (c_set_tensorNd(input, APPLY_SPECIFIC(input)) == -1)
return 1;
if (c_set_tensorNd(output, APPLY_SPECIFIC(output)) == -1)
return 1;
int nb_dim = CudaNdarray_NDIM(output); int nb_dim = CudaNdarray_NDIM(output);
#ifdef CONV_INPLACE #ifdef CONV_INPLACE
...@@ -30,6 +25,22 @@ APPLY_SPECIFIC(conv_gw)(CudaNdarray *input, CudaNdarray *output, ...@@ -30,6 +25,22 @@ APPLY_SPECIFIC(conv_gw)(CudaNdarray *input, CudaNdarray *output,
return 1; return 1;
#endif #endif
if (CudaNdarray_DIMS(input)[0] == 0 || CudaNdarray_DIMS(km)[0] == 0 || CudaNdarray_DIMS(km)[1] == 0) {
cudaError_t err2 = cudaMemset((*kerns)->devdata, 0,
CudaNdarray_SIZE(*kerns) * sizeof(real));
if (err2 != cudaSuccess) {
PyErr_Format(PyExc_RuntimeError,
"GpuDnnConv grad wrt. weights could not fill the output with zeros: %s",
cudaGetErrorString(err2));
return 1;
}
return 0;
}
if (c_set_tensorNd(input, APPLY_SPECIFIC(input)) == -1)
return 1;
if (c_set_tensorNd(output, APPLY_SPECIFIC(output)) == -1)
return 1;
if (c_set_filterNd(*kerns, APPLY_SPECIFIC(kerns)) == -1) if (c_set_filterNd(*kerns, APPLY_SPECIFIC(kerns)) == -1)
return 1; return 1;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论