提交 e3063817 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fail if array copy failed.

上级 1143c9f8
...@@ -497,6 +497,8 @@ class GemmRelated(Op): ...@@ -497,6 +497,8 @@ class GemmRelated(Op):
|| ((Sx[0] != type_size) && (Sx[1] != type_size))) || ((Sx[0] != type_size) && (Sx[1] != type_size)))
{ {
PyArrayObject * _x_copy = PyArray_GETCONTIGUOUS(%(_x)s); PyArrayObject * _x_copy = PyArray_GETCONTIGUOUS(%(_x)s);
if (!_x_copy)
%(fail)s
Py_XDECREF(%(_x)s); Py_XDECREF(%(_x)s);
%(_x)s = _x_copy; %(_x)s = _x_copy;
Sx = %(_x)s->strides; Sx = %(_x)s->strides;
...@@ -506,6 +508,8 @@ class GemmRelated(Op): ...@@ -506,6 +508,8 @@ class GemmRelated(Op):
|| ((Sy[0] != type_size) && (Sy[1] != type_size))) || ((Sy[0] != type_size) && (Sy[1] != type_size)))
{ {
PyArrayObject * _y_copy = PyArray_GETCONTIGUOUS(%(_y)s); PyArrayObject * _y_copy = PyArray_GETCONTIGUOUS(%(_y)s);
if (!_y_copy)
%(fail)s
Py_XDECREF(%(_y)s); Py_XDECREF(%(_y)s);
%(_y)s = _y_copy; %(_y)s = _y_copy;
Sy = %(_y)s->strides; Sy = %(_y)s->strides;
...@@ -515,6 +519,8 @@ class GemmRelated(Op): ...@@ -515,6 +519,8 @@ class GemmRelated(Op):
|| ((Sz[0] != type_size) && (Sz[1] != type_size))) || ((Sz[0] != type_size) && (Sz[1] != type_size)))
{ {
PyArrayObject * _z_copy = PyArray_GETCONTIGUOUS(%(_zout)s); PyArrayObject * _z_copy = PyArray_GETCONTIGUOUS(%(_zout)s);
if (!_z_copy)
%(fail)s
Py_XDECREF(%(_zout)s); Py_XDECREF(%(_zout)s);
%(_zout)s = _z_copy; %(_zout)s = _z_copy;
Sz = %(_zout)s->strides; Sz = %(_zout)s->strides;
...@@ -649,7 +655,7 @@ class GemmRelated(Op): ...@@ -649,7 +655,7 @@ class GemmRelated(Op):
self.end_switch_typenum), '') self.end_switch_typenum), '')
def build_gemm_version(self): def build_gemm_version(self):
return (10,) return (11,)
class Gemm(GemmRelated): class Gemm(GemmRelated):
"""In-place version of matrix-matrix multiplication (with accumulation): """In-place version of matrix-matrix multiplication (with accumulation):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论