提交 6685523c authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Use numpy's function to check array is contiguous

In particular, strides for dimensions of length 1 could be anything
上级 cb1e2e32
...@@ -964,7 +964,7 @@ class ConvOp(OpenMPOp): ...@@ -964,7 +964,7 @@ class ConvOp(OpenMPOp):
return ['<numpy/noprefix.h>', '<iostream>', '<sstream>'] return ['<numpy/noprefix.h>', '<iostream>', '<sstream>']
def c_code_cache_version(self): def c_code_cache_version(self):
return (11, self.openmp, blas.blas_header_version()) return (12, self.openmp, blas.blas_header_version())
def c_support_code(self): def c_support_code(self):
return """ return """
...@@ -2157,11 +2157,12 @@ if ((!%(z)s) ...@@ -2157,11 +2157,12 @@ if ((!%(z)s)
} }
z_arr = (PyArrayObject*) %(z)s; z_arr = (PyArrayObject*) %(z)s;
//assertions // assert the output is C-contiguous
if (PyArray_STRIDES(%(z)s)[0] != PyArray_DIMS(%(z)s)[1] *PyArray_DIMS(%(z)s)[2] *PyArray_DIMS(%(z)s)[3] * sizeof(%(type)s)) %(fail)s; if (!PyArray_ISCONTIGUOUS(%(z)s))
if (PyArray_STRIDES(%(z)s)[1] != PyArray_DIMS(%(z)s)[2] * PyArray_DIMS(%(z)s)[3] * sizeof(%(type)s)) %(fail)s; {
if (PyArray_STRIDES(%(z)s)[2] != PyArray_DIMS(%(z)s)[3] * sizeof(%(type)s)) %(fail)s; PyErr_SetString(PyExc_AssertionError, "Output (%(z)s) not contiguous");
if (PyArray_STRIDES(%(z)s)[3] != sizeof(%(type)s)) %(fail)s; %(fail)s;
}
//The if on the number of loop make a speed up for small array. //The if on the number of loop make a speed up for small array.
//with g++ 4.5.1. The compiler should be smart enough to do this himself! //with g++ 4.5.1. The compiler should be smart enough to do this himself!
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论