提交 0a2a930f authored 作者: Rami Al-Rfou's avatar Rami Al-Rfou

C code now supports csr formats and no-in-place operations

上级 a961de79
......@@ -1744,29 +1744,46 @@ class AddSD(gof.op.Op):
broadcastable=y.type.broadcastable
).make_variable()])
def cc_code(self, node, name, (_data, _indices, _indptr, y), (z, ), sub):
def c_code(self, node, name, (_data, _indices, _indptr, y), (z, ), sub):
inplace = int(self.inplace)
format = {'csc': 0, 'csr':1}[self.format]
code = """
if(%(z)s) {Py_XDECREF(%(z)s);}
if (!%(inplace)s){
%(z)s = (PyArrayObject *) PyArray_NewCopy(%(y)s, NPY_CORDER);
}else{
%(z)s = %(y)s;
}
Py_XINCREF(%(z)s);
npy_intp N = PyArray_DIMS(%(_indptr)s)[0]-1;
const npy_int32 * __restrict__ indptr = (npy_int32 *)%(_indptr)s->data;
const npy_int32 * __restrict__ indices = (npy_int32*)%(_indices)s->data;
const dtype_%(_data)s* __restrict__ data = (dtype_%(_data)s*)%(_data)s->data;
dtype_%(y)s* ydata = (dtype_%(y)s*)PyArray_DATA(%(y)s);
dtype_%(z)s* zdata = (dtype_%(z)s*)PyArray_DATA(%(z)s);
int Yi = PyArray_STRIDES(%(y)s)[0]/PyArray_DESCR(%(y)s)->elsize;
int Yj = PyArray_STRIDES(%(y)s)[1]/PyArray_DESCR(%(y)s)->elsize;
npy_int32 pos;
if (%(format)s == 0){
for (npy_int32 col = 0; col < N; ++col){
for (npy_int32 ind = indptr[col]; ind < indptr[col+1]; ++ind){
npy_int32 row = indices[ind];
pos = row * Yi + col * Yj;
ydata[pos] = ydata[pos] + data[ind];
zdata[pos] = ydata[pos] + data[ind];
}
}
if(%(z)s) {Py_XDECREF(%(z)s);}
%(z)s = %(y)s;
Py_XINCREF(%(z)s);
}else{
for (npy_int32 row = 0; row < N; ++row){
for (npy_int32 ind = indptr[row]; ind < indptr[row+1]; ++ind){
npy_int32 col = indices[ind];
pos = row * Yi + col * Yj;
zdata[pos] = ydata[pos] + data[ind];
}
}
}
""" % dict(locals(), **sub)
return code
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论