Added some documentation.

上级 c8c2c4f7
...@@ -9,7 +9,12 @@ import tensor ...@@ -9,7 +9,12 @@ import tensor
# Wrapper type # Wrapper type
def assparse(sp, **kwargs): def assparse(sp, **kwargs):
"""Return SparseR version of sp""" """
Wrapper around SparseR constructor.
@param sp: A sparse matrix. assparse reads dtype and format properties
out of this sparse matrix.
@return: SparseR version of sp.
"""
if isinstance(sp, SparseR): if isinstance(sp, SparseR):
return sp return sp
else: else:
...@@ -37,6 +42,13 @@ class SparseR(gof.result.ResultBase): ...@@ -37,6 +42,13 @@ class SparseR(gof.result.ResultBase):
dtype_set = set(['int', 'int32', 'int64', 'float32', 'float64']) dtype_set = set(['int', 'int32', 'int64', 'float32', 'float64'])
def __init__(self, dtype, format, **kwargs): def __init__(self, dtype, format, **kwargs):
"""
Fundamental way to do create a sparse node.
@param dtype: Type of numbers in the matrix.
@param format: The sparse storage strategy.
@return An empty SparseR instance.
"""
gof.ResultBase.__init__(self, **kwargs) gof.ResultBase.__init__(self, **kwargs)
if dtype in SparseR.dtype_set: if dtype in SparseR.dtype_set:
self._dtype = dtype self._dtype = dtype
...@@ -165,8 +177,10 @@ if 0: ...@@ -165,8 +177,10 @@ if 0:
def gen_outputs(self): return [SparseR()] def gen_outputs(self): return [SparseR()]
def impl(x,y): def impl(x,y):
if hasattr(x, 'getnnz'): if hasattr(x, 'getnnz'):
# if x is sparse, then do this.
return x.dot(y) return x.dot(y)
else: else:
# if x is dense (and y is sparse), we do this
return y.transpose().dot(x.transpose()).transpose() return y.transpose().dot(x.transpose()).transpose()
def grad(self, x, y, gz): def grad(self, x, y, gz):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论