提交 cc62c3a8 authored 作者: nouiz's avatar nouiz

Merge pull request #523 from ynd/sp_sandbox

Added MulSD, AddSS optimizations and new blas function declerations to sparse sandbox
...@@ -27,6 +27,266 @@ fcast = Cast('float32') ...@@ -27,6 +27,266 @@ fcast = Cast('float32')
dcast = Cast('float64') dcast = Cast('float64')
# register a specialization to replace AddSS -> AddSSData
@gof.local_optimizer([add_s_s])
def local_add_s_s(node):
"""
If two matrices are known to have the same sparsity pattern,
optimize the addition by only adding their data vector.
Very special case optimization. Activate when for add(x, y),
y is an expression like sp_ones_like(x) * another_matrix.
This is useful for sparse weight updates.
Work also for add(x, neg(y)) in the same case.
As of this writting sub is only implemented as x + neg(y) for sparse matrix.
"""
if node.op == add_s_s:
x, y = node.inputs
# In case addition was transformed to subtraction
if hasattr(y.owner, 'op') and y.owner.op == neg:
y_ = y.owner.inputs[0]
else:
y_ = y
if y_.owner is None:
return False
if hasattr(y_.owner, 'op') and y_.owner.op not in [mul_s_s, mul_s_d]:
return False
def same_pattern(node):
"""Check node has same sparsity as x."""
# In case the sparse matrix is multiplied by a scalar (ex: learning rate)
if hasattr(node.owner, 'op') and node.owner.op == mul_scalar:
node = node.owner.inputs[1]
# Check node creates a matrix
if not hasattr(node.owner, 'op') or not isinstance(node.owner.op, CSM):
return False
# Check matrix is creates from CSMProperties
if filter(lambda i: not hasattr(i.owner, 'op') or not isinstance(i.owner.op, CSMProperties), node.owner.inputs[1:]):
return False
# Verify indices, indptr and shape are the same as x
if filter(lambda i: i.owner.inputs[0] != x, node.owner.inputs[1:]):
return False
return True
if filter(same_pattern, y_.owner.inputs):
return [add_s_s_data(x, y)]
return False
register_specialize(local_add_s_s)
class AddSSData(gof.op.Op):
'''Add two sparse matrices assuming they have the same sparsity pattern. '''
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x, y):
x, y = map(as_sparse_variable, [x, y])
if x.type.dtype != y.type.dtype:
raise NotImplementedError()
if x.type.format != y.type.format:
raise NotImplementedError()
return gof.Apply(self,
[x, y],
[SparseType(dtype = x.type.dtype,
format = x.type.format).make_variable()])
def perform(self, node, (x, y), (out, )):
assert _is_sparse(x) and _is_sparse(y)
assert x.shape == y.shape
out[0] = x.copy()
out[0].data += y.data
add_s_s_data = AddSSData()
# register a specialization to replace MulSD -> MulSDCSX
@gof.local_optimizer([mul_s_d])
def local_mul_s_d(node):
if node.op == mul_s_d:
x, y = node.inputs
x_is_sparse_variable = _is_sparse_variable(x)
y_is_sparse_variable = _is_sparse_variable(y)
if x_is_sparse_variable:
svar = x
dvar = y
else:
svar = y
dvar = x
if dvar.type.ndim != 2:
return False
if svar.type.format == 'csc':
CSx = CSC
mul_s_d_csx = mul_s_d_csc
elif svar.type.format == 'csr':
CSx = CSR
mul_s_d_csx = mul_s_d_csr
else:
raise NotImplemented()
c_data = mul_s_d_csx(csm_data(svar), csm_indices(svar), csm_indptr(svar), dvar)
return [CSx(c_data, csm_indices(svar), csm_indptr(svar), csm_shape(svar))]
return False
register_specialize(local_mul_s_d)
class MulSDCSC(gof.Op):
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, a_data, a_indices, a_indptr, b):
assert b.type.ndim == 2
return gof.Apply(self, [a_data, a_indices, a_indptr, b],
[tensor.tensor(b.dtype, (False,))])
#def perform(self, node, (a_data, a_indices, a_indptr, b), (out,)):
# return NotImplementedError()
def c_code(self, node, name, (_data, _indices, _indptr, _b,), (_zout, ), sub):
if node.inputs[0].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for a')
if node.inputs[3].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for b')
return """
if (%(_b)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); %(fail)s;}
if (%(_data)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(data) != 1"); %(fail)s;}
if (%(_indices)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); %(fail)s;}
if (%(_indptr)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(indptr) != 1"); %(fail)s;}
if( %(_indices)s->descr->type_num != PyArray_INT32) {
PyErr_SetString(PyExc_NotImplementedError, "C"); %(fail)s;}
if( %(_indptr)s->descr->type_num != PyArray_INT32)
{PyErr_SetString(PyExc_NotImplementedError, "D"); %(fail)s;}
if (!%(_zout)s)
{
%(_zout)s = (PyArrayObject*) PyArray_SimpleNew(1, %(_indices)s->dimensions, %(_b)s->descr->type_num);
}
if (%(_zout)s->dimensions[0] != %(_indices)s->dimensions[0])
{
PyErr_SetString(PyExc_NotImplementedError, "somehow _zout got the wrong size.. and I don't know how to resize it.");
%(fail)s;
}
{ //makes it compile even though labels jump over variable definitions.
const npy_intp nnz = %(_indices)s->dimensions[0];
const npy_intp N = %(_indptr)s->dimensions[0]-1; //TODO: error checking with this
const dtype_%(_data)s * const __restrict__ data = (dtype_%(_data)s*)%(_data)s->data;
const npy_int32 * const __restrict__ indptr = (npy_int32 *)%(_indptr)s->data;
const npy_int32 * const __restrict__ indices = (npy_int32 *)%(_indices)s->data;
dtype_%(_zout)s * const __restrict__ zout = (dtype_%(_zout)s*)%(_zout)s->data;
const npy_intp Sb = %(_b)s->strides[0];
// loop over columns
for (npy_int32 j = 0; j < N; ++j)
{
// for each non-null value in the sparse column
for (npy_int32 i_idx = indptr[j]; i_idx < indptr[j+1]; ++i_idx)
{
// extract row index of non-null value
npy_int32 i = indices[i_idx];
// extract i-th row of dense matrix
const dtype_%(_b)s* __restrict__ b_row = (dtype_%(_b)s*)(%(_b)s->data + Sb * i);
// write resulting gradient to sparse output
zout[i_idx] = data[i_idx] * b_row[j];
}
}
}
"""% dict(locals(), **sub)
mul_s_d_csc = MulSDCSC()
class MulSDCSR(gof.Op):
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, a_data, a_indices, a_indptr, b):
assert b.type.ndim == 2
return gof.Apply(self, [a_data, a_indices, a_indptr, b],
[tensor.tensor(b.dtype, (False,))])
#def perform(self, node, (a_data, a_indices, a_indptr, b), (out,)):
# return NotImplemented()
def c_code(self, node, name, (_data, _indices, _indptr, _b,), (_zout, ), sub):
if node.inputs[0].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for a')
if node.inputs[3].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for b')
return """
if (%(_b)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); %(fail)s;}
if (%(_data)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(data) != 1"); %(fail)s;}
if (%(_indices)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); %(fail)s;}
if (%(_indptr)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(indptr) != 1"); %(fail)s;}
if( %(_indices)s->descr->type_num != PyArray_INT32) {
PyErr_SetString(PyExc_NotImplementedError, "C"); %(fail)s;}
if( %(_indptr)s->descr->type_num != PyArray_INT32)
{PyErr_SetString(PyExc_NotImplementedError, "D"); %(fail)s;}
if (!%(_zout)s)
{
%(_zout)s = (PyArrayObject*) PyArray_SimpleNew(1, %(_indices)s->dimensions, %(_b)s->descr->type_num);
}
if (%(_zout)s->dimensions[0] != %(_indices)s->dimensions[0])
{
PyErr_SetString(PyExc_NotImplementedError, "somehow _zout got the wrong size.. and I don't know how to resize it.");
%(fail)s;
}
{ //makes it compile even though labels jump over variable definitions.
const npy_intp nnz = %(_indices)s->dimensions[0];
const npy_intp N = %(_indptr)s->dimensions[0]-1; //TODO: error checking with this
const dtype_%(_data)s * const __restrict__ data = (dtype_%(_data)s*)%(_data)s->data;
const npy_int32 * const __restrict__ indptr = (npy_int32 *)%(_indptr)s->data;
const npy_int32 * const __restrict__ indices = (npy_int32 *)%(_indices)s->data;
dtype_%(_zout)s * const __restrict__ zout = (dtype_%(_zout)s*)%(_zout)s->data;
const npy_intp Sb = %(_b)s->strides[0];
// loop over columns
for (npy_int32 j = 0; j < N; ++j)
{
// extract i-th row of dense matrix
const dtype_%(_b)s* __restrict__ b_row = (dtype_%(_b)s*)(%(_b)s->data + Sb * j);
// for each non-null value in the sparse column
for (npy_int32 i_idx = indptr[j]; i_idx < indptr[j+1]; ++i_idx)
{
// extract row index of non-null value
npy_int32 i = indices[i_idx];
// write resulting gradient to sparse output
zout[i_idx] = data[i_idx] * b_row[i];
}
}
}
"""% dict(locals(), **sub)
mul_s_d_csr = MulSDCSR()
class Poisson(gof.op.Op): class Poisson(gof.op.Op):
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other)) return (type(self) == type(other))
...@@ -46,7 +306,6 @@ class Poisson(gof.op.Op): ...@@ -46,7 +306,6 @@ class Poisson(gof.op.Op):
out[0].eliminate_zeros() out[0].eliminate_zeros()
poisson = Poisson() poisson = Poisson()
class Multinomial(gof.op.Op): class Multinomial(gof.op.Op):
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other)) return (type(self) == type(other))
...@@ -483,10 +742,10 @@ class SamplingDotCsr(gof.Op): ...@@ -483,10 +742,10 @@ class SamplingDotCsr(gof.Op):
if dot_out == "float32": if dot_out == "float32":
conv_type = "float" conv_type = "float"
cdot = "sdot_sub_" cdot = "sdot_"
else: else:
conv_type = "double" conv_type = "double"
cdot = "ddot_sub_" cdot = "ddot_"
# retrieve dtype number # retrieve dtype number
typenum_x = node.inputs[0].type.dtype_specs()[-1] typenum_x = node.inputs[0].type.dtype_specs()[-1]
...@@ -580,9 +839,7 @@ class SamplingDotCsr(gof.Op): ...@@ -580,9 +839,7 @@ class SamplingDotCsr(gof.Op):
const dtype_%(y)s* y_col = (dtype_%(y)s*)(%(y)s->data + %(y)s->strides[0] * n); const dtype_%(y)s* y_col = (dtype_%(y)s*)(%(y)s->data + %(y)s->strides[0] * n);
%(cdot)s((int*)&K, (const %(conv_type)s*)x_row, (int*)&Sdx, (const %(conv_type)s*)y_col, (int*)&Sdy, &Dzd[n_idx * Sdzd]); Dzd[n_idx * Sdzd] = Dpd[n_idx * Sdpd] * %(cdot)s((int*)&K, (const %(conv_type)s*)x_row, (int*)&Sdx, (const %(conv_type)s*)y_col, (int*)&Sdy);
Dzd[n_idx * Sdzd] *= Dpd[n_idx * Sdpd];
} }
} }
} }
......
...@@ -604,6 +604,7 @@ def blas_header_text(): ...@@ -604,6 +604,7 @@ def blas_header_text():
void sswap_( const int*, float *, const int*, float *, const int*); void sswap_( const int*, float *, const int*, float *, const int*);
void scopy_( const int*, const float *, const int*, float *, const int*); void scopy_( const int*, const float *, const int*, float *, const int*);
void saxpy_( const int*, const float *, const float *, const int*, float *, const int*); void saxpy_( const int*, const float *, const float *, const int*, float *, const int*);
float sdot_(const int*, const float *, const int*, const float *, const int*);
void sdot_sub_(const int*, const float *, const int*, const float *, const int*, float *); void sdot_sub_(const int*, const float *, const int*, const float *, const int*, float *);
void sdsdot_sub_( const int*, const float *, const float *, const int*, const float *, const int*, float *); void sdsdot_sub_( const int*, const float *, const float *, const int*, const float *, const int*, float *);
void sscal_( const int*, const float *, float *, const int*); void sscal_( const int*, const float *, float *, const int*);
...@@ -621,6 +622,7 @@ def blas_header_text(): ...@@ -621,6 +622,7 @@ def blas_header_text():
void dcopy_( const int*, const double *, const int*, double *, const int*); void dcopy_( const int*, const double *, const int*, double *, const int*);
void daxpy_( const int*, const double *, const double *, const int*, double *, const int*); void daxpy_( const int*, const double *, const double *, const int*, double *, const int*);
void dswap_( const int*, double *, const int*, double *, const int*); void dswap_( const int*, double *, const int*, double *, const int*);
double ddot_(const int*, const double *, const int*, const double *, const int*);
void dsdot_sub_(const int*, const float *, const int*, const float *, const int*, double *); void dsdot_sub_(const int*, const float *, const int*, const float *, const int*, double *);
void ddot_sub_( const int*, const double *, const int*, const double *, const int*, double *); void ddot_sub_( const int*, const double *, const int*, const double *, const int*, double *);
void dscal_( const int*, const double *, double *, const int*); void dscal_( const int*, const double *, double *, const int*);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论