提交 c166a9e7 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2352 from lamblin/fix_conv2d_stride_check

Use numpy's function to check array is contiguous
......@@ -964,7 +964,7 @@ class ConvOp(OpenMPOp):
return ['<numpy/noprefix.h>', '<iostream>', '<sstream>']
def c_code_cache_version(self):
return (11, self.openmp, blas.blas_header_version())
return (12, self.openmp, blas.blas_header_version())
def c_support_code(self):
return """
......@@ -2157,11 +2157,12 @@ if ((!%(z)s)
}
z_arr = (PyArrayObject*) %(z)s;
//assertions
if (PyArray_STRIDES(%(z)s)[0] != PyArray_DIMS(%(z)s)[1] *PyArray_DIMS(%(z)s)[2] *PyArray_DIMS(%(z)s)[3] * sizeof(%(type)s)) %(fail)s;
if (PyArray_STRIDES(%(z)s)[1] != PyArray_DIMS(%(z)s)[2] * PyArray_DIMS(%(z)s)[3] * sizeof(%(type)s)) %(fail)s;
if (PyArray_STRIDES(%(z)s)[2] != PyArray_DIMS(%(z)s)[3] * sizeof(%(type)s)) %(fail)s;
if (PyArray_STRIDES(%(z)s)[3] != sizeof(%(type)s)) %(fail)s;
// assert the output is C-contiguous
if (!PyArray_ISCONTIGUOUS(%(z)s))
{
PyErr_SetString(PyExc_AssertionError, "Output (%(z)s) not contiguous");
%(fail)s;
}
//The if on the number of loop make a speed up for small array.
//with g++ 4.5.1. The compiler should be smart enough to do this himself!
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论