提交 25682947 authored 作者: James Bergstra's avatar James Bergstra

Added type constructors to sparse module: csc_dmatrix, csc_matrix, etc.

上级 f7550e7f
......@@ -15,6 +15,7 @@ from theano import gof
from theano import tensor
from theano import compile
from theano import scalar
from theano import config
#TODO: move this decorator to the compile submodule
def register_specialize(lopt, *tags, **kwargs):
......@@ -138,7 +139,7 @@ class SparseType(gof.Type):
dtype_set = set(['int', 'int8', 'int16','int32', 'int64', 'float32', 'float64', 'complex64','complex128'])
ndim = 2
def __init__(self, format, dtype = 'float64'):
def __init__(self, format, dtype):
"""
Fundamental way to create a sparse node.
@param dtype: Type of numbers in the matrix.
......@@ -205,8 +206,13 @@ class SparseType(gof.Type):
def is_valid_value(self, a):
return scipy.sparse.issparse(a) and (a.format == self.format)
csc_matrix = SparseType(format='csc')
csr_matrix = SparseType(format='csr')
# for more dtypes, call SparseType(format, dtype)
csc_matrix = SparseType(format='csc', dtype=config.floatX)
csr_matrix = SparseType(format='csr', dtype=config.floatX)
csc_dmatrix = SparseType(format='csc', dtype='float64')
csr_dmatrix = SparseType(format='csr', dtype='float64')
csc_fmatrix = SparseType(format='csc', dtype='float32')
csr_fmatrix = SparseType(format='csr', dtype='float32')
class _sparse_py_operators:
T = property(lambda self: transpose(self), doc = "Return aliased transpose of self (read-only)")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论