提交 f3398c59 authored 作者: Gijs van Tulder's avatar Gijs van Tulder

GpuCorrMM (gpuarray) with zero-sized inputs, channels, filters.

上级 382be885
...@@ -528,7 +528,7 @@ class BaseGpuCorrMM(CGpuKernelBase): ...@@ -528,7 +528,7 @@ class BaseGpuCorrMM(CGpuKernelBase):
def c_code_cache_version(self): def c_code_cache_version(self):
# Raise this whenever modifying the code below. # Raise this whenever modifying the code below.
return (6,) return (7,)
def c_code_helper(self, bottom, weights, top, direction, sub, height=None, width=None): def c_code_helper(self, bottom, weights, top, direction, sub, height=None, width=None):
""" """
...@@ -1125,7 +1125,7 @@ class BaseGpuCorr3dMM(CGpuKernelBase): ...@@ -1125,7 +1125,7 @@ class BaseGpuCorr3dMM(CGpuKernelBase):
def c_code_cache_version(self): def c_code_cache_version(self):
# raise this whenever modifying the code below. # raise this whenever modifying the code below.
return (6,) return (7,)
def c_code_helper(self, bottom, weights, top, direction, sub, def c_code_helper(self, bottom, weights, top, direction, sub,
height=None, width=None, depth=None): height=None, width=None, depth=None):
......
...@@ -487,6 +487,17 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom, ...@@ -487,6 +487,17 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
PyGpuArrayObject *output; PyGpuArrayObject *output;
if (direction == 0) { // forward pass if (direction == 0) { // forward pass
output = top; output = top;
if (batchSize == 0 || nChannels == 0 || nFilters == 0) {
err = GpuArray_memset(&output->ga, 0);
if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorr3dMM could not fill the output with zeros: %d", err);
Py_DECREF(col);
return NULL;
}
Py_DECREF(col);
return output;
}
// valid correlation: im3d2col, then gemm // valid correlation: im3d2col, then gemm
// Iterate over batch // Iterate over batch
for (size_t n = 0; n < batchSize; n++) { for (size_t n = 0; n < batchSize; n++) {
...@@ -538,6 +549,17 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom, ...@@ -538,6 +549,17 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
} }
else if (direction == 1) { // backprop wrt. weights else if (direction == 1) { // backprop wrt. weights
output = weight; output = weight;
if (batchSize == 0 || nChannels == 0 || nFilters == 0) {
err = GpuArray_memset(&output->ga, 0);
if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorr3dMM grad wrt. weights could not fill the output with zeros: %d", err);
Py_DECREF(col);
return NULL;
}
Py_DECREF(col);
return output;
}
// valid convolution: im3col, then gemm // valid convolution: im3col, then gemm
// Iterate over batch // Iterate over batch
for (size_t n = 0; n < batchSize; n++) { for (size_t n = 0; n < batchSize; n++) {
...@@ -589,9 +611,29 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom, ...@@ -589,9 +611,29 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
return NULL; return NULL;
} }
} }
if (batchSize == 0) {
err = GpuArray_memset(&weight->ga, 0);
if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorr3dMM grad weights could not fill the output with zeros: %d", err);
Py_DECREF(col);
return NULL;
}
}
} }
else if (direction == 2) { // backprop wrt. inputs else if (direction == 2) { // backprop wrt. inputs
output = bottom; output = bottom;
if (batchSize == 0 || nChannels == 0 || nFilters == 0) {
err = GpuArray_memset(&output->ga, 0);
if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorr3dMM grad wrt. inputs could not fill the output with zeros: %d", err);
Py_DECREF(col);
return NULL;
}
Py_DECREF(col);
return output;
}
// full convolution: gemm, then col2im3d // full convolution: gemm, then col2im3d
// Iterate over batch // Iterate over batch
for (size_t n = 0; n < batchSize; n++) { for (size_t n = 0; n < batchSize; n++) {
......
...@@ -418,6 +418,17 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom, ...@@ -418,6 +418,17 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom,
PyGpuArrayObject *output; PyGpuArrayObject *output;
if (direction == 0) { // forward pass if (direction == 0) { // forward pass
output = top; output = top;
if (batchSize == 0 || nChannels == 0 || nFilters == 0) {
err = GpuArray_memset(&output->ga, 0);
if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM could not fill the output with zeros: %d", err);
Py_DECREF(col);
return NULL;
}
Py_DECREF(col);
return output;
}
// valid correlation: im2col, then gemm // valid correlation: im2col, then gemm
// Iterate over batch // Iterate over batch
for (size_t n = 0; n < batchSize; n++) { for (size_t n = 0; n < batchSize; n++) {
...@@ -469,6 +480,17 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom, ...@@ -469,6 +480,17 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom,
} }
else if (direction == 1) { // backprop wrt. weights else if (direction == 1) { // backprop wrt. weights
output = weight; output = weight;
if (batchSize == 0 || nChannels == 0 || nFilters == 0) {
err = GpuArray_memset(&output->ga, 0);
if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM grad wrt. weights could not fill the output with zeros: %d", err);
Py_DECREF(col);
return NULL;
}
Py_DECREF(col);
return output;
}
// valid convolution: im2col, then gemm // valid convolution: im2col, then gemm
// Iterate over batch // Iterate over batch
for (size_t n = 0; n < batchSize; n++) { for (size_t n = 0; n < batchSize; n++) {
...@@ -523,6 +545,17 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom, ...@@ -523,6 +545,17 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom,
} }
else if (direction == 2) { // backprop wrt. inputs else if (direction == 2) { // backprop wrt. inputs
output = bottom; output = bottom;
if (batchSize == 0 || nChannels == 0 || nFilters == 0) {
err = GpuArray_memset(&output->ga, 0);
if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM grad wrt. inputs could not fill the output with zeros: %d", err);
Py_DECREF(col);
return NULL;
}
Py_DECREF(col);
return output;
}
// full convolution: gemm, then col2im // full convolution: gemm, then col2im
// Iterate over batch // Iterate over batch
for (size_t n = 0; n < batchSize; n++) { for (size_t n = 0; n < batchSize; n++) {
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论