提交 8e633dca authored 作者: Yann N. Dauphin's avatar Yann N. Dauphin

allow for duplicates and unsorted sparse dimensions

上级 76c598ee
...@@ -698,12 +698,12 @@ class CSM(gof.Op): ...@@ -698,12 +698,12 @@ class CSM(gof.Op):
indptr.copy()), shape.copy(), indptr.copy()), shape.copy(),
copy=False) copy=False)
def grad(self, (x_data, x_indices, x_indptr, _), (g_out,)): def grad(self, (x_data, x_indices, x_indptr, x_shape), (g_out,)):
"""Return a gradient on the data vector""" """Return a gradient on the data vector"""
g_data, g_indices, g_indptr, _ = csm_properties(g_out) g_data, g_indices, g_indptr, g_shape = csm_properties(g_out)
#unpack the data vector and wrap it as a 1d TensorType #unpack the data vector and wrap it as a 1d TensorType
g_data = csm_grad(self.kmap)(x_data, x_indices, x_indptr, g_data = csm_grad(self.kmap)(x_data, x_indices, x_indptr, x_shape,
g_data, g_indices, g_indptr) g_data, g_indices, g_indptr, g_shape)
return [g_data, None, None, None] return [g_data, None, None, None]
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
...@@ -735,33 +735,30 @@ class CSMGrad(gof.op.Op): ...@@ -735,33 +735,30 @@ class CSMGrad(gof.op.Op):
self.__class__.__name__, self.__class__.__name__,
self.kmap) self.kmap)
def make_node(self, x_data, x_indices, x_indptr, def make_node(self, x_data, x_indices, x_indptr, x_shape,
g_data, g_indices, g_indptr): g_data, g_indices, g_indptr, g_shape):
gout_data = g_data.type() gout_data = g_data.type()
return gof.Apply(self, [x_data, x_indices, x_indptr, return gof.Apply(self, [x_data, x_indices, x_indptr, x_shape,
g_data, g_indices, g_indptr], [gout_data]) g_data, g_indices, g_indptr, g_shape], [gout_data])
def perform(self, node, (x_data, x_indices, x_indptr, def perform(self, node, (x_data, x_indices, x_indptr, x_shape,
g_data, g_indices, g_indptr), (g_out,)): g_data, g_indices, g_indptr, g_shape), (g_out,)):
if len(x_indptr) - 1 == x_shape[0]:
sp_dim = x_shape[1]
else:
sp_dim = x_shape[0]
g_row = numpy.zeros(sp_dim, dtype=g_data.dtype)
gout_data = numpy.zeros_like(x_data) gout_data = numpy.zeros_like(x_data)
for i in range(len(x_indptr) - 1): for i in range(len(x_indptr) - 1):
x_pos = x_indptr[i] for j_ptr in range(g_indptr[i], g_indptr[i + 1]):
g_pos = g_indptr[i] g_row[g_indices[j_ptr]] += g_data[j_ptr]
x_end = x_indptr[i + 1]
g_end = g_indptr[i + 1]
while x_pos < x_end and g_pos < g_end: for j_ptr in range(x_indptr[i], x_indptr[i + 1]):
x_ind = x_indices[x_pos] gout_data[j_ptr] = g_row[x_indices[j_ptr]]
g_ind = g_indices[g_pos]
for j_ptr in range(g_indptr[i], g_indptr[i + 1]):
if x_ind == g_ind: g_row[g_indices[j_ptr]] = 0
gout_data[x_pos] = g_data[g_pos]
x_pos += 1
g_pos += 1
elif x_ind < g_ind:
x_pos += 1
else:
g_pos += 1
if self.kmap is None: if self.kmap is None:
g_out[0] = gout_data g_out[0] = gout_data
...@@ -789,19 +786,19 @@ class CSMGradC(gof.Op): ...@@ -789,19 +786,19 @@ class CSMGradC(gof.Op):
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
def make_node(self, a_val, a_ind, a_ptr, b_val, b_ind, b_ptr): def make_node(self, a_val, a_ind, a_ptr, a_dim, b_val, b_ind, b_ptr, b_dim):
return gof.Apply(self, [a_val, a_ind, a_ptr, b_val, b_ind, b_ptr], return gof.Apply(self, [a_val, a_ind, a_ptr, a_dim,
[b_val.type()]) b_val, b_ind, b_ptr, b_dim], [b_val.type()])
def c_code(self, node, name, (a_val, a_ind, a_ptr, def c_code(self, node, name, (a_val, a_ind, a_ptr, a_dim,
b_val, b_ind, b_ptr), (z,), sub): b_val, b_ind, b_ptr, b_dim), (z,), sub):
# retrieve dtype number # retrieve dtype number
typenum_z = node.outputs[0].type.dtype_specs()[-1] typenum_z = node.outputs[0].type.dtype_specs()[-1]
if node.inputs[0].type.dtype in ('complex64', 'complex128'): if node.inputs[0].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for a_val') raise NotImplementedError('Complex types are not supported for a_val')
if node.inputs[3].type.dtype in ('complex64', 'complex128'): if node.inputs[3].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for b_val') raise NotImplementedError('Complex types are not supported for b_val')
return """ return """
if (%(a_val)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); %(fail)s;} if (%(a_val)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); %(fail)s;}
if (%(a_ind)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); %(fail)s;} if (%(a_ind)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); %(fail)s;}
...@@ -842,7 +839,11 @@ class CSMGradC(gof.Op): ...@@ -842,7 +839,11 @@ class CSMGradC(gof.Op):
{ {
// sparse array has size MxK, dense KxN, output MxN // sparse array has size MxK, dense KxN, output MxN
npy_intp M = %(a_ptr)s->dimensions[0] - 1; npy_intp M = %(a_ptr)s->dimensions[0] - 1;
npy_intp a_dim_0 = ((npy_int32 *)%(a_dim)s->data)[0];
npy_intp a_dim_1 = ((npy_int32 *)%(a_dim)s->data)[1];
npy_intp sp_dim = (M == a_dim_0)?a_dim_1:a_dim_0;
// strides tell you how many bytes to skip to go to next column/row entry // strides tell you how many bytes to skip to go to next column/row entry
npy_intp Sz = %(z)s->strides[0] / %(z)s->descr->elsize; npy_intp Sz = %(z)s->strides[0] / %(z)s->descr->elsize;
npy_intp Sa_val = %(a_val)s->strides[0] / %(a_val)s->descr->elsize; npy_intp Sa_val = %(a_val)s->strides[0] / %(a_val)s->descr->elsize;
...@@ -862,33 +863,29 @@ class CSMGradC(gof.Op): ...@@ -862,33 +863,29 @@ class CSMGradC(gof.Op):
const npy_int32 * __restrict__ Db_ptr = (npy_int32*)%(b_ptr)s->data; const npy_int32 * __restrict__ Db_ptr = (npy_int32*)%(b_ptr)s->data;
npy_intp nnz = %(a_ind)s->dimensions[0]; npy_intp nnz = %(a_ind)s->dimensions[0];
dtype_%(b_val)s b_row[sp_dim];
//clear the output array //clear the output array
memset(Dz, 0, nnz*sizeof(dtype_%(z)s)); memset(Dz, 0, nnz*sizeof(dtype_%(z)s));
memset(b_row, 0, sp_dim*sizeof(dtype_%(b_val)s));
// loop over inner dimension // loop over inner dimension
for (npy_int64 m = 0; m < M; ++m) for (npy_int64 m = 0; m < M; ++m)
{ {
npy_int32 a_pos = Da_ptr[m * Sa_ptr]; for (npy_int32 j_ptr = Db_ptr[m * Sb_ptr];
npy_int32 b_pos = Db_ptr[m * Sb_ptr]; j_ptr < Db_ptr[(m + 1) * Sb_ptr]; j_ptr++) {
npy_int32 a_end = Da_ptr[(m + 1) * Sa_ptr]; b_row[Db_ind[j_ptr * Sb_ind]] += Db_val[j_ptr*Sb_val];
npy_int32 b_end = Db_ptr[(m + 1) * Sb_ptr]; }
while (a_pos < a_end && b_pos < b_end) { for (npy_int32 j_ptr = Da_ptr[m * Sa_ptr];
npy_int32 a_ind = Da_ind[a_pos * Sa_ind]; j_ptr < Da_ptr[(m + 1) * Sa_ptr]; j_ptr++) {
npy_int32 b_ind = Db_ind[b_pos * Sb_ind]; Dz[j_ptr*Sz] = b_row[Da_ind[j_ptr * Sa_ind]];
}
if (a_ind == b_ind) {
Dz[a_pos*Sz] = Db_val[b_pos*Sb_val]; for (npy_int32 j_ptr = Db_ptr[m * Sb_ptr];
a_pos++; j_ptr < Db_ptr[(m + 1) * Sb_ptr]; j_ptr++) {
b_pos++; b_row[Db_ind[j_ptr * Sb_ind]] = 0;
}
else if (a_ind < b_ind) {
a_pos++;
}
else {
b_pos++;
}
} }
} }
} }
...@@ -896,12 +893,12 @@ class CSMGradC(gof.Op): ...@@ -896,12 +893,12 @@ class CSMGradC(gof.Op):
""" % dict(locals(), **sub) """ % dict(locals(), **sub)
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
csm_grad_c = CSMGradC() csm_grad_c = CSMGradC()
@gof.local_optimizer([csm_grad(None)]) @gof.local_optimizer([csm_grad(None)])
def local_csm_grad_c(node): def local_csm_grad_c(node):
""" usmm -> usmm_csc_dense """ """ csm_grad(None) -> csm_grad_c """
if node.op == csm_grad(None): if node.op == csm_grad(None):
return [csm_grad_c(*node.inputs)] return [csm_grad_c(*node.inputs)]
return False return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论