提交 84dd6ae8 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix copy of input when output is not c-contiguous.

Also make sure the output memory layout is suitable, otherwise it would be copied a second time.
上级 3a4e6c78
...@@ -845,9 +845,15 @@ class Gemm(GemmRelated): ...@@ -845,9 +845,15 @@ class Gemm(GemmRelated):
setup_z_Nz_Sz_outplace = """ setup_z_Nz_Sz_outplace = """
if ((NULL == %(_zout)s) if ((NULL == %(_zout)s)
|| (%(_zout)s->dimensions[0] != %(_z)s->dimensions[0]) || (%(_zout)s->dimensions[0] != %(_z)s->dimensions[0])
|| (%(_zout)s->dimensions[1] != %(_z)s->dimensions[1])) || (%(_zout)s->dimensions[1] != %(_z)s->dimensions[1])
|| (%(_zout)s->strides[0] <= 0)
|| (%(_zout)s->strides[1] <= 0)
|| (%(_zout)s->strides[0] MOD type_size)
|| (%(_zout)s->strides[1] MOD type_size)
|| ((%(_zout)s->strides[0] != type_size)
&& (%(_zout)s->strides[1] != type_size)))
{ {
if (%(_zout)s) Py_XDECREF(%(_zout)s); Py_XDECREF(%(_zout)s);
npy_intp dims[2]; npy_intp dims[2];
dims[0] = %(_z)s->dimensions[0]; dims[0] = %(_z)s->dimensions[0];
dims[1] = %(_z)s->dimensions[1]; dims[1] = %(_z)s->dimensions[1];
...@@ -862,11 +868,12 @@ class Gemm(GemmRelated): ...@@ -862,11 +868,12 @@ class Gemm(GemmRelated):
} }
Nz = %(_zout)s->dimensions; Nz = %(_zout)s->dimensions;
Sz = %(_zout)s->strides; Sz = %(_zout)s->strides;
if (1) // COPY z -> zout
{
if (%(_zout)s->descr->type_num == PyArray_FLOAT) if (%(_zout)s->descr->type_num == PyArray_FLOAT)
{ {
float * zoutdata = (float*)%(_zout)s->data; float * zoutdata = (float*)%(_zout)s->data;
int zoi = Sz[0] / sizeof(float);
int zoj = Sz[1] / sizeof(float);
const float * zdata = (float*)%(_z)s->data; const float * zdata = (float*)%(_z)s->data;
int zi = %(_z)s->strides[0]/sizeof(float); int zi = %(_z)s->strides[0]/sizeof(float);
int zj = %(_z)s->strides[1]/sizeof(float); int zj = %(_z)s->strides[1]/sizeof(float);
...@@ -874,13 +881,15 @@ class Gemm(GemmRelated): ...@@ -874,13 +881,15 @@ class Gemm(GemmRelated):
{ {
for (int j = 0; j < Nz[1]; ++j) for (int j = 0; j < Nz[1]; ++j)
{ {
zoutdata[i*Nz[1]+j] = zdata[zi*i+zj*j]; zoutdata[zoi*i + zoj*j] = zdata[zi*i + zj*j];
} }
} }
} }
else if (%(_zout)s->descr->type_num == PyArray_DOUBLE) else if (%(_zout)s->descr->type_num == PyArray_DOUBLE)
{ {
double * zoutdata = (double*) %(_zout)s->data; double * zoutdata = (double*) %(_zout)s->data;
int zoi = Sz[0] / sizeof(double);
int zoj = Sz[1] / sizeof(double);
const double * zdata = (double*)%(_z)s->data; const double * zdata = (double*)%(_z)s->data;
int zi = %(_z)s->strides[0]/sizeof(double); int zi = %(_z)s->strides[0]/sizeof(double);
int zj = %(_z)s->strides[1]/sizeof(double); int zj = %(_z)s->strides[1]/sizeof(double);
...@@ -888,7 +897,7 @@ class Gemm(GemmRelated): ...@@ -888,7 +897,7 @@ class Gemm(GemmRelated):
{ {
for (int j = 0; j < Nz[1]; ++j) for (int j = 0; j < Nz[1]; ++j)
{ {
zoutdata[i*Nz[1]+j] = zdata[zi*i+zj*j]; zoutdata[zoi*i + zoj*j] = zdata[zi*i + zj*j];
} }
} }
} }
...@@ -898,7 +907,6 @@ class Gemm(GemmRelated): ...@@ -898,7 +907,6 @@ class Gemm(GemmRelated):
"neither float nor double dtype"); "neither float nor double dtype");
%(fail)s %(fail)s
} }
}
""" """
case_float_ab_constants = """ case_float_ab_constants = """
...@@ -938,7 +946,7 @@ class Gemm(GemmRelated): ...@@ -938,7 +946,7 @@ class Gemm(GemmRelated):
def c_code_cache_version(self): def c_code_cache_version(self):
gv = self.build_gemm_version() gv = self.build_gemm_version()
if gv: if gv:
return (3,) + gv return (4,) + gv
else: else:
return gv return gv
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论