提交 84dd6ae8 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix copy of input when output is not c-contiguous.

Also make sure the output memory layout is suitable, otherwise it would be copied a second time.
上级 3a4e6c78
...@@ -845,9 +845,15 @@ class Gemm(GemmRelated): ...@@ -845,9 +845,15 @@ class Gemm(GemmRelated):
setup_z_Nz_Sz_outplace = """ setup_z_Nz_Sz_outplace = """
if ((NULL == %(_zout)s) if ((NULL == %(_zout)s)
|| (%(_zout)s->dimensions[0] != %(_z)s->dimensions[0]) || (%(_zout)s->dimensions[0] != %(_z)s->dimensions[0])
|| (%(_zout)s->dimensions[1] != %(_z)s->dimensions[1])) || (%(_zout)s->dimensions[1] != %(_z)s->dimensions[1])
|| (%(_zout)s->strides[0] <= 0)
|| (%(_zout)s->strides[1] <= 0)
|| (%(_zout)s->strides[0] MOD type_size)
|| (%(_zout)s->strides[1] MOD type_size)
|| ((%(_zout)s->strides[0] != type_size)
&& (%(_zout)s->strides[1] != type_size)))
{ {
if (%(_zout)s) Py_XDECREF(%(_zout)s); Py_XDECREF(%(_zout)s);
npy_intp dims[2]; npy_intp dims[2];
dims[0] = %(_z)s->dimensions[0]; dims[0] = %(_z)s->dimensions[0];
dims[1] = %(_z)s->dimensions[1]; dims[1] = %(_z)s->dimensions[1];
...@@ -862,42 +868,44 @@ class Gemm(GemmRelated): ...@@ -862,42 +868,44 @@ class Gemm(GemmRelated):
} }
Nz = %(_zout)s->dimensions; Nz = %(_zout)s->dimensions;
Sz = %(_zout)s->strides; Sz = %(_zout)s->strides;
if (1) // COPY z -> zout
if (%(_zout)s->descr->type_num == PyArray_FLOAT)
{ {
if (%(_zout)s->descr->type_num == PyArray_FLOAT) float * zoutdata = (float*)%(_zout)s->data;
int zoi = Sz[0] / sizeof(float);
int zoj = Sz[1] / sizeof(float);
const float * zdata = (float*)%(_z)s->data;
int zi = %(_z)s->strides[0]/sizeof(float);
int zj = %(_z)s->strides[1]/sizeof(float);
for (int i = 0; i < Nz[0]; ++i)
{ {
float * zoutdata = (float*)%(_zout)s->data; for (int j = 0; j < Nz[1]; ++j)
const float * zdata = (float*)%(_z)s->data;
int zi = %(_z)s->strides[0]/sizeof(float);
int zj = %(_z)s->strides[1]/sizeof(float);
for (int i = 0; i < Nz[0]; ++i)
{ {
for (int j = 0; j < Nz[1]; ++j) zoutdata[zoi*i + zoj*j] = zdata[zi*i + zj*j];
{
zoutdata[i*Nz[1]+j] = zdata[zi*i+zj*j];
}
} }
} }
else if (%(_zout)s->descr->type_num == PyArray_DOUBLE) }
else if (%(_zout)s->descr->type_num == PyArray_DOUBLE)
{
double * zoutdata = (double*) %(_zout)s->data;
int zoi = Sz[0] / sizeof(double);
int zoj = Sz[1] / sizeof(double);
const double * zdata = (double*)%(_z)s->data;
int zi = %(_z)s->strides[0]/sizeof(double);
int zj = %(_z)s->strides[1]/sizeof(double);
for (int i = 0; i < Nz[0]; ++i)
{ {
double * zoutdata = (double*) %(_zout)s->data; for (int j = 0; j < Nz[1]; ++j)
const double * zdata = (double*)%(_z)s->data;
int zi = %(_z)s->strides[0]/sizeof(double);
int zj = %(_z)s->strides[1]/sizeof(double);
for (int i = 0; i < Nz[0]; ++i)
{ {
for (int j = 0; j < Nz[1]; ++j) zoutdata[zoi*i + zoj*j] = zdata[zi*i + zj*j];
{
zoutdata[i*Nz[1]+j] = zdata[zi*i+zj*j];
}
} }
} }
else }
{ else
PyErr_SetString(PyExc_AssertionError, {
"neither float nor double dtype"); PyErr_SetString(PyExc_AssertionError,
%(fail)s "neither float nor double dtype");
} %(fail)s
} }
""" """
...@@ -938,7 +946,7 @@ class Gemm(GemmRelated): ...@@ -938,7 +946,7 @@ class Gemm(GemmRelated):
def c_code_cache_version(self): def c_code_cache_version(self):
gv = self.build_gemm_version() gv = self.build_gemm_version()
if gv: if gv:
return (3,) + gv return (4,) + gv
else: else:
return gv return gv
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论