提交 f7550e7f authored 作者: James Bergstra's avatar James Bergstra

sparse/basic: s/sparse/scipy.sparse

上级 eb535917
...@@ -8,7 +8,6 @@ To read about different sparse formats, see U{http://www-users.cs.umn.edu/~saad/ ...@@ -8,7 +8,6 @@ To read about different sparse formats, see U{http://www-users.cs.umn.edu/~saad/
import sys, operator import sys, operator
import numpy, theano import numpy, theano
from scipy import sparse
import scipy.sparse import scipy.sparse
from theano.printing import Print from theano.printing import Print
...@@ -23,11 +22,11 @@ def register_specialize(lopt, *tags, **kwargs): ...@@ -23,11 +22,11 @@ def register_specialize(lopt, *tags, **kwargs):
""" Types of sparse matrices to use for testing """ """ Types of sparse matrices to use for testing """
_mtypes = [sparse.csc_matrix, sparse.csr_matrix] _mtypes = [scipy.sparse.csc_matrix, scipy.sparse.csr_matrix]
#_mtypes = [sparse.csc_matrix, sparse.csr_matrix, sparse.dok_matrix, sparse.lil_matrix, sparse.coo_matrix] #_mtypes = [sparse.csc_matrix, sparse.csr_matrix, sparse.dok_matrix, sparse.lil_matrix, sparse.coo_matrix]
#* new class ``dia_matrix`` : the sparse DIAgonal format #* new class ``dia_matrix`` : the sparse DIAgonal format
#* new class ``bsr_matrix`` : the Block CSR format #* new class ``bsr_matrix`` : the Block CSR format
_mtype_to_str = {sparse.csc_matrix: "csc", sparse.csr_matrix: "csr"} _mtype_to_str = {scipy.sparse.csc_matrix: "csc", scipy.sparse.csr_matrix: "csr"}
def _is_sparse_variable(x): def _is_sparse_variable(x):
""" """
...@@ -51,15 +50,15 @@ def _is_sparse(x): ...@@ -51,15 +50,15 @@ def _is_sparse(x):
@rtype: boolean @rtype: boolean
@return: True iff x is a L{scipy.sparse.spmatrix} (and not a L{numpy.ndarray}) @return: True iff x is a L{scipy.sparse.spmatrix} (and not a L{numpy.ndarray})
""" """
if not isinstance(x, sparse.spmatrix) and not isinstance(x, numpy.ndarray): if not isinstance(x, scipy.sparse.spmatrix) and not isinstance(x, numpy.ndarray):
raise NotImplementedError("this function should only be called on sparse.scipy.sparse.spmatrix or numpy.ndarray, not,", x) raise NotImplementedError("this function should only be called on sparse.scipy.sparse.spmatrix or numpy.ndarray, not,", x)
return isinstance(x, sparse.spmatrix) return isinstance(x, scipy.sparse.spmatrix)
def _is_dense(x): def _is_dense(x):
""" """
@rtype: boolean @rtype: boolean
@return: True unless x is a L{scipy.sparse.spmatrix} (and not a L{numpy.ndarray}) @return: True unless x is a L{scipy.sparse.spmatrix} (and not a L{numpy.ndarray})
""" """
if not isinstance(x, sparse.spmatrix) and not isinstance(x, numpy.ndarray): if not isinstance(x, scipy.sparse.spmatrix) and not isinstance(x, numpy.ndarray):
raise NotImplementedError("this function should only be called on sparse.scipy.sparse.spmatrix or numpy.ndarray, not,", x) raise NotImplementedError("this function should only be called on sparse.scipy.sparse.spmatrix or numpy.ndarray, not,", x)
return isinstance(x, numpy.ndarray) return isinstance(x, numpy.ndarray)
...@@ -101,7 +100,7 @@ def as_sparse_variable(x): ...@@ -101,7 +100,7 @@ def as_sparse_variable(x):
as_sparse = as_sparse_variable as_sparse = as_sparse_variable
def constant(x): def constant(x):
if not isinstance(x, sparse.spmatrix): if not isinstance(x, scipy.sparse.spmatrix):
raise TypeError("sparse.constant must be called on a scipy.sparse.spmatrix") raise TypeError("sparse.constant must be called on a scipy.sparse.spmatrix")
try: try:
return SparseConstant(SparseType(format = x.format, return SparseConstant(SparseType(format = x.format,
...@@ -109,14 +108,15 @@ def constant(x): ...@@ -109,14 +108,15 @@ def constant(x):
except TypeError: except TypeError:
raise TypeError("Could not convert %s to SparseType" % x, type(x)) raise TypeError("Could not convert %s to SparseType" % x, type(x))
def value(x): if 0:
if not isinstance(x, sparse.spmatrix): def value(x):
raise TypeError("sparse.value must be called on a scipy.sparse.spmatrix") if not isinstance(x, scipy.sparse.spmatrix):
try: raise TypeError("sparse.value must be called on a scipy.sparse.spmatrix")
return SparseValue(SparseType(format = x.format, try:
dtype = x.dtype), x) return SparseValue(SparseType(format = x.format,
except TypeError: dtype = x.dtype), x)
raise TypeError("Could not convert %s to SparseType" % x, type(x)) except TypeError:
raise TypeError("Could not convert %s to SparseType" % x, type(x))
def sp_ones_like(x): def sp_ones_like(x):
data, indices, indptr, shape = csm_properties(x) #TODO: don't restrict to CSM formats data, indices, indptr, shape = csm_properties(x) #TODO: don't restrict to CSM formats
...@@ -132,8 +132,8 @@ class SparseType(gof.Type): ...@@ -132,8 +132,8 @@ class SparseType(gof.Type):
@note As far as I can tell, L{scipy.sparse} objects must be matrices, i.e. have dimension 2. @note As far as I can tell, L{scipy.sparse} objects must be matrices, i.e. have dimension 2.
""" """
format_cls = { format_cls = {
'csr' : sparse.csr_matrix, 'csr' : scipy.sparse.csr_matrix,
'csc' : sparse.csc_matrix 'csc' : scipy.sparse.csc_matrix
} }
dtype_set = set(['int', 'int8', 'int16','int32', 'int64', 'float32', 'float64', 'complex64','complex128']) dtype_set = set(['int', 'int8', 'int16','int32', 'int64', 'float32', 'float64', 'complex64','complex128'])
ndim = 2 ndim = 2
...@@ -187,11 +187,21 @@ class SparseType(gof.Type): ...@@ -187,11 +187,21 @@ class SparseType(gof.Type):
return "Sparse[%s, %s]" % (str(self.dtype), str(self.format)) return "Sparse[%s, %s]" % (str(self.dtype), str(self.format))
def values_eq_approx(self, a, b, eps=1e-6): def values_eq_approx(self, a, b, eps=1e-6):
# print "VEA", a, b, scipy.sparse.issparse(a), scipy.sparse.issparse(b), abs(a-b).sum(), abs(a-b).sum() < (1e-6 * a.nnz) #WARNING: equality comparison of sparse matrices is not fast or easy
# we definitely do not want to be doing this un-necessarily during
# a FAST_RUN computation..
return scipy.sparse.issparse(a) \ return scipy.sparse.issparse(a) \
and scipy.sparse.issparse(b) \ and scipy.sparse.issparse(b) \
and abs(a-b).sum() < (1e-6 * a.nnz) and abs(a-b).sum() < (1e-6 * a.nnz)
def values_eq(self, a, b):
#WARNING: equality comparison of sparse matrices is not fast or easy
# we definitely do not want to be doing this un-necessarily during
# a FAST_RUN computation..
return scipy.sparse.issparse(a) \
and scipy.sparse.issparse(b) \
and abs(a-b).sum() == 0.0
def is_valid_value(self, a): def is_valid_value(self, a):
return scipy.sparse.issparse(a) and (a.format == self.format) return scipy.sparse.issparse(a) and (a.format == self.format)
...@@ -377,13 +387,13 @@ class CSM(gof.Op): ...@@ -377,13 +387,13 @@ class CSM(gof.Op):
'as indices (shape'+`indices.shape`+') or elements as kmap ('+`numpy.size(self.kmap)`+')' 'as indices (shape'+`indices.shape`+') or elements as kmap ('+`numpy.size(self.kmap)`+')'
raise ValueError(errmsg) raise ValueError(errmsg)
if self.format == 'csc': if self.format == 'csc':
out[0] = sparse.csc_matrix((data, indices.copy(), indptr.copy()), out[0] = scipy.sparse.csc_matrix((data, indices.copy(), indptr.copy()),
numpy.asarray(shape), numpy.asarray(shape),
copy = False #1000*len(data.flatten()) copy = False #1000*len(data.flatten())
) )
else: else:
assert self.format == 'csr' assert self.format == 'csr'
out[0] = sparse.csr_matrix((data, indices.copy(), indptr.copy()), out[0] = scipy.sparse.csr_matrix((data, indices.copy(), indptr.copy()),
shape.copy(), shape.copy(),
copy = False #1000*len(data.flatten()) copy = False #1000*len(data.flatten())
) )
...@@ -795,7 +805,7 @@ class StructuredDotCSC(gof.Op): ...@@ -795,7 +805,7 @@ class StructuredDotCSC(gof.Op):
return r return r
def perform(self, node, (a_val, a_ind, a_ptr, a_nrows, b), (out,)): def perform(self, node, (a_val, a_ind, a_ptr, a_nrows, b), (out,)):
a = sparse.csc_matrix((a_val, a_ind, a_ptr), a = scipy.sparse.csc_matrix((a_val, a_ind, a_ptr),
(a_nrows, b.shape[0]), (a_nrows, b.shape[0]),
copy = False) copy = False)
#out[0] = a.dot(b) #out[0] = a.dot(b)
...@@ -952,7 +962,7 @@ class StructuredDotCSR(gof.Op): ...@@ -952,7 +962,7 @@ class StructuredDotCSR(gof.Op):
return r return r
def perform(self, node, (a_val, a_ind, a_ptr, b), (out,)): def perform(self, node, (a_val, a_ind, a_ptr, b), (out,)):
a = sparse.csr_matrix((a_val, a_ind, a_ptr), a = scipy.sparse.csr_matrix((a_val, a_ind, a_ptr),
(len(a_ptr)-1, b.shape[0]), (len(a_ptr)-1, b.shape[0]),
copy = True) #use view_map before setting this to False copy = True) #use view_map before setting this to False
#out[0] = a.dot(b) #out[0] = a.dot(b)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论