提交 61344155 authored 作者: James Bergstra's avatar James Bergstra

fixes to StructuredDotCSR

上级 b914ef2b
...@@ -869,20 +869,20 @@ sd_csc = StructuredDotCSC() ...@@ -869,20 +869,20 @@ sd_csc = StructuredDotCSC()
class StructuredDotCSR(gof.Op): class StructuredDotCSR(gof.Op):
def make_node(self, a_val, a_ind, a_ptr, a_ncols, b): def make_node(self, a_val, a_ind, a_ptr, b):
assert a_val.type.dtype == b.type.dtype assert a_val.type.dtype == b.type.dtype
r = gof.Apply(self, [a_val, a_ind, a_ptr, a_ncols, b], r = gof.Apply(self, [a_val, a_ind, a_ptr, b],
[tensor.tensor(a_val.type.dtype, (False, False))]) [tensor.tensor(a_val.type.dtype, (False, False))])
return r return r
def perform(self, node, (a_val, a_ind, a_ptr, a_ncols, b), (out,)): def perform(self, node, (a_val, a_ind, a_ptr, b), (out,)):
a = sparse.csr_matrix((a_val, a_ind, a_ptr), a = sparse.csr_matrix((a_val, a_ind, a_ptr),
(a_ncols, b.shape[0]), (len(a_ptr)-1, b.shape[0]),
copy = False) copy = True) #use view_map before setting this to False
out[0] = a.dot(b) out[0] = a.dot(b)
assert _is_dense(out[0]) # scipy 0.7 automatically converts to dense assert _is_dense(out[0]) # scipy 0.7 automatically converts to dense, but not .6 sometimes
def c_code(self, node, name, (a_val, a_ind, a_ptr, a_ncols, b), (z,), sub): def c_code(self, node, name, (a_val, a_ind, a_ptr, b), (z,), sub):
""" """
C-implementation of the dot product of the sparse matrix A and matrix B. C-implementation of the dot product of the sparse matrix A and matrix B.
@param a_val: non-zero values of the sparse matrix @param a_val: non-zero values of the sparse matrix
...@@ -897,7 +897,6 @@ class StructuredDotCSR(gof.Op): ...@@ -897,7 +897,6 @@ class StructuredDotCSR(gof.Op):
if (%(a_val)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); %(fail)s;} if (%(a_val)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); %(fail)s;}
if (%(a_ind)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); %(fail)s;} if (%(a_ind)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); %(fail)s;}
if (%(a_ptr)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_ptr) != 1"); %(fail)s;} if (%(a_ptr)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_ptr) != 1"); %(fail)s;}
if (%(a_ncols)s->nd != 0) {PyErr_SetString(PyExc_NotImplementedError, "rank(ncols) != 0"); %(fail)s;}
if (%(b)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); %(fail)s;} if (%(b)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); %(fail)s;}
if (%(a_val)s->descr->type_num != PyArray_DOUBLE) if (%(a_val)s->descr->type_num != PyArray_DOUBLE)
...@@ -909,26 +908,20 @@ class StructuredDotCSR(gof.Op): ...@@ -909,26 +908,20 @@ class StructuredDotCSR(gof.Op):
if (%(a_ptr)s->descr->type_num != PyArray_INT32) if (%(a_ptr)s->descr->type_num != PyArray_INT32)
{PyErr_SetString(PyExc_NotImplementedError, "a_ptr dtype not INT32"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "a_ptr dtype not INT32"); %(fail)s;}
if (%(a_ncols)s->descr->type_num != PyArray_INT32)
{PyErr_SetString(PyExc_NotImplementedError, "a_ncols dtype not INT32"); %(fail)s;}
if (%(b)s->descr->type_num != PyArray_DOUBLE) if (%(b)s->descr->type_num != PyArray_DOUBLE)
{PyErr_SetString(PyExc_NotImplementedError, "b's dtype not NPY_DOUBLE"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "b's dtype not NPY_DOUBLE"); %(fail)s;}
if (%(a_val)s->dimensions[0] != %(a_ind)s->dimensions[0]) if (%(a_val)s->dimensions[0] != %(a_ind)s->dimensions[0])
{PyErr_SetString(PyExc_NotImplementedError, "a_val and a_ind have different lengths"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "a_val and a_ind have different lengths"); %(fail)s;}
if (%(a_ptr)s->dimensions[0] != %(b)s->dimensions[0]+1)
{PyErr_SetString(PyExc_NotImplementedError, "a's number of columns doesn't match b's rows"); %(fail)s;}
if ((!%(z)s) if ((!%(z)s)
|| (%(z)s->dimensions[0] != ((npy_int32 *)%(a_ncols)s->data)[0]) || (%(z)s->dimensions[0] != %(a_ptr)s->dimensions[0]-1) //a's rows
|| (%(z)s->dimensions[1] != %(b)s->dimensions[1]) || (%(z)s->dimensions[1] != %(b)s->dimensions[1]) //b's columns
) )
{ {
if (%(z)s) Py_DECREF(%(z)s); if (%(z)s) Py_DECREF(%(z)s);
npy_intp dims[] = {0,0}; npy_intp dims[] = {0,0};
dims[0] = ((npy_int32 *)%(a_ncols)s->data)[0]; dims[0] = %(a_ptr)s->dimensions[0]-1;
dims[1] = %(b)s->dimensions[1]; dims[1] = %(b)s->dimensions[1];
%(z)s = (PyArrayObject*) PyArray_SimpleNew(2, dims, %(b)s->descr->type_num); %(z)s = (PyArrayObject*) PyArray_SimpleNew(2, dims, %(b)s->descr->type_num);
} }
...@@ -1011,11 +1004,13 @@ sd_csr = StructuredDotCSR() ...@@ -1011,11 +1004,13 @@ sd_csr = StructuredDotCSR()
def local_structured_dot(node): def local_structured_dot(node):
if node.op == _structured_dot: if node.op == _structured_dot:
a, b = node.inputs a, b = node.inputs
if a.type.format in ('csc','csr'): if a.type.format == 'csc':
a_val, a_ind, a_ptr, a_shape = csm_properties(a) a_val, a_ind, a_ptr, a_shape = csm_properties(a)
a_nsparse = a_shape[0] a_nsparse = a_shape[0]
sd_csx = sd_csc if a.type.format == 'csc' else sd_csr return [sd_csc(a_val, a_ind, a_ptr, a_nsparse, b)]
return [sd_csx(a_val,a_ind, a_ptr, a_nsparse, b)] if a.type.format == 'csr':
a_val, a_ind, a_ptr, a_shape = csm_properties(a)
return [sd_csr(a_val, a_ind, a_ptr, b)]
return False return False
register_specialize(local_structured_dot) register_specialize(local_structured_dot)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论