提交 f7550e7f authored 作者: James Bergstra's avatar James Bergstra

sparse/basic: s/sparse/scipy.sparse

上级 eb535917
......@@ -8,7 +8,6 @@ To read about different sparse formats, see U{http://www-users.cs.umn.edu/~saad/
import sys, operator
import numpy, theano
from scipy import sparse
import scipy.sparse
from theano.printing import Print
......@@ -23,11 +22,11 @@ def register_specialize(lopt, *tags, **kwargs):
""" Types of sparse matrices to use for testing """
_mtypes = [sparse.csc_matrix, sparse.csr_matrix]
_mtypes = [scipy.sparse.csc_matrix, scipy.sparse.csr_matrix]
#_mtypes = [sparse.csc_matrix, sparse.csr_matrix, sparse.dok_matrix, sparse.lil_matrix, sparse.coo_matrix]
#* new class ``dia_matrix`` : the sparse DIAgonal format
#* new class ``bsr_matrix`` : the Block CSR format
_mtype_to_str = {sparse.csc_matrix: "csc", sparse.csr_matrix: "csr"}
_mtype_to_str = {scipy.sparse.csc_matrix: "csc", scipy.sparse.csr_matrix: "csr"}
def _is_sparse_variable(x):
"""
......@@ -51,15 +50,15 @@ def _is_sparse(x):
@rtype: boolean
@return: True iff x is a L{scipy.sparse.spmatrix} (and not a L{numpy.ndarray})
"""
if not isinstance(x, sparse.spmatrix) and not isinstance(x, numpy.ndarray):
if not isinstance(x, scipy.sparse.spmatrix) and not isinstance(x, numpy.ndarray):
raise NotImplementedError("this function should only be called on sparse.scipy.sparse.spmatrix or numpy.ndarray, not,", x)
return isinstance(x, sparse.spmatrix)
return isinstance(x, scipy.sparse.spmatrix)
def _is_dense(x):
"""
@rtype: boolean
@return: True unless x is a L{scipy.sparse.spmatrix} (and not a L{numpy.ndarray})
"""
if not isinstance(x, sparse.spmatrix) and not isinstance(x, numpy.ndarray):
if not isinstance(x, scipy.sparse.spmatrix) and not isinstance(x, numpy.ndarray):
raise NotImplementedError("this function should only be called on sparse.scipy.sparse.spmatrix or numpy.ndarray, not,", x)
return isinstance(x, numpy.ndarray)
......@@ -101,7 +100,7 @@ def as_sparse_variable(x):
as_sparse = as_sparse_variable
def constant(x):
if not isinstance(x, sparse.spmatrix):
if not isinstance(x, scipy.sparse.spmatrix):
raise TypeError("sparse.constant must be called on a scipy.sparse.spmatrix")
try:
return SparseConstant(SparseType(format = x.format,
......@@ -109,14 +108,15 @@ def constant(x):
except TypeError:
raise TypeError("Could not convert %s to SparseType" % x, type(x))
def value(x):
if not isinstance(x, sparse.spmatrix):
raise TypeError("sparse.value must be called on a scipy.sparse.spmatrix")
try:
return SparseValue(SparseType(format = x.format,
dtype = x.dtype), x)
except TypeError:
raise TypeError("Could not convert %s to SparseType" % x, type(x))
if 0:
def value(x):
if not isinstance(x, scipy.sparse.spmatrix):
raise TypeError("sparse.value must be called on a scipy.sparse.spmatrix")
try:
return SparseValue(SparseType(format = x.format,
dtype = x.dtype), x)
except TypeError:
raise TypeError("Could not convert %s to SparseType" % x, type(x))
def sp_ones_like(x):
data, indices, indptr, shape = csm_properties(x) #TODO: don't restrict to CSM formats
......@@ -132,8 +132,8 @@ class SparseType(gof.Type):
@note As far as I can tell, L{scipy.sparse} objects must be matrices, i.e. have dimension 2.
"""
format_cls = {
'csr' : sparse.csr_matrix,
'csc' : sparse.csc_matrix
'csr' : scipy.sparse.csr_matrix,
'csc' : scipy.sparse.csc_matrix
}
dtype_set = set(['int', 'int8', 'int16','int32', 'int64', 'float32', 'float64', 'complex64','complex128'])
ndim = 2
......@@ -187,11 +187,21 @@ class SparseType(gof.Type):
return "Sparse[%s, %s]" % (str(self.dtype), str(self.format))
def values_eq_approx(self, a, b, eps=1e-6):
# print "VEA", a, b, scipy.sparse.issparse(a), scipy.sparse.issparse(b), abs(a-b).sum(), abs(a-b).sum() < (1e-6 * a.nnz)
#WARNING: equality comparison of sparse matrices is not fast or easy
# we definitely do not want to be doing this un-necessarily during
# a FAST_RUN computation..
return scipy.sparse.issparse(a) \
and scipy.sparse.issparse(b) \
and abs(a-b).sum() < (1e-6 * a.nnz)
def values_eq(self, a, b):
#WARNING: equality comparison of sparse matrices is not fast or easy
# we definitely do not want to be doing this un-necessarily during
# a FAST_RUN computation..
return scipy.sparse.issparse(a) \
and scipy.sparse.issparse(b) \
and abs(a-b).sum() == 0.0
def is_valid_value(self, a):
return scipy.sparse.issparse(a) and (a.format == self.format)
......@@ -377,13 +387,13 @@ class CSM(gof.Op):
'as indices (shape'+`indices.shape`+') or elements as kmap ('+`numpy.size(self.kmap)`+')'
raise ValueError(errmsg)
if self.format == 'csc':
out[0] = sparse.csc_matrix((data, indices.copy(), indptr.copy()),
out[0] = scipy.sparse.csc_matrix((data, indices.copy(), indptr.copy()),
numpy.asarray(shape),
copy = False #1000*len(data.flatten())
)
else:
assert self.format == 'csr'
out[0] = sparse.csr_matrix((data, indices.copy(), indptr.copy()),
out[0] = scipy.sparse.csr_matrix((data, indices.copy(), indptr.copy()),
shape.copy(),
copy = False #1000*len(data.flatten())
)
......@@ -795,7 +805,7 @@ class StructuredDotCSC(gof.Op):
return r
def perform(self, node, (a_val, a_ind, a_ptr, a_nrows, b), (out,)):
a = sparse.csc_matrix((a_val, a_ind, a_ptr),
a = scipy.sparse.csc_matrix((a_val, a_ind, a_ptr),
(a_nrows, b.shape[0]),
copy = False)
#out[0] = a.dot(b)
......@@ -952,7 +962,7 @@ class StructuredDotCSR(gof.Op):
return r
def perform(self, node, (a_val, a_ind, a_ptr, b), (out,)):
a = sparse.csr_matrix((a_val, a_ind, a_ptr),
a = scipy.sparse.csr_matrix((a_val, a_ind, a_ptr),
(len(a_ptr)-1, b.shape[0]),
copy = True) #use view_map before setting this to False
#out[0] = a.dot(b)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论