提交 85f08efd authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Always make a copy of non-compliant array.

PyArray_GETCONTINUOUS considers 0-sized arrays as contiguous, even if they have negative strides on the 0-length dimension for instance, which confuses BLAS.
上级 2ac4b4ca
...@@ -496,7 +496,7 @@ class GemmRelated(Op): ...@@ -496,7 +496,7 @@ class GemmRelated(Op):
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] MOD type_size) || (Sx[1] MOD type_size) if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] MOD type_size) || (Sx[1] MOD type_size)
|| ((Sx[0] != type_size) && (Sx[1] != type_size))) || ((Sx[0] != type_size) && (Sx[1] != type_size)))
{ {
PyArrayObject * _x_copy = PyArray_GETCONTIGUOUS(%(_x)s); PyArrayObject * _x_copy = (PyArrayObject *) PyArray_Copy(%(_x)s);
if (!_x_copy) if (!_x_copy)
%(fail)s %(fail)s
Py_XDECREF(%(_x)s); Py_XDECREF(%(_x)s);
...@@ -507,7 +507,7 @@ class GemmRelated(Op): ...@@ -507,7 +507,7 @@ class GemmRelated(Op):
if ((Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] MOD type_size) || (Sy[1] MOD type_size) if ((Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] MOD type_size) || (Sy[1] MOD type_size)
|| ((Sy[0] != type_size) && (Sy[1] != type_size))) || ((Sy[0] != type_size) && (Sy[1] != type_size)))
{ {
PyArrayObject * _y_copy = PyArray_GETCONTIGUOUS(%(_y)s); PyArrayObject * _y_copy = (PyArrayObject *) PyArray_Copy(%(_y)s);
if (!_y_copy) if (!_y_copy)
%(fail)s %(fail)s
Py_XDECREF(%(_y)s); Py_XDECREF(%(_y)s);
...@@ -518,7 +518,7 @@ class GemmRelated(Op): ...@@ -518,7 +518,7 @@ class GemmRelated(Op):
if ((Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] MOD type_size) || (Sz[1] MOD type_size) if ((Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] MOD type_size) || (Sz[1] MOD type_size)
|| ((Sz[0] != type_size) && (Sz[1] != type_size))) || ((Sz[0] != type_size) && (Sz[1] != type_size)))
{ {
PyArrayObject * _z_copy = PyArray_GETCONTIGUOUS(%(_zout)s); PyArrayObject * _z_copy = (PyArrayObject *) PyArray_Copy(%(_zout)s);
if (!_z_copy) if (!_z_copy)
%(fail)s %(fail)s
Py_XDECREF(%(_zout)s); Py_XDECREF(%(_zout)s);
...@@ -655,7 +655,7 @@ class GemmRelated(Op): ...@@ -655,7 +655,7 @@ class GemmRelated(Op):
self.end_switch_typenum), '') self.end_switch_typenum), '')
def build_gemm_version(self): def build_gemm_version(self):
return (11,) return (12,)
class Gemm(GemmRelated): class Gemm(GemmRelated):
"""In-place version of matrix-matrix multiplication (with accumulation): """In-place version of matrix-matrix multiplication (with accumulation):
......
...@@ -371,7 +371,7 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -371,7 +371,7 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
dims[0] = Nx0; dims[0] = Nx0;
dims[1] = Nx1; dims[1] = Nx1;
PyArrayObject * xx_copy = PyArray_GETCONTIGUOUS(%(xx)s); PyArrayObject * xx_copy = (PyArrayObject *) PyArray_Copy(%(xx)s);
if (!xx_copy) if (!xx_copy)
%(fail)s %(fail)s
Py_XDECREF(%(xx)s); Py_XDECREF(%(xx)s);
...@@ -475,7 +475,7 @@ class CGemv(BaseBLAS, Gemv): ...@@ -475,7 +475,7 @@ class CGemv(BaseBLAS, Gemv):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (6,) return (7,)
@local_optimizer([gemv_inplace, gemv_no_inplace]) @local_optimizer([gemv_inplace, gemv_no_inplace])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论