提交 75a1ffc5 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix crash in code for non-contiguous matrices in gemm-related op.

上级 ff8d1e25
......@@ -311,17 +311,10 @@ class GemmRelated(Op):
if ((Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] MOD type_size) || (Sz[1] MOD type_size)
|| ((Sz[0] != type_size) && (Sz[1] != type_size)))
{
PyArrayObject * _z_copy = PyArray_GETCONTIGUOUS(%(_z)s);
Py_XDECREF(%(_z)s);
// if we work inplace, %(_zout) should also point to _z_copy
if (%(_z)s == %(_zout)s)
{
Py_XDECREF(%(_zout)s);
%(_zout)s = _z_copy;
Py_XINCREF(%(_zout)s);
}
%(_z)s = _z_copy;
Sz = %(_z)s->strides;
PyArrayObject * _z_copy = PyArray_GETCONTIGUOUS(%(_zout)s);
Py_XDECREF(%(_zout)s);
%(_zout)s = _z_copy;
Sz = %(_zout)s->strides;
}
"""
......@@ -443,7 +436,7 @@ class GemmRelated(Op):
self.end_switch_typenum), '')
def build_gemm_version(self):
return (5,)
return (6,)
class Gemm(GemmRelated):
"""In-place version of matrix-matrix multiplication (with accumulation):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论