提交 84dd6ae8 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix copy of input when output is not c-contiguous.

Also make sure the output memory layout is suitable, otherwise it would be copied a second time.
上级 3a4e6c78
......@@ -845,9 +845,15 @@ class Gemm(GemmRelated):
setup_z_Nz_Sz_outplace = """
if ((NULL == %(_zout)s)
|| (%(_zout)s->dimensions[0] != %(_z)s->dimensions[0])
|| (%(_zout)s->dimensions[1] != %(_z)s->dimensions[1]))
|| (%(_zout)s->dimensions[1] != %(_z)s->dimensions[1])
|| (%(_zout)s->strides[0] <= 0)
|| (%(_zout)s->strides[1] <= 0)
|| (%(_zout)s->strides[0] MOD type_size)
|| (%(_zout)s->strides[1] MOD type_size)
|| ((%(_zout)s->strides[0] != type_size)
&& (%(_zout)s->strides[1] != type_size)))
{
if (%(_zout)s) Py_XDECREF(%(_zout)s);
Py_XDECREF(%(_zout)s);
npy_intp dims[2];
dims[0] = %(_z)s->dimensions[0];
dims[1] = %(_z)s->dimensions[1];
......@@ -862,42 +868,44 @@ class Gemm(GemmRelated):
}
Nz = %(_zout)s->dimensions;
Sz = %(_zout)s->strides;
if (1) // COPY z -> zout
if (%(_zout)s->descr->type_num == PyArray_FLOAT)
{
if (%(_zout)s->descr->type_num == PyArray_FLOAT)
float * zoutdata = (float*)%(_zout)s->data;
int zoi = Sz[0] / sizeof(float);
int zoj = Sz[1] / sizeof(float);
const float * zdata = (float*)%(_z)s->data;
int zi = %(_z)s->strides[0]/sizeof(float);
int zj = %(_z)s->strides[1]/sizeof(float);
for (int i = 0; i < Nz[0]; ++i)
{
float * zoutdata = (float*)%(_zout)s->data;
const float * zdata = (float*)%(_z)s->data;
int zi = %(_z)s->strides[0]/sizeof(float);
int zj = %(_z)s->strides[1]/sizeof(float);
for (int i = 0; i < Nz[0]; ++i)
for (int j = 0; j < Nz[1]; ++j)
{
for (int j = 0; j < Nz[1]; ++j)
{
zoutdata[i*Nz[1]+j] = zdata[zi*i+zj*j];
}
zoutdata[zoi*i + zoj*j] = zdata[zi*i + zj*j];
}
}
else if (%(_zout)s->descr->type_num == PyArray_DOUBLE)
}
else if (%(_zout)s->descr->type_num == PyArray_DOUBLE)
{
double * zoutdata = (double*) %(_zout)s->data;
int zoi = Sz[0] / sizeof(double);
int zoj = Sz[1] / sizeof(double);
const double * zdata = (double*)%(_z)s->data;
int zi = %(_z)s->strides[0]/sizeof(double);
int zj = %(_z)s->strides[1]/sizeof(double);
for (int i = 0; i < Nz[0]; ++i)
{
double * zoutdata = (double*) %(_zout)s->data;
const double * zdata = (double*)%(_z)s->data;
int zi = %(_z)s->strides[0]/sizeof(double);
int zj = %(_z)s->strides[1]/sizeof(double);
for (int i = 0; i < Nz[0]; ++i)
for (int j = 0; j < Nz[1]; ++j)
{
for (int j = 0; j < Nz[1]; ++j)
{
zoutdata[i*Nz[1]+j] = zdata[zi*i+zj*j];
}
zoutdata[zoi*i + zoj*j] = zdata[zi*i + zj*j];
}
}
else
{
PyErr_SetString(PyExc_AssertionError,
"neither float nor double dtype");
%(fail)s
}
}
else
{
PyErr_SetString(PyExc_AssertionError,
"neither float nor double dtype");
%(fail)s
}
"""
......@@ -938,7 +946,7 @@ class Gemm(GemmRelated):
def c_code_cache_version(self):
gv = self.build_gemm_version()
if gv:
return (3,) + gv
return (4,) + gv
else:
return gv
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论