提交 b736b12e authored 作者: Frederic's avatar Frederic

Add function that make code easier in sparse.

上级 d47e2c12
...@@ -327,6 +327,16 @@ class SparseType(gof.Type): ...@@ -327,6 +327,16 @@ class SparseType(gof.Type):
return scipy.sparse.issparse(a) and (a.format == self.format) return scipy.sparse.issparse(a) and (a.format == self.format)
# for more dtypes, call SparseType(format, dtype) # for more dtypes, call SparseType(format, dtype)
def matrix(format, name=None, dtype=None):
if dtype is None:
dtype = config.floatX
type = SparseType(format=format, dtype=dtype)
return type(name)
def csc_matrix(name=None, dtype=None):
return matrix('csc', name, dtype)
def csr_matrix(name=None, dtype=None):
return matrix('csr', name, dtype)
# for more dtypes, call SparseType(format, dtype)
csc_matrix = SparseType(format='csc', dtype=config.floatX) csc_matrix = SparseType(format='csc', dtype=config.floatX)
csr_matrix = SparseType(format='csr', dtype=config.floatX) csr_matrix = SparseType(format='csr', dtype=config.floatX)
csc_dmatrix = SparseType(format='csc', dtype='float64') csc_dmatrix = SparseType(format='csc', dtype='float64')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论