提交 a5256ba2 authored 作者: James Bergstra's avatar James Bergstra

Important bug-fix to GEMM since the inplace-gemm modification a few days ago.

上级 7c84cbe2
...@@ -102,11 +102,11 @@ class GemmRelated(Op): ...@@ -102,11 +102,11 @@ class GemmRelated(Op):
npy_intp* Nx = %(_x)s->dimensions; npy_intp* Nx = %(_x)s->dimensions;
npy_intp* Ny = %(_y)s->dimensions; npy_intp* Ny = %(_y)s->dimensions;
npy_intp* Nz = 0; //%(_z)s->dimensions; npy_intp* Nz = 0; //%(_zout)s->dimensions;
npy_intp* Sx = %(_x)s->strides; npy_intp* Sx = %(_x)s->strides;
npy_intp* Sy = %(_y)s->strides; npy_intp* Sy = %(_y)s->strides;
npy_intp* Sz = 0; //%(_z)s->strides; npy_intp* Sz = 0; //%(_zout)s->strides;
//strides for x, y, z in dimensions 0, 1 //strides for x, y, z in dimensions 0, 1
int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1; int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
...@@ -117,7 +117,7 @@ class GemmRelated(Op): ...@@ -117,7 +117,7 @@ class GemmRelated(Op):
check_xyz_rank2 = """ check_xyz_rank2 = """
if (%(_x)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); %(fail)s;} if (%(_x)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); %(fail)s;}
if (%(_y)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); %(fail)s;} if (%(_y)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); %(fail)s;}
if (%(_z)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(z) != 2"); %(fail)s;} if (%(_zout)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(z) != 2"); %(fail)s;}
""" """
check_xyz_double_or_float = """ check_xyz_double_or_float = """
if ((%(_x)s->descr->type_num != PyArray_DOUBLE) if ((%(_x)s->descr->type_num != PyArray_DOUBLE)
...@@ -128,12 +128,12 @@ class GemmRelated(Op): ...@@ -128,12 +128,12 @@ class GemmRelated(Op):
&& (%(_y)s->descr->type_num != PyArray_FLOAT)) && (%(_y)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;}
if ((%(_z)s->descr->type_num != PyArray_DOUBLE) if ((%(_zout)s->descr->type_num != PyArray_DOUBLE)
&& (%(_z)s->descr->type_num != PyArray_FLOAT)) && (%(_zout)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;}
if ((%(_x)s->descr->type_num != %(_y)s->descr->type_num) if ((%(_x)s->descr->type_num != %(_y)s->descr->type_num)
||(%(_x)s->descr->type_num != %(_z)s->descr->type_num)) ||(%(_x)s->descr->type_num != %(_zout)s->descr->type_num))
{ PyErr_SetString(PyExc_NotImplementedError, "type(z), type(y), type(z) are not all the same"); %(fail)s; } { PyErr_SetString(PyExc_NotImplementedError, "type(z), type(y), type(z) are not all the same"); %(fail)s; }
""" """
...@@ -164,7 +164,7 @@ class GemmRelated(Op): ...@@ -164,7 +164,7 @@ class GemmRelated(Op):
encode_strides_in_unit = """ encode_strides_in_unit = """
/* /*
encode the stride structure of _x,_y,_z into a single integer encode the stride structure of _x,_y,_zout into a single integer
*/ */
unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) << 8; unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) << 8;
unit |= ((Sy[1] == type_size) ? 0x0 : (Sy[0] == type_size) ? 0x1 : 0x2) << 4; unit |= ((Sy[1] == type_size) ? 0x0 : (Sy[0] == type_size) ? 0x1 : 0x2) << 4;
...@@ -198,7 +198,7 @@ class GemmRelated(Op): ...@@ -198,7 +198,7 @@ class GemmRelated(Op):
case_float_gemm = """ case_float_gemm = """
float* x = (float*)PyArray_DATA(%(_x)s); float* x = (float*)PyArray_DATA(%(_x)s);
float* y = (float*)PyArray_DATA(%(_y)s); float* y = (float*)PyArray_DATA(%(_y)s);
float* z = (float*)PyArray_DATA(%(_z)s); float* z = (float*)PyArray_DATA(%(_zout)s);
char N = 'N'; char N = 'N';
char T = 'T'; char T = 'T';
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1]; int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
...@@ -229,7 +229,7 @@ class GemmRelated(Op): ...@@ -229,7 +229,7 @@ class GemmRelated(Op):
case_double_gemm = """ case_double_gemm = """
double* x = (double*)PyArray_DATA(%(_x)s); double* x = (double*)PyArray_DATA(%(_x)s);
double* y = (double*)PyArray_DATA(%(_y)s); double* y = (double*)PyArray_DATA(%(_y)s);
double* z = (double*)PyArray_DATA(%(_z)s); double* z = (double*)PyArray_DATA(%(_zout)s);
char N = 'N'; char N = 'N';
char T = 'T'; char T = 'T';
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1]; int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
...@@ -393,7 +393,7 @@ class Gemm(GemmRelated): ...@@ -393,7 +393,7 @@ class Gemm(GemmRelated):
|| (%(_zout)s->dimensions[0] != %(_z)s->dimensions[0]) || (%(_zout)s->dimensions[0] != %(_z)s->dimensions[0])
|| (%(_zout)s->dimensions[1] != %(_z)s->dimensions[1])) || (%(_zout)s->dimensions[1] != %(_z)s->dimensions[1]))
{ {
if (NULL != %(_zout)s) Py_XDECREF(%(_zout)s); if (%(_zout)s) Py_XDECREF(%(_zout)s);
npy_intp dims[2]; npy_intp dims[2];
dims[0] = %(_z)s->dimensions[0]; dims[0] = %(_z)s->dimensions[0];
dims[1] = %(_z)s->dimensions[1]; dims[1] = %(_z)s->dimensions[1];
...@@ -405,6 +405,42 @@ class Gemm(GemmRelated): ...@@ -405,6 +405,42 @@ class Gemm(GemmRelated):
} }
Nz = %(_zout)s->dimensions; Nz = %(_zout)s->dimensions;
Sz = %(_zout)s->strides; Sz = %(_zout)s->strides;
if (1) // COPY z -> zout
{
if (%(_zout)s->descr->type_num == PyArray_FLOAT)
{
float * zoutdata = (float*)%(_zout)s->data;
const float * zdata = (float*)%(_z)s->data;
int zi = %(_z)s->strides[0]/sizeof(float);
int zj = %(_z)s->strides[1]/sizeof(float);
for (int i = 0; i < Nz[0]; ++i)
{
for (int j = 0; j < Nz[1]; ++j)
{
zoutdata[i*Nz[1]+j] = zdata[zi*i+zj*j];
}
}
}
else if (%(_zout)s->descr->type_num == PyArray_DOUBLE)
{
double * zoutdata = (double*) %(_zout)s->data;
const double * zdata = (double*)%(_z)s->data;
int zi = %(_z)s->strides[0]/sizeof(double);
int zj = %(_z)s->strides[1]/sizeof(double);
for (int i = 0; i < Nz[0]; ++i)
{
for (int j = 0; j < Nz[1]; ++j)
{
zoutdata[i*Nz[1]+j] = zdata[zi*i+zj*j];
}
}
}
else
{
PyErr_SetString(PyExc_AssertionError, "neither float nor double dtype");
%(fail)s
}
}
""" """
case_float_ab_constants = """ case_float_ab_constants = """
...@@ -435,7 +471,7 @@ class Gemm(GemmRelated): ...@@ -435,7 +471,7 @@ class Gemm(GemmRelated):
return full_code return full_code
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) + self.build_gemm_version() return (3,) + self.build_gemm_version()
gemm_inplace = Gemm(inplace=True) gemm_inplace = Gemm(inplace=True)
gemm_no_inplace = Gemm(inplace=False) gemm_no_inplace = Gemm(inplace=False)
...@@ -684,22 +720,22 @@ class Dot22(GemmRelated): ...@@ -684,22 +720,22 @@ class Dot22(GemmRelated):
return "_dot22" return "_dot22"
setup_z_Nz_Sz = """ setup_z_Nz_Sz = """
if ((NULL == %(_z)s) if ((NULL == %(_zout)s)
|| (%(_z)s->dimensions[0] != %(_x)s->dimensions[0]) || (%(_zout)s->dimensions[0] != %(_x)s->dimensions[0])
|| (%(_z)s->dimensions[1] != %(_y)s->dimensions[1])) || (%(_zout)s->dimensions[1] != %(_y)s->dimensions[1]))
{ {
if (NULL != %(_z)s) Py_XDECREF(%(_z)s); if (NULL != %(_zout)s) Py_XDECREF(%(_zout)s);
npy_intp dims[2]; npy_intp dims[2];
dims[0] = %(_x)s->dimensions[0]; dims[0] = %(_x)s->dimensions[0];
dims[1] = %(_y)s->dimensions[1]; dims[1] = %(_y)s->dimensions[1];
%(_z)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, type_num_%(_x)s); %(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, type_num_%(_x)s);
if(!%(_z)s) { if(!%(_zout)s) {
PyErr_SetString(PyExc_MemoryError, "failed to alloc dot22 output"); PyErr_SetString(PyExc_MemoryError, "failed to alloc dot22 output");
%(fail)s %(fail)s
} }
} }
Nz = %(_z)s->dimensions; Nz = %(_zout)s->dimensions;
Sz = %(_z)s->strides; Sz = %(_zout)s->strides;
""" """
check_ab_double_or_float = "" check_ab_double_or_float = ""
...@@ -711,9 +747,9 @@ class Dot22(GemmRelated): ...@@ -711,9 +747,9 @@ class Dot22(GemmRelated):
double a = 1.0; double a = 1.0;
double b = 0.0; double b = 0.0;
""" """
def c_code(self, node, name, (_x, _y), (_z, ), sub): #DEBUG def c_code(self, node, name, (_x, _y), (_zout, ), sub): #DEBUG
if len(self.c_libraries())<=0: if len(self.c_libraries())<=0:
return super(Dot22, self).c_code(node, name, (_x, _y), (_z, ), sub) return super(Dot22, self).c_code(node, name, (_x, _y), (_zout, ), sub)
full_code = self.build_gemm_call() % dict(locals(), **sub) full_code = self.build_gemm_call() % dict(locals(), **sub)
return full_code return full_code
def c_code_cache_version(self): def c_code_cache_version(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论