提交 a5256ba2 authored 作者: James Bergstra's avatar James Bergstra

Important bug-fix to GEMM since the inplace-gemm modification a few days ago.

上级 7c84cbe2
......@@ -102,11 +102,11 @@ class GemmRelated(Op):
npy_intp* Nx = %(_x)s->dimensions;
npy_intp* Ny = %(_y)s->dimensions;
npy_intp* Nz = 0; //%(_z)s->dimensions;
npy_intp* Nz = 0; //%(_zout)s->dimensions;
npy_intp* Sx = %(_x)s->strides;
npy_intp* Sy = %(_y)s->strides;
npy_intp* Sz = 0; //%(_z)s->strides;
npy_intp* Sz = 0; //%(_zout)s->strides;
//strides for x, y, z in dimensions 0, 1
int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
......@@ -117,7 +117,7 @@ class GemmRelated(Op):
check_xyz_rank2 = """
if (%(_x)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); %(fail)s;}
if (%(_y)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); %(fail)s;}
if (%(_z)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(z) != 2"); %(fail)s;}
if (%(_zout)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(z) != 2"); %(fail)s;}
"""
check_xyz_double_or_float = """
if ((%(_x)s->descr->type_num != PyArray_DOUBLE)
......@@ -128,12 +128,12 @@ class GemmRelated(Op):
&& (%(_y)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;}
if ((%(_z)s->descr->type_num != PyArray_DOUBLE)
&& (%(_z)s->descr->type_num != PyArray_FLOAT))
if ((%(_zout)s->descr->type_num != PyArray_DOUBLE)
&& (%(_zout)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;}
if ((%(_x)s->descr->type_num != %(_y)s->descr->type_num)
||(%(_x)s->descr->type_num != %(_z)s->descr->type_num))
||(%(_x)s->descr->type_num != %(_zout)s->descr->type_num))
{ PyErr_SetString(PyExc_NotImplementedError, "type(z), type(y), type(z) are not all the same"); %(fail)s; }
"""
......@@ -164,7 +164,7 @@ class GemmRelated(Op):
encode_strides_in_unit = """
/*
encode the stride structure of _x,_y,_z into a single integer
encode the stride structure of _x,_y,_zout into a single integer
*/
unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) << 8;
unit |= ((Sy[1] == type_size) ? 0x0 : (Sy[0] == type_size) ? 0x1 : 0x2) << 4;
......@@ -198,7 +198,7 @@ class GemmRelated(Op):
case_float_gemm = """
float* x = (float*)PyArray_DATA(%(_x)s);
float* y = (float*)PyArray_DATA(%(_y)s);
float* z = (float*)PyArray_DATA(%(_z)s);
float* z = (float*)PyArray_DATA(%(_zout)s);
char N = 'N';
char T = 'T';
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
......@@ -229,7 +229,7 @@ class GemmRelated(Op):
case_double_gemm = """
double* x = (double*)PyArray_DATA(%(_x)s);
double* y = (double*)PyArray_DATA(%(_y)s);
double* z = (double*)PyArray_DATA(%(_z)s);
double* z = (double*)PyArray_DATA(%(_zout)s);
char N = 'N';
char T = 'T';
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
......@@ -393,7 +393,7 @@ class Gemm(GemmRelated):
|| (%(_zout)s->dimensions[0] != %(_z)s->dimensions[0])
|| (%(_zout)s->dimensions[1] != %(_z)s->dimensions[1]))
{
if (NULL != %(_zout)s) Py_XDECREF(%(_zout)s);
if (%(_zout)s) Py_XDECREF(%(_zout)s);
npy_intp dims[2];
dims[0] = %(_z)s->dimensions[0];
dims[1] = %(_z)s->dimensions[1];
......@@ -405,6 +405,42 @@ class Gemm(GemmRelated):
}
Nz = %(_zout)s->dimensions;
Sz = %(_zout)s->strides;
if (1) // COPY z -> zout
{
if (%(_zout)s->descr->type_num == PyArray_FLOAT)
{
float * zoutdata = (float*)%(_zout)s->data;
const float * zdata = (float*)%(_z)s->data;
int zi = %(_z)s->strides[0]/sizeof(float);
int zj = %(_z)s->strides[1]/sizeof(float);
for (int i = 0; i < Nz[0]; ++i)
{
for (int j = 0; j < Nz[1]; ++j)
{
zoutdata[i*Nz[1]+j] = zdata[zi*i+zj*j];
}
}
}
else if (%(_zout)s->descr->type_num == PyArray_DOUBLE)
{
double * zoutdata = (double*) %(_zout)s->data;
const double * zdata = (double*)%(_z)s->data;
int zi = %(_z)s->strides[0]/sizeof(double);
int zj = %(_z)s->strides[1]/sizeof(double);
for (int i = 0; i < Nz[0]; ++i)
{
for (int j = 0; j < Nz[1]; ++j)
{
zoutdata[i*Nz[1]+j] = zdata[zi*i+zj*j];
}
}
}
else
{
PyErr_SetString(PyExc_AssertionError, "neither float nor double dtype");
%(fail)s
}
}
"""
case_float_ab_constants = """
......@@ -435,7 +471,7 @@ class Gemm(GemmRelated):
return full_code
def c_code_cache_version(self):
return (2,) + self.build_gemm_version()
return (3,) + self.build_gemm_version()
gemm_inplace = Gemm(inplace=True)
gemm_no_inplace = Gemm(inplace=False)
......@@ -684,22 +720,22 @@ class Dot22(GemmRelated):
return "_dot22"
setup_z_Nz_Sz = """
if ((NULL == %(_z)s)
|| (%(_z)s->dimensions[0] != %(_x)s->dimensions[0])
|| (%(_z)s->dimensions[1] != %(_y)s->dimensions[1]))
if ((NULL == %(_zout)s)
|| (%(_zout)s->dimensions[0] != %(_x)s->dimensions[0])
|| (%(_zout)s->dimensions[1] != %(_y)s->dimensions[1]))
{
if (NULL != %(_z)s) Py_XDECREF(%(_z)s);
if (NULL != %(_zout)s) Py_XDECREF(%(_zout)s);
npy_intp dims[2];
dims[0] = %(_x)s->dimensions[0];
dims[1] = %(_y)s->dimensions[1];
%(_z)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, type_num_%(_x)s);
if(!%(_z)s) {
%(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, type_num_%(_x)s);
if(!%(_zout)s) {
PyErr_SetString(PyExc_MemoryError, "failed to alloc dot22 output");
%(fail)s
}
}
Nz = %(_z)s->dimensions;
Sz = %(_z)s->strides;
Nz = %(_zout)s->dimensions;
Sz = %(_zout)s->strides;
"""
check_ab_double_or_float = ""
......@@ -711,9 +747,9 @@ class Dot22(GemmRelated):
double a = 1.0;
double b = 0.0;
"""
def c_code(self, node, name, (_x, _y), (_z, ), sub): #DEBUG
def c_code(self, node, name, (_x, _y), (_zout, ), sub): #DEBUG
if len(self.c_libraries())<=0:
return super(Dot22, self).c_code(node, name, (_x, _y), (_z, ), sub)
return super(Dot22, self).c_code(node, name, (_x, _y), (_zout, ), sub)
full_code = self.build_gemm_call() % dict(locals(), **sub)
return full_code
def c_code_cache_version(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论