提交 65798698 authored 作者: Gijs van Tulder's avatar Gijs van Tulder

Fix CorrMM for zero-sized inputs (avoid sgemm errors).

上级 0a361bbc
...@@ -123,7 +123,7 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -123,7 +123,7 @@ class BaseCorrMM(gof.OpenMPOp):
def c_code_cache_version(self): def c_code_cache_version(self):
# raise this whenever modifying any of the support_code_files # raise this whenever modifying any of the support_code_files
return (4, self.openmp, blas_header_version()) return (5, self.openmp, blas_header_version())
def c_support_code_apply(self, node, nodename): def c_support_code_apply(self, node, nodename):
# REMEMBER TO RAISE c_code_cache_version when changing any of # REMEMBER TO RAISE c_code_cache_version when changing any of
......
...@@ -123,7 +123,7 @@ class BaseCorr3dMM(gof.OpenMPOp): ...@@ -123,7 +123,7 @@ class BaseCorr3dMM(gof.OpenMPOp):
def c_code_cache_version(self): def c_code_cache_version(self):
# raise this whenever modifying any of the support_code_files # raise this whenever modifying any of the support_code_files
return (4, self.openmp, blas_header_version()) return (5, self.openmp, blas_header_version())
def c_support_code_apply(self, node, nodename): def c_support_code_apply(self, node, nodename):
# REMEMBER TO RAISE c_code_cache_version when changing any of # REMEMBER TO RAISE c_code_cache_version when changing any of
......
...@@ -253,7 +253,23 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom, ...@@ -253,7 +253,23 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom,
char Trans = 'T'; char Trans = 'T';
PyArrayObject *output; PyArrayObject *output;
if (direction == 0) { // forward pass if (batchSize == 0 || nChannels == 0 || nFilters == 0) {
switch(direction) {
case 0:
output = top;
break;
case 1:
output = weight;
break;
case 2:
output = bottom;
break;
default:
return NULL;
}
PyArray_FILLWBYTE(output, 0);
}
else if (direction == 0) { // forward pass
output = top; output = top;
// valid correlation: im3d2col, then gemm // valid correlation: im3d2col, then gemm
// Iterate over batch // Iterate over batch
......
...@@ -226,7 +226,23 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -226,7 +226,23 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
char Trans = 'T'; char Trans = 'T';
PyArrayObject *output; PyArrayObject *output;
if (direction == 0) { // forward pass if (batchSize == 0 || nChannels == 0 || nFilters == 0) {
switch(direction) {
case 0:
output = top;
break;
case 1:
output = weight;
break;
case 2:
output = bottom;
break;
default:
return NULL;
}
PyArray_FILLWBYTE(output, 0);
}
else if (direction == 0) { // forward pass
output = top; output = top;
// valid correlation: im2col, then gemm // valid correlation: im2col, then gemm
// Iterate over batch // Iterate over batch
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论