提交 25682947 authored 作者: James Bergstra's avatar James Bergstra

Added type constructors to sparse module: csc_dmatrix, csc_matrix, etc.

上级 f7550e7f
...@@ -15,6 +15,7 @@ from theano import gof ...@@ -15,6 +15,7 @@ from theano import gof
from theano import tensor from theano import tensor
from theano import compile from theano import compile
from theano import scalar from theano import scalar
from theano import config
#TODO: move this decorator to the compile submodule #TODO: move this decorator to the compile submodule
def register_specialize(lopt, *tags, **kwargs): def register_specialize(lopt, *tags, **kwargs):
...@@ -138,7 +139,7 @@ class SparseType(gof.Type): ...@@ -138,7 +139,7 @@ class SparseType(gof.Type):
dtype_set = set(['int', 'int8', 'int16','int32', 'int64', 'float32', 'float64', 'complex64','complex128']) dtype_set = set(['int', 'int8', 'int16','int32', 'int64', 'float32', 'float64', 'complex64','complex128'])
ndim = 2 ndim = 2
def __init__(self, format, dtype = 'float64'): def __init__(self, format, dtype):
""" """
Fundamental way to create a sparse node. Fundamental way to create a sparse node.
@param dtype: Type of numbers in the matrix. @param dtype: Type of numbers in the matrix.
...@@ -205,8 +206,13 @@ class SparseType(gof.Type): ...@@ -205,8 +206,13 @@ class SparseType(gof.Type):
def is_valid_value(self, a): def is_valid_value(self, a):
return scipy.sparse.issparse(a) and (a.format == self.format) return scipy.sparse.issparse(a) and (a.format == self.format)
csc_matrix = SparseType(format='csc') # for more dtypes, call SparseType(format, dtype)
csr_matrix = SparseType(format='csr') csc_matrix = SparseType(format='csc', dtype=config.floatX)
csr_matrix = SparseType(format='csr', dtype=config.floatX)
csc_dmatrix = SparseType(format='csc', dtype='float64')
csr_dmatrix = SparseType(format='csr', dtype='float64')
csc_fmatrix = SparseType(format='csc', dtype='float32')
csr_fmatrix = SparseType(format='csr', dtype='float32')
class _sparse_py_operators: class _sparse_py_operators:
T = property(lambda self: transpose(self), doc = "Return aliased transpose of self (read-only)") T = property(lambda self: transpose(self), doc = "Return aliased transpose of self (read-only)")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论