提交 f3398c59 authored 作者: Gijs van Tulder's avatar Gijs van Tulder

GpuCorrMM (gpuarray) with zero-sized inputs, channels, filters.

上级 382be885
......@@ -528,7 +528,7 @@ class BaseGpuCorrMM(CGpuKernelBase):
def c_code_cache_version(self):
# Raise this whenever modifying the code below.
return (6,)
return (7,)
def c_code_helper(self, bottom, weights, top, direction, sub, height=None, width=None):
"""
......@@ -1125,7 +1125,7 @@ class BaseGpuCorr3dMM(CGpuKernelBase):
def c_code_cache_version(self):
# raise this whenever modifying the code below.
return (6,)
return (7,)
def c_code_helper(self, bottom, weights, top, direction, sub,
height=None, width=None, depth=None):
......
......@@ -487,6 +487,17 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
PyGpuArrayObject *output;
if (direction == 0) { // forward pass
output = top;
if (batchSize == 0 || nChannels == 0 || nFilters == 0) {
err = GpuArray_memset(&output->ga, 0);
if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorr3dMM could not fill the output with zeros: %d", err);
Py_DECREF(col);
return NULL;
}
Py_DECREF(col);
return output;
}
// valid correlation: im3d2col, then gemm
// Iterate over batch
for (size_t n = 0; n < batchSize; n++) {
......@@ -538,6 +549,17 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
}
else if (direction == 1) { // backprop wrt. weights
output = weight;
if (batchSize == 0 || nChannels == 0 || nFilters == 0) {
err = GpuArray_memset(&output->ga, 0);
if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorr3dMM grad wrt. weights could not fill the output with zeros: %d", err);
Py_DECREF(col);
return NULL;
}
Py_DECREF(col);
return output;
}
// valid convolution: im3col, then gemm
// Iterate over batch
for (size_t n = 0; n < batchSize; n++) {
......@@ -589,9 +611,29 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
return NULL;
}
}
if (batchSize == 0) {
err = GpuArray_memset(&weight->ga, 0);
if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorr3dMM grad weights could not fill the output with zeros: %d", err);
Py_DECREF(col);
return NULL;
}
}
}
else if (direction == 2) { // backprop wrt. inputs
output = bottom;
if (batchSize == 0 || nChannels == 0 || nFilters == 0) {
err = GpuArray_memset(&output->ga, 0);
if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorr3dMM grad wrt. inputs could not fill the output with zeros: %d", err);
Py_DECREF(col);
return NULL;
}
Py_DECREF(col);
return output;
}
// full convolution: gemm, then col2im3d
// Iterate over batch
for (size_t n = 0; n < batchSize; n++) {
......
......@@ -418,6 +418,17 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom,
PyGpuArrayObject *output;
if (direction == 0) { // forward pass
output = top;
if (batchSize == 0 || nChannels == 0 || nFilters == 0) {
err = GpuArray_memset(&output->ga, 0);
if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM could not fill the output with zeros: %d", err);
Py_DECREF(col);
return NULL;
}
Py_DECREF(col);
return output;
}
// valid correlation: im2col, then gemm
// Iterate over batch
for (size_t n = 0; n < batchSize; n++) {
......@@ -469,6 +480,17 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom,
}
else if (direction == 1) { // backprop wrt. weights
output = weight;
if (batchSize == 0 || nChannels == 0 || nFilters == 0) {
err = GpuArray_memset(&output->ga, 0);
if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM grad wrt. weights could not fill the output with zeros: %d", err);
Py_DECREF(col);
return NULL;
}
Py_DECREF(col);
return output;
}
// valid convolution: im2col, then gemm
// Iterate over batch
for (size_t n = 0; n < batchSize; n++) {
......@@ -523,6 +545,17 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom,
}
else if (direction == 2) { // backprop wrt. inputs
output = bottom;
if (batchSize == 0 || nChannels == 0 || nFilters == 0) {
err = GpuArray_memset(&output->ga, 0);
if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM grad wrt. inputs could not fill the output with zeros: %d", err);
Py_DECREF(col);
return NULL;
}
Py_DECREF(col);
return output;
}
// full convolution: gemm, then col2im
// Iterate over batch
for (size_t n = 0; n < batchSize; n++) {
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论