提交 a5b7de8b authored 作者: Yann N. Dauphin's avatar Yann N. Dauphin

base fix for bug in CSM and CSMGrad

上级 cbe15896
...@@ -698,10 +698,12 @@ class CSM(gof.Op): ...@@ -698,10 +698,12 @@ class CSM(gof.Op):
indptr.copy()), shape.copy(), indptr.copy()), shape.copy(),
copy=False) copy=False)
def grad(self, (data, indices, indptr, shape), (g_out,)): def grad(self, (x_data, x_indices, x_indptr, _), (g_out,)):
"""Return a gradient on the data vector""" """Return a gradient on the data vector"""
g_data, g_indices, g_indptr, _ = csm_properties(g_out)
#unpack the data vector and wrap it as a 1d TensorType #unpack the data vector and wrap it as a 1d TensorType
g_data = csm_grad(self.kmap)(data, csm_data(g_out), csm_indices(g_out)) g_data = csm_grad(self.kmap)(x_data, x_indices, x_indptr,
g_data, g_indices, g_indptr)
return [g_data, None, None, None] return [g_data, None, None, None]
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
...@@ -733,17 +735,40 @@ class CSMGrad(gof.op.Op): ...@@ -733,17 +735,40 @@ class CSMGrad(gof.op.Op):
self.__class__.__name__, self.__class__.__name__,
self.kmap) self.kmap)
def make_node(self, data, gout_data, gout_indices): def make_node(self, x_data, x_indices, x_indptr,
g_data = gout_data.type() g_data, g_indices, g_indptr):
return gof.Apply(self, [data, gout_data, gout_indices], [g_data]) gout_data = g_data.type()
return gof.Apply(self, [x_data, x_indices, x_indptr,
def perform(self, node, (data, gout_data, gout_indices), (g_data,)): g_data, g_indices, g_indptr], [gout_data])
def perform(self, node, (x_data, x_indices, x_indptr,
g_data, g_indices, g_indptr), (g_out,)):
gout_data = numpy.zeros_like(x_data)
for i in range(len(x_indptr) - 1):
x_pos = x_indptr[i]
g_pos = g_indptr[i]
x_end = x_indptr[i + 1]
g_end = g_indptr[i + 1]
while x_pos < x_end and g_pos < g_end:
x_ind = x_indices[x_pos]
g_ind = g_indices[g_pos]
if x_ind == g_ind:
gout_data[x_pos] = g_data[g_pos]
x_pos += 1
g_pos += 1
elif x_ind < g_ind:
x_pos += 1
else:
g_pos += 1
if self.kmap is None: if self.kmap is None:
g_data[0] = gout_data g_out[0] = gout_data
else: else:
grad = numpy.zeros_like(data) grad = numpy.zeros_like(x_data)
grad[self.kmap] = gout_data grad[self.kmap] = gout_data
g_data[0] = grad g_out[0] = grad
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
if self.kmap is None: if self.kmap is None:
...@@ -754,6 +779,134 @@ class CSMGrad(gof.op.Op): ...@@ -754,6 +779,134 @@ class CSMGrad(gof.op.Op):
csm_grad = CSMGrad csm_grad = CSMGrad
class CSMGradC(gof.Op):
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def __str__(self):
return self.__class__.__name__
def make_node(self, a_val, a_ind, a_ptr, b_val, b_ind, b_ptr):
return gof.Apply(self, [a_val, a_ind, a_ptr, b_val, b_ind, b_ptr],
[b_val.type()])
def c_code(self, node, name, (a_val, a_ind, a_ptr,
b_val, b_ind, b_ptr), (z,), sub):
# retrieve dtype number
typenum_z = node.outputs[0].type.dtype_specs()[-1]
if node.inputs[0].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for a_val')
if node.inputs[3].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for b_val')
return """
if (%(a_val)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); %(fail)s;}
if (%(a_ind)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); %(fail)s;}
if (%(a_ptr)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_ptr) != 1"); %(fail)s;}
if (%(b_val)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(b_val) != 1"); %(fail)s;}
if (%(b_ind)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(b_ind) != 1"); %(fail)s;}
if (%(b_ptr)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(b_ptr) != 1"); %(fail)s;}
if (%(a_ind)s->descr->type_num != PyArray_INT32) {
PyErr_SetString(PyExc_NotImplementedError, "a_ind dtype not INT32"); %(fail)s;}
if (%(a_ptr)s->descr->type_num != PyArray_INT32)
{PyErr_SetString(PyExc_NotImplementedError, "a_ptr dtype not INT32"); %(fail)s;}
if (%(b_ind)s->descr->type_num != PyArray_INT32) {
PyErr_SetString(PyExc_NotImplementedError, "b_ind dtype not INT32"); %(fail)s;}
if (%(b_ptr)s->descr->type_num != PyArray_INT32)
{PyErr_SetString(PyExc_NotImplementedError, "b_ptr dtype not INT32"); %(fail)s;}
if (%(a_val)s->dimensions[0] != %(a_ind)s->dimensions[0])
{PyErr_SetString(PyExc_NotImplementedError, "a_val and a_ind have different lengths"); %(fail)s;}
if (%(b_val)s->dimensions[0] != %(b_ind)s->dimensions[0])
{PyErr_SetString(PyExc_NotImplementedError, "b_val and b_ind have different lengths"); %(fail)s;}
if (%(a_ptr)s->dimensions[0] != %(b_ptr)s->dimensions[0])
{PyErr_SetString(PyExc_NotImplementedError, "a_ptr and b_ptr have different lengths"); %(fail)s;}
if ((!%(z)s) || (%(z)s->dimensions[0] != %(a_val)s->dimensions[0]))
{
{Py_XDECREF(%(z)s);}
npy_intp dims[] = {0};
dims[0] = %(a_val)s->dimensions[0];
%(z)s = (PyArrayObject*) PyArray_SimpleNew(1, dims, %(typenum_z)s);
}
{
// sparse array has size MxK, dense KxN, output MxN
npy_intp M = %(a_ptr)s->dimensions[0] - 1;
// strides tell you how many bytes to skip to go to next column/row entry
npy_intp Sz = %(z)s->strides[0] / %(z)s->descr->elsize;
npy_intp Sa_val = %(a_val)s->strides[0] / %(a_val)s->descr->elsize;
npy_intp Sa_ind = %(a_ind)s->strides[0] / %(a_ind)s->descr->elsize;
npy_intp Sa_ptr = %(a_ptr)s->strides[0] / %(a_ptr)s->descr->elsize;
npy_intp Sb_val = %(b_val)s->strides[0] / %(b_val)s->descr->elsize;
npy_intp Sb_ind = %(b_ind)s->strides[0] / %(b_ind)s->descr->elsize;
npy_intp Sb_ptr = %(b_ptr)s->strides[0] / %(b_ptr)s->descr->elsize;
// pointers to access actual data in the arrays passed as params.
dtype_%(z)s* __restrict__ Dz = (dtype_%(z)s*)%(z)s->data;
const dtype_%(a_val)s* __restrict__ Da_val = (dtype_%(a_val)s*)%(a_val)s->data;
const npy_int32 * __restrict__ Da_ind = (npy_int32*)%(a_ind)s->data;
const npy_int32 * __restrict__ Da_ptr = (npy_int32*)%(a_ptr)s->data;
const dtype_%(b_val)s* __restrict__ Db_val = (dtype_%(b_val)s*)%(b_val)s->data;
const npy_int32 * __restrict__ Db_ind = (npy_int32*)%(b_ind)s->data;
const npy_int32 * __restrict__ Db_ptr = (npy_int32*)%(b_ptr)s->data;
npy_intp nnz = %(a_ind)s->dimensions[0];
//clear the output array
memset(Dz, 0, nnz*sizeof(dtype_%(z)s));
// loop over inner dimension
for (npy_int64 m = 0; m < M; ++m)
{
npy_int32 a_pos = Da_ptr[m * Sa_ptr];
npy_int32 b_pos = Db_ptr[m * Sb_ptr];
npy_int32 a_end = Da_ptr[(m + 1) * Sa_ptr];
npy_int32 b_end = Db_ptr[(m + 1) * Sb_ptr];
while (a_pos < a_end && b_pos < b_end) {
npy_int32 a_ind = Da_ind[a_pos * Sa_ind];
npy_int32 b_ind = Db_ind[b_pos * Sb_ind];
if (a_ind == b_ind) {
Dz[a_pos*Sz] = Db_val[b_pos*Sb_val];
a_pos++;
b_pos++;
}
else if (a_ind < b_ind) {
a_pos++;
}
else {
b_pos++;
}
}
}
}
""" % dict(locals(), **sub)
def c_code_cache_version(self):
return (1,)
csm_grad_c = CSMGradC()
@gof.local_optimizer([csm_grad(None)])
def local_csm_grad_c(node):
""" usmm -> usmm_csc_dense """
if node.op == csm_grad(None):
return [csm_grad_c(*node.inputs)]
return False
register_specialize(local_csm_grad_c)
# #
# Conversion # Conversion
# #
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论