提交 77863a7b authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix bug in GemmRelated C-code when matrices are non-contiguous

上级 6cd840c7
...@@ -261,7 +261,7 @@ class GemmRelated(Op): ...@@ -261,7 +261,7 @@ class GemmRelated(Op):
{PyErr_SetString(PyExc_NotImplementedError, "type(b) is not double or float"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "type(b) is not double or float"); %(fail)s;}
""" """
check_dims_strides = """ check_dims = """
if (Nx[0] != Nz[0]) if (Nx[0] != Nz[0])
{ {
PyErr_Format(PyExc_ValueError, PyErr_Format(PyExc_ValueError,
...@@ -283,11 +283,45 @@ class GemmRelated(Op): ...@@ -283,11 +283,45 @@ class GemmRelated(Op):
(long int)Ny[1], (long int)Nz[1]); (long int)Ny[1], (long int)Nz[1]);
%(fail)s; %(fail)s;
} }
"""
check_strides = """
/*
If some matrices are not contiguous on either dimensions,
or have invalid strides, copy their content into a contiguous one
*/
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] MOD type_size) || (Sx[1] MOD type_size) if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] MOD type_size) || (Sx[1] MOD type_size)
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] MOD type_size) || (Sy[1] MOD type_size) || ((Sx[0] != type_size) && (Sx[1] != type_size)))
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] MOD type_size) || (Sz[1] MOD type_size)) {
PyArrayObject * _x_copy = PyArray_GETCONTIGUOUS(%(_x)s);
Py_XDECREF(%(_x)s);
%(_x)s = _x_copy;
Sx = %(_x)s->strides;
}
if ((Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] MOD type_size) || (Sy[1] MOD type_size)
|| ((Sy[0] != type_size) && (Sy[1] != type_size)))
{
PyArrayObject * _y_copy = PyArray_GETCONTIGUOUS(%(_y)s);
Py_XDECREF(%(_y)s);
%(_y)s = _y_copy;
Sy = %(_y)s->strides;
}
if ((Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] MOD type_size) || (Sz[1] MOD type_size)
|| ((Sz[0] != type_size) && (Sz[1] != type_size)))
{ {
PyErr_SetString(PyExc_NotImplementedError, "stride is not multiple of element size"); %(fail)s; PyArrayObject * _z_copy = PyArray_GETCONTIGUOUS(%(_z)s);
Py_XDECREF(%(_z)s);
// if we work inplace, %(_zout) should also point to _z_copy
if (%(_z)s == %(_zout)s)
{
Py_XDECREF(%(_zout)s);
%(_zout)s = _z_copy;
Py_XINCREF(%(_zout)s);
}
%(_z)s = _z_copy;
Sz = %(_z)s->strides;
} }
""" """
...@@ -395,7 +429,8 @@ class GemmRelated(Op): ...@@ -395,7 +429,8 @@ class GemmRelated(Op):
self.check_xyz_rank2, self.check_xyz_rank2,
self.check_xyz_double_or_float, self.check_xyz_double_or_float,
self.check_ab_double_or_float, self.check_ab_double_or_float,
self.check_dims_strides, self.check_dims,
self.check_strides,
self.encode_strides_in_unit, self.encode_strides_in_unit,
self.compute_strides, self.compute_strides,
self.begin_switch_typenum, self.begin_switch_typenum,
...@@ -408,6 +443,7 @@ class GemmRelated(Op): ...@@ -408,6 +443,7 @@ class GemmRelated(Op):
self.end_switch_typenum), '') self.end_switch_typenum), '')
def build_gemm_version(self): def build_gemm_version(self):
return ()
return (4,) return (4,)
class Gemm(GemmRelated): class Gemm(GemmRelated):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论