提交 771cf248 authored 作者: Gijs van Tulder's avatar Gijs van Tulder

GpuCorrMM (sandbox.cuda) with zero-sized inputs, channels, filters.

上级 b0b0f076
...@@ -922,7 +922,7 @@ class BaseGpuCorrMM(GpuOp): ...@@ -922,7 +922,7 @@ class BaseGpuCorrMM(GpuOp):
def c_code_cache_version(self): def c_code_cache_version(self):
# raise this whenever modifying any of the support_code_files # raise this whenever modifying any of the support_code_files
return (0, 29) return (0, 30)
def c_support_code_apply(self, node, nodename): def c_support_code_apply(self, node, nodename):
# REMEMBER TO RAISE c_code_cache_version when changing any of # REMEMBER TO RAISE c_code_cache_version when changing any of
...@@ -1513,7 +1513,7 @@ class BaseGpuCorr3dMM(GpuOp): ...@@ -1513,7 +1513,7 @@ class BaseGpuCorr3dMM(GpuOp):
def c_code_cache_version(self): def c_code_cache_version(self):
# raise this whenever modifying any of the support_code_files # raise this whenever modifying any of the support_code_files
return (0, 28) return (0, 29)
def c_support_code_apply(self, node, nodename): def c_support_code_apply(self, node, nodename):
# REMEMBER TO RAISE c_code_cache_version when changing any of # REMEMBER TO RAISE c_code_cache_version when changing any of
......
...@@ -486,6 +486,19 @@ CudaNdarray* corr3dMM(CudaNdarray *const bottom, ...@@ -486,6 +486,19 @@ CudaNdarray* corr3dMM(CudaNdarray *const bottom,
if (direction == 0) if (direction == 0)
{ // forward pass { // forward pass
output = top; output = top;
if (batchSize == 0 || nChannels == 0 || nFilters == 0) {
cudaError_t err = cudaMemset(output->devdata, 0,
CudaNdarray_SIZE(output) * sizeof(real));
if (err != cudaSuccess) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorr3dMM could not fill the output with zeros: %s",
cudaGetErrorString(err));
Py_DECREF(col);
return NULL;
}
Py_DECREF(col);
return output;
}
// valid correlation: im2col, then gemm // valid correlation: im2col, then gemm
// Iterate over batch // Iterate over batch
for (int n = 0; n < batchSize; n++) for (int n = 0; n < batchSize; n++)
...@@ -535,6 +548,19 @@ CudaNdarray* corr3dMM(CudaNdarray *const bottom, ...@@ -535,6 +548,19 @@ CudaNdarray* corr3dMM(CudaNdarray *const bottom,
{ {
// backprop wrt. weights // backprop wrt. weights
output = weight; output = weight;
if (batchSize == 0 || nChannels == 0 || nFilters == 0) {
cudaError_t err = cudaMemset(output->devdata, 0,
CudaNdarray_SIZE(output) * sizeof(real));
if (err != cudaSuccess) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorr3dMM grad wrt. weights could not fill the output with zeros: %s",
cudaGetErrorString(err));
Py_DECREF(col);
return NULL;
}
Py_DECREF(col);
return output;
}
// valid convolution: im2col, then gemm // valid convolution: im2col, then gemm
// Iterate over batch // Iterate over batch
for (int n = 0; n < batchSize; n++) for (int n = 0; n < batchSize; n++)
...@@ -586,6 +612,19 @@ CudaNdarray* corr3dMM(CudaNdarray *const bottom, ...@@ -586,6 +612,19 @@ CudaNdarray* corr3dMM(CudaNdarray *const bottom,
{ {
// backprop wrt. inputs // backprop wrt. inputs
output = bottom; output = bottom;
if (batchSize == 0 || nChannels == 0 || nFilters == 0) {
cudaError_t err = cudaMemset(output->devdata, 0,
CudaNdarray_SIZE(output) * sizeof(real));
if (err != cudaSuccess) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorr3dMM grad wrt. inputs could not fill the output with zeros: %s",
cudaGetErrorString(err));
Py_DECREF(col);
return NULL;
}
Py_DECREF(col);
return output;
}
// full convolution: gemm, then col2im3d // full convolution: gemm, then col2im3d
// Iterate over batch // Iterate over batch
for (int n = 0; n < batchSize; n++) for (int n = 0; n < batchSize; n++)
......
...@@ -384,6 +384,19 @@ CudaNdarray* corrMM(CudaNdarray *const bottom, ...@@ -384,6 +384,19 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
CudaNdarray *output; CudaNdarray *output;
if (direction == 0) { // forward pass if (direction == 0) { // forward pass
output = top; output = top;
if (batchSize == 0 || nChannels == 0 || nFilters == 0) {
cudaError_t err = cudaMemset(output->devdata, 0,
CudaNdarray_SIZE(output) * sizeof(real));
if (err != cudaSuccess) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM could not fill the output with zeros: %s",
cudaGetErrorString(err));
Py_DECREF(col);
return NULL;
}
Py_DECREF(col);
return output;
}
// valid correlation: im2col, then gemm // valid correlation: im2col, then gemm
// Iterate over batch // Iterate over batch
for (int n = 0; n < batchSize; n++) { for (int n = 0; n < batchSize; n++) {
...@@ -452,6 +465,19 @@ CudaNdarray* corrMM(CudaNdarray *const bottom, ...@@ -452,6 +465,19 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
} }
else if (direction == 1) { // backprop wrt. weights else if (direction == 1) { // backprop wrt. weights
output = weight; output = weight;
if (batchSize == 0 || nChannels == 0 || nFilters == 0) {
cudaError_t err = cudaMemset(output->devdata, 0,
CudaNdarray_SIZE(output) * sizeof(real));
if (err != cudaSuccess) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM grad wrt. weights could not fill the output with zeros: %s",
cudaGetErrorString(err));
Py_DECREF(col);
return NULL;
}
Py_DECREF(col);
return output;
}
// valid convolution: im2col, then gemm // valid convolution: im2col, then gemm
// Iterate over batch // Iterate over batch
for (int n = 0; n < batchSize; n++) { for (int n = 0; n < batchSize; n++) {
...@@ -520,6 +546,19 @@ CudaNdarray* corrMM(CudaNdarray *const bottom, ...@@ -520,6 +546,19 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
} }
else if (direction == 2) { // backprop wrt. inputs else if (direction == 2) { // backprop wrt. inputs
output = bottom; output = bottom;
if (batchSize == 0 || nChannels == 0 || nFilters == 0) {
cudaError_t err = cudaMemset(output->devdata, 0,
CudaNdarray_SIZE(output) * sizeof(real));
if (err != cudaSuccess) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM grad wrt. inputs could not fill the output with zeros: %s",
cudaGetErrorString(err));
Py_DECREF(col);
return NULL;
}
Py_DECREF(col);
return output;
}
// full convolution: gemm, then col2im // full convolution: gemm, then col2im
// Iterate over batch // Iterate over batch
for (int n = 0; n < batchSize; n++) { for (int n = 0; n < batchSize; n++) {
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论