提交 688c897f authored 作者: Rebecca N. Palmer's avatar Rebecca N. Palmer

Use 64-bit indices in sparse.AddSD where necessary (fix #5525)

上级 f3844589
......@@ -137,28 +137,29 @@ class AddSD_ccode(gof.op.Op):
}
npy_intp N = PyArray_DIMS(%(_indptr)s)[0]-1;
const npy_int32 * __restrict__ indptr = (npy_int32 *)PyArray_DATA(%(_indptr)s);
const npy_int32 * __restrict__ indices = (npy_int32*)PyArray_DATA(%(_indices)s);
const dtype_%(_indptr)s* __restrict__ indptr = (dtype_%(_indptr)s*)PyArray_DATA(%(_indptr)s);
const dtype_%(_indices)s* __restrict__ indices = (dtype_%(_indices)s*)PyArray_DATA(%(_indices)s);
const dtype_%(_data)s* __restrict__ data = (dtype_%(_data)s*)PyArray_DATA(%(_data)s);
dtype_%(y)s* ydata = (dtype_%(y)s*)PyArray_DATA(%(y)s);
dtype_%(z)s* zdata = (dtype_%(z)s*)PyArray_DATA(%(z)s);
int Yi = PyArray_STRIDES(%(y)s)[0]/PyArray_DESCR(%(y)s)->elsize;
int Yj = PyArray_STRIDES(%(y)s)[1]/PyArray_DESCR(%(y)s)->elsize;
npy_intp Yi = PyArray_STRIDES(%(y)s)[0]/PyArray_DESCR(%(y)s)->elsize;
npy_intp Yj = PyArray_STRIDES(%(y)s)[1]/PyArray_DESCR(%(y)s)->elsize;
npy_int32 pos;
npy_intp pos;
if (%(format)s == 0){
for (npy_int32 col = 0; col < N; ++col){
for (npy_int32 ind = indptr[col]; ind < indptr[col+1]; ++ind){
npy_int32 row = indices[ind];
for (npy_intp col = 0; col < N; ++col){
for (dtype_%(_indptr)s ind = indptr[col]; ind < indptr[col+1]; ++ind){
npy_intp row = indices[ind];
pos = row * Yi + col * Yj;
zdata[pos] = ydata[pos] + data[ind];
}
}
}else{
for (npy_int32 row = 0; row < N; ++row){
for (npy_int32 ind = indptr[row]; ind < indptr[row+1]; ++ind){
npy_int32 col = indices[ind];
for (npy_intp row = 0; row < N; ++row){
for (dtype_%(_indptr)s ind = indptr[row]; ind < indptr[row+1]; ++ind){
npy_intp col = indices[ind];
pos = row * Yi + col * Yj;
zdata[pos] = ydata[pos] + data[ind];
}
......@@ -171,7 +172,7 @@ class AddSD_ccode(gof.op.Op):
return [shapes[3]]
def c_code_cache_version(self):
return (1,)
return (2,)
@gof.local_optimizer([sparse.AddSD])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论