提交 f141420e authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5577 from rebecca-palmer/sparseadd_64bit_indexing

Use 64-bit indices in sparse.AddSD where necessary (fix #5525)
...@@ -137,28 +137,29 @@ class AddSD_ccode(gof.op.Op): ...@@ -137,28 +137,29 @@ class AddSD_ccode(gof.op.Op):
} }
npy_intp N = PyArray_DIMS(%(_indptr)s)[0]-1; npy_intp N = PyArray_DIMS(%(_indptr)s)[0]-1;
const npy_int32 * __restrict__ indptr = (npy_int32 *)PyArray_DATA(%(_indptr)s);
const npy_int32 * __restrict__ indices = (npy_int32*)PyArray_DATA(%(_indices)s); const dtype_%(_indptr)s* __restrict__ indptr = (dtype_%(_indptr)s*)PyArray_DATA(%(_indptr)s);
const dtype_%(_indices)s* __restrict__ indices = (dtype_%(_indices)s*)PyArray_DATA(%(_indices)s);
const dtype_%(_data)s* __restrict__ data = (dtype_%(_data)s*)PyArray_DATA(%(_data)s); const dtype_%(_data)s* __restrict__ data = (dtype_%(_data)s*)PyArray_DATA(%(_data)s);
dtype_%(y)s* ydata = (dtype_%(y)s*)PyArray_DATA(%(y)s); dtype_%(y)s* ydata = (dtype_%(y)s*)PyArray_DATA(%(y)s);
dtype_%(z)s* zdata = (dtype_%(z)s*)PyArray_DATA(%(z)s); dtype_%(z)s* zdata = (dtype_%(z)s*)PyArray_DATA(%(z)s);
int Yi = PyArray_STRIDES(%(y)s)[0]/PyArray_DESCR(%(y)s)->elsize; npy_intp Yi = PyArray_STRIDES(%(y)s)[0]/PyArray_DESCR(%(y)s)->elsize;
int Yj = PyArray_STRIDES(%(y)s)[1]/PyArray_DESCR(%(y)s)->elsize; npy_intp Yj = PyArray_STRIDES(%(y)s)[1]/PyArray_DESCR(%(y)s)->elsize;
npy_int32 pos; npy_intp pos;
if (%(format)s == 0){ if (%(format)s == 0){
for (npy_int32 col = 0; col < N; ++col){ for (npy_intp col = 0; col < N; ++col){
for (npy_int32 ind = indptr[col]; ind < indptr[col+1]; ++ind){ for (dtype_%(_indptr)s ind = indptr[col]; ind < indptr[col+1]; ++ind){
npy_int32 row = indices[ind]; npy_intp row = indices[ind];
pos = row * Yi + col * Yj; pos = row * Yi + col * Yj;
zdata[pos] = ydata[pos] + data[ind]; zdata[pos] = ydata[pos] + data[ind];
} }
} }
}else{ }else{
for (npy_int32 row = 0; row < N; ++row){ for (npy_intp row = 0; row < N; ++row){
for (npy_int32 ind = indptr[row]; ind < indptr[row+1]; ++ind){ for (dtype_%(_indptr)s ind = indptr[row]; ind < indptr[row+1]; ++ind){
npy_int32 col = indices[ind]; npy_intp col = indices[ind];
pos = row * Yi + col * Yj; pos = row * Yi + col * Yj;
zdata[pos] = ydata[pos] + data[ind]; zdata[pos] = ydata[pos] + data[ind];
} }
...@@ -171,7 +172,7 @@ class AddSD_ccode(gof.op.Op): ...@@ -171,7 +172,7 @@ class AddSD_ccode(gof.op.Op):
return [shapes[3]] return [shapes[3]]
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
@gof.local_optimizer([sparse.AddSD]) @gof.local_optimizer([sparse.AddSD])
...@@ -336,6 +337,8 @@ class StructuredDotCSC(gof.Op): ...@@ -336,6 +337,8 @@ class StructuredDotCSC(gof.Op):
npy_intp M = PyArray_DIMS(%(z)s)[0]; npy_intp M = PyArray_DIMS(%(z)s)[0];
npy_intp N = PyArray_DIMS(%(z)s)[1]; npy_intp N = PyArray_DIMS(%(z)s)[1];
npy_intp K = PyArray_DIMS(%(b)s)[0]; npy_intp K = PyArray_DIMS(%(b)s)[0];
if (N > 0x7fffffffL)
{PyErr_SetString(PyExc_NotImplementedError, "array too big (overflows int32 index)"); %(fail)s;}
// strides tell you how many bytes to skip to go to next column/row entry // strides tell you how many bytes to skip to go to next column/row entry
npy_intp Szm = PyArray_STRIDES(%(z)s)[0] / PyArray_DESCR(%(z)s)->elsize; npy_intp Szm = PyArray_STRIDES(%(z)s)[0] / PyArray_DESCR(%(z)s)->elsize;
...@@ -413,7 +416,7 @@ class StructuredDotCSC(gof.Op): ...@@ -413,7 +416,7 @@ class StructuredDotCSC(gof.Op):
return rval return rval
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (3,)
sd_csc = StructuredDotCSC() sd_csc = StructuredDotCSC()
...@@ -529,6 +532,8 @@ class StructuredDotCSR(gof.Op): ...@@ -529,6 +532,8 @@ class StructuredDotCSR(gof.Op):
npy_intp M = PyArray_DIMS(%(z)s)[0]; npy_intp M = PyArray_DIMS(%(z)s)[0];
npy_intp N = PyArray_DIMS(%(z)s)[1]; npy_intp N = PyArray_DIMS(%(z)s)[1];
npy_intp K = PyArray_DIMS(%(b)s)[0]; npy_intp K = PyArray_DIMS(%(b)s)[0];
if (N > 0x7fffffffL)
{PyErr_SetString(PyExc_NotImplementedError, "array too big (overflows int32 index)"); %(fail)s;}
// strides tell you how many bytes to skip to go to next column/row entry // strides tell you how many bytes to skip to go to next column/row entry
npy_intp Szm = PyArray_STRIDES(%(z)s)[0] / PyArray_DESCR(%(z)s)->elsize; npy_intp Szm = PyArray_STRIDES(%(z)s)[0] / PyArray_DESCR(%(z)s)->elsize;
...@@ -590,7 +595,7 @@ class StructuredDotCSR(gof.Op): ...@@ -590,7 +595,7 @@ class StructuredDotCSR(gof.Op):
""" % dict(locals(), **sub) """ % dict(locals(), **sub)
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
sd_csr = StructuredDotCSR() sd_csr = StructuredDotCSR()
...@@ -832,6 +837,8 @@ class UsmmCscDense(gof.Op): ...@@ -832,6 +837,8 @@ class UsmmCscDense(gof.Op):
npy_intp Sy = PyArray_STRIDES(%(y)s)[1] / PyArray_DESCR(%(y)s)->elsize; npy_intp Sy = PyArray_STRIDES(%(y)s)[1] / PyArray_DESCR(%(y)s)->elsize;
// blas expects ints; convert here (rather than just making N etc ints) to avoid potential overflow in the negative-stride correction // blas expects ints; convert here (rather than just making N etc ints) to avoid potential overflow in the negative-stride correction
if ((N > 0x7fffffffL)||(Sy > 0x7fffffffL)||(Szn > 0x7fffffffL)||(Sy < -0x7fffffffL)||(Szn < -0x7fffffffL))
{PyErr_SetString(PyExc_NotImplementedError, "array too big for BLAS (overflows int32 index)"); %(fail)s;}
int N32 = N; int N32 = N;
int Sy32 = Sy; int Sy32 = Sy;
int Szn32 = Szn; int Szn32 = Szn;
...@@ -845,7 +852,7 @@ class UsmmCscDense(gof.Op): ...@@ -845,7 +852,7 @@ class UsmmCscDense(gof.Op):
} }
} }
for (npy_int32 k = 0; k < K; ++k) for (npy_intp k = 0; k < K; ++k)
{ {
for (npy_int32 m_idx = Dptr[k * Sptr]; m_idx < Dptr[(k+1)*Sptr]; ++m_idx) for (npy_int32 m_idx = Dptr[k * Sptr]; m_idx < Dptr[(k+1)*Sptr]; ++m_idx)
{ {
...@@ -873,7 +880,7 @@ class UsmmCscDense(gof.Op): ...@@ -873,7 +880,7 @@ class UsmmCscDense(gof.Op):
return rval return rval
def c_code_cache_version(self): def c_code_cache_version(self):
return (2, blas.blas_header_version()) return (3, blas.blas_header_version())
usmm_csc_dense = UsmmCscDense(inplace=False) usmm_csc_dense = UsmmCscDense(inplace=False)
usmm_csc_dense_inplace = UsmmCscDense(inplace=True) usmm_csc_dense_inplace = UsmmCscDense(inplace=True)
...@@ -1106,7 +1113,7 @@ class MulSDCSC(gof.Op): ...@@ -1106,7 +1113,7 @@ class MulSDCSC(gof.Op):
[tensor.tensor(b.dtype, (False,))]) [tensor.tensor(b.dtype, (False,))])
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (3,)
# def perform(self, node, (a_data, a_indices, a_indptr, b), (out,)): # def perform(self, node, (a_data, a_indices, a_indptr, b), (out,)):
# return NotImplementedError() # return NotImplementedError()
...@@ -1169,7 +1176,7 @@ class MulSDCSC(gof.Op): ...@@ -1169,7 +1176,7 @@ class MulSDCSC(gof.Op):
const npy_intp Sb = PyArray_STRIDES(%(_b)s)[0]; const npy_intp Sb = PyArray_STRIDES(%(_b)s)[0];
// loop over columns // loop over columns
for (npy_int32 j = 0; j < N; ++j) for (npy_intp j = 0; j < N; ++j)
{ {
// for each non-null value in the sparse column // for each non-null value in the sparse column
for (npy_int32 i_idx = indptr[j]; i_idx < indptr[j+1]; ++i_idx) for (npy_int32 i_idx = indptr[j]; i_idx < indptr[j+1]; ++i_idx)
...@@ -1232,7 +1239,7 @@ class MulSDCSR(gof.Op): ...@@ -1232,7 +1239,7 @@ class MulSDCSR(gof.Op):
[tensor.tensor(b.dtype, (False,))]) [tensor.tensor(b.dtype, (False,))])
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (3,)
# def perform(self, node, (a_data, a_indices, a_indptr, b), (out,)): # def perform(self, node, (a_data, a_indices, a_indptr, b), (out,)):
# return NotImplemented() # return NotImplemented()
...@@ -1295,7 +1302,7 @@ class MulSDCSR(gof.Op): ...@@ -1295,7 +1302,7 @@ class MulSDCSR(gof.Op):
const npy_intp Sb = PyArray_STRIDES(%(_b)s)[0]; const npy_intp Sb = PyArray_STRIDES(%(_b)s)[0];
// loop over columns // loop over columns
for (npy_int32 j = 0; j < N; ++j) for (npy_intp j = 0; j < N; ++j)
{ {
// extract i-th row of dense matrix // extract i-th row of dense matrix
const dtype_%(_b)s* __restrict__ b_row = (dtype_%(_b)s*)(PyArray_BYTES(%(_b)s) + Sb * j); const dtype_%(_b)s* __restrict__ b_row = (dtype_%(_b)s*)(PyArray_BYTES(%(_b)s) + Sb * j);
...@@ -1400,7 +1407,7 @@ class MulSVCSR(gof.Op): ...@@ -1400,7 +1407,7 @@ class MulSVCSR(gof.Op):
[tensor.tensor(b.dtype, (False,))]) [tensor.tensor(b.dtype, (False,))])
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
_data, _indices, _indptr, _b, = inputs _data, _indices, _indptr, _b, = inputs
...@@ -1459,7 +1466,7 @@ class MulSVCSR(gof.Op): ...@@ -1459,7 +1466,7 @@ class MulSVCSR(gof.Op):
const npy_intp Sb = PyArray_STRIDES(%(_b)s)[0] / PyArray_DESCR(%(_b)s)->elsize; const npy_intp Sb = PyArray_STRIDES(%(_b)s)[0] / PyArray_DESCR(%(_b)s)->elsize;
// loop over rows // loop over rows
for (npy_int32 j = 0; j < N; ++j) for (npy_intp j = 0; j < N; ++j)
{ {
// for each non-null value in the sparse column // for each non-null value in the sparse column
for (npy_int32 i_idx = indptr[j]; i_idx < indptr[j+1]; ++i_idx) for (npy_int32 i_idx = indptr[j]; i_idx < indptr[j+1]; ++i_idx)
...@@ -1558,7 +1565,7 @@ class StructuredAddSVCSR(gof.Op): ...@@ -1558,7 +1565,7 @@ class StructuredAddSVCSR(gof.Op):
[tensor.tensor(b.dtype, (False,))]) [tensor.tensor(b.dtype, (False,))])
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (3,)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
_data, _indices, _indptr, _b, = inputs _data, _indices, _indptr, _b, = inputs
...@@ -1623,7 +1630,7 @@ class StructuredAddSVCSR(gof.Op): ...@@ -1623,7 +1630,7 @@ class StructuredAddSVCSR(gof.Op):
const npy_intp Sb = PyArray_STRIDES(%(_b)s)[0] / PyArray_DESCR(%(_b)s)->elsize; const npy_intp Sb = PyArray_STRIDES(%(_b)s)[0] / PyArray_DESCR(%(_b)s)->elsize;
// loop over columns // loop over columns
for (npy_int32 j = 0; j < N; ++j) for (npy_intp j = 0; j < N; ++j)
{ {
// for each non-null value in the sparse column // for each non-null value in the sparse column
for (npy_int32 i_idx = indptr[j]; i_idx < indptr[j+1]; ++i_idx) for (npy_int32 i_idx = indptr[j]; i_idx < indptr[j+1]; ++i_idx)
...@@ -1753,7 +1760,7 @@ class SamplingDotCSR(gof.Op): ...@@ -1753,7 +1760,7 @@ class SamplingDotCSR(gof.Op):
]) ])
def c_code_cache_version(self): def c_code_cache_version(self):
return (3, blas.blas_header_version()) return (4, blas.blas_header_version())
def c_support_code(self): def c_support_code(self):
return blas.blas_header_text() return blas.blas_header_text()
...@@ -1897,11 +1904,13 @@ PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); %(fail)s;} ...@@ -1897,11 +1904,13 @@ PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); %(fail)s;}
memcpy(Dzp, Dpp, PyArray_DIMS(%(p_ptr)s)[0]*sizeof(dtype_%(p_ptr)s)); memcpy(Dzp, Dpp, PyArray_DIMS(%(p_ptr)s)[0]*sizeof(dtype_%(p_ptr)s));
// blas expects ints; convert here (rather than just making K etc ints) to avoid potential overflow in the negative-stride correction // blas expects ints; convert here (rather than just making K etc ints) to avoid potential overflow in the negative-stride correction
if ((K > 0x7fffffffL)||(Sdx > 0x7fffffffL)||(Sdy > 0x7fffffffL)||(Sdx < -0x7fffffffL)||(Sdy < -0x7fffffffL))
{PyErr_SetString(PyExc_NotImplementedError, "array too big for BLAS (overflows int32 index)"); %(fail)s;}
int K32 = K; int K32 = K;
int Sdx32 = Sdx; int Sdx32 = Sdx;
int Sdy32 = Sdy; int Sdy32 = Sdy;
for (npy_int32 m = 0; m < M; ++m) { for (npy_intp m = 0; m < M; ++m) {
for (npy_int32 n_idx = Dpp[m * Sdpp]; n_idx < Dpp[(m+1)*Sdpp]; ++n_idx) { for (npy_int32 n_idx = Dpp[m * Sdpp]; n_idx < Dpp[(m+1)*Sdpp]; ++n_idx) {
const npy_int32 n = Dpi[n_idx * Sdpi]; // row index of non-null value for column K const npy_int32 n = Dpi[n_idx * Sdpi]; // row index of non-null value for column K
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论