提交 589f0869 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix op for non-contiguous input and output

上级 0299784f
......@@ -2411,8 +2411,6 @@ class UsmmCscDense(gof.Op):
npy_intp K = %(y)s->dimensions[0];
// pointers to access actual data in the arrays passed as params.
dtype_%(z)s* __restrict__ Dz = (dtype_%(z)s*)%(z)s->data;
dtype_%(zn)s* __restrict__ Dzn = (dtype_%(zn)s*)%(zn)s->data;
const dtype_%(x_val)s* __restrict__ Dval = (dtype_%(x_val)s*)%(x_val)s->data;
const npy_int32 * __restrict__ Dind = (npy_int32*)%(x_ind)s->data;
const npy_int32 * __restrict__ Dptr = (npy_int32*)%(x_ptr)s->data;
......@@ -2428,7 +2426,11 @@ class UsmmCscDense(gof.Op):
if (!(%(inplace)s))
{
memcpy(Dzn, Dz, M*N*sizeof(dtype_%(zn)s));
if (PyArray_CopyInto(%(zn)s, %(z)s))
{
Py_XDECREF(%(zn)s);
%(fail)s;
}
}
for (npy_int32 k = 0; k < K; ++k)
......@@ -2439,9 +2441,16 @@ class UsmmCscDense(gof.Op):
const dtype_%(x_val)s Amk = alpha * Dval[m_idx * Sval]; // actual value at that location
const dtype_%(y)s* y_row = (dtype_%(y)s*)(%(y)s->data + %(y)s->strides[0] * k);
dtype_%(y)s* y_row = (dtype_%(y)s*)(%(y)s->data + %(y)s->strides[0] * k);
// axpy expects pointer to the beginning of memory arrays,
// so when the stride is negative, we need to get the
// last element
if (Sy < 0)
y_row += (K - 1) * Sy;
const dtype_%(zn)s* z_row = (dtype_%(zn)s*)(%(zn)s->data + %(zn)s->strides[0] * m);
dtype_%(zn)s* z_row = (dtype_%(zn)s*)(%(zn)s->data + %(zn)s->strides[0] * m);
if (Szn < 0)
z_row += (N - 1) * Szn;
%(axpy)s((int*)&N, (%(conv_type)s*)&Amk, (%(conv_type)s*)y_row, (int*)&Sy, (%(conv_type)s*)z_row, (int*)&Szn);
}
......@@ -2451,6 +2460,9 @@ class UsmmCscDense(gof.Op):
return rval
def c_code_cache_version(self):
return (1,)
usmm_csc_dense = UsmmCscDense(inplace=False)
usmm_csc_dense_inplace = UsmmCscDense(inplace=True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论