提交 688c897f authored 作者: Rebecca N. Palmer's avatar Rebecca N. Palmer

Use 64-bit indices in sparse.AddSD where necessary (fix #5525)

上级 f3844589
...@@ -137,28 +137,29 @@ class AddSD_ccode(gof.op.Op): ...@@ -137,28 +137,29 @@ class AddSD_ccode(gof.op.Op):
} }
npy_intp N = PyArray_DIMS(%(_indptr)s)[0]-1; npy_intp N = PyArray_DIMS(%(_indptr)s)[0]-1;
const npy_int32 * __restrict__ indptr = (npy_int32 *)PyArray_DATA(%(_indptr)s);
const npy_int32 * __restrict__ indices = (npy_int32*)PyArray_DATA(%(_indices)s); const dtype_%(_indptr)s* __restrict__ indptr = (dtype_%(_indptr)s*)PyArray_DATA(%(_indptr)s);
const dtype_%(_indices)s* __restrict__ indices = (dtype_%(_indices)s*)PyArray_DATA(%(_indices)s);
const dtype_%(_data)s* __restrict__ data = (dtype_%(_data)s*)PyArray_DATA(%(_data)s); const dtype_%(_data)s* __restrict__ data = (dtype_%(_data)s*)PyArray_DATA(%(_data)s);
dtype_%(y)s* ydata = (dtype_%(y)s*)PyArray_DATA(%(y)s); dtype_%(y)s* ydata = (dtype_%(y)s*)PyArray_DATA(%(y)s);
dtype_%(z)s* zdata = (dtype_%(z)s*)PyArray_DATA(%(z)s); dtype_%(z)s* zdata = (dtype_%(z)s*)PyArray_DATA(%(z)s);
int Yi = PyArray_STRIDES(%(y)s)[0]/PyArray_DESCR(%(y)s)->elsize; npy_intp Yi = PyArray_STRIDES(%(y)s)[0]/PyArray_DESCR(%(y)s)->elsize;
int Yj = PyArray_STRIDES(%(y)s)[1]/PyArray_DESCR(%(y)s)->elsize; npy_intp Yj = PyArray_STRIDES(%(y)s)[1]/PyArray_DESCR(%(y)s)->elsize;
npy_int32 pos; npy_intp pos;
if (%(format)s == 0){ if (%(format)s == 0){
for (npy_int32 col = 0; col < N; ++col){ for (npy_intp col = 0; col < N; ++col){
for (npy_int32 ind = indptr[col]; ind < indptr[col+1]; ++ind){ for (dtype_%(_indptr)s ind = indptr[col]; ind < indptr[col+1]; ++ind){
npy_int32 row = indices[ind]; npy_intp row = indices[ind];
pos = row * Yi + col * Yj; pos = row * Yi + col * Yj;
zdata[pos] = ydata[pos] + data[ind]; zdata[pos] = ydata[pos] + data[ind];
} }
} }
}else{ }else{
for (npy_int32 row = 0; row < N; ++row){ for (npy_intp row = 0; row < N; ++row){
for (npy_int32 ind = indptr[row]; ind < indptr[row+1]; ++ind){ for (dtype_%(_indptr)s ind = indptr[row]; ind < indptr[row+1]; ++ind){
npy_int32 col = indices[ind]; npy_intp col = indices[ind];
pos = row * Yi + col * Yj; pos = row * Yi + col * Yj;
zdata[pos] = ydata[pos] + data[ind]; zdata[pos] = ydata[pos] + data[ind];
} }
...@@ -171,7 +172,7 @@ class AddSD_ccode(gof.op.Op): ...@@ -171,7 +172,7 @@ class AddSD_ccode(gof.op.Op):
return [shapes[3]] return [shapes[3]]
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
@gof.local_optimizer([sparse.AddSD]) @gof.local_optimizer([sparse.AddSD])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论