提交 8cfc382a authored 作者: James Bergstra's avatar James Bergstra

sparse - added name parameter to variable constructor

上级 9b304769
...@@ -75,7 +75,7 @@ def _kmap_hash(a): ...@@ -75,7 +75,7 @@ def _kmap_hash(a):
# Wrapper type # Wrapper type
def as_sparse_variable(x): def as_sparse_variable(x, name=None):
""" """
Wrapper around SparseVariable constructor. Wrapper around SparseVariable constructor.
@param x: A sparse matrix. as_sparse_variable reads dtype and format properties @param x: A sparse matrix. as_sparse_variable reads dtype and format properties
...@@ -94,18 +94,18 @@ def as_sparse_variable(x): ...@@ -94,18 +94,18 @@ def as_sparse_variable(x):
raise TypeError("Variable type field must be a SparseType.", x, x.type) raise TypeError("Variable type field must be a SparseType.", x, x.type)
return x return x
try: try:
return constant(x) return constant(x, name=name)
except TypeError: except TypeError:
raise TypeError("Cannot convert %s to SparseType" % x, type(x)) raise TypeError("Cannot convert %s to SparseType" % x, type(x))
as_sparse = as_sparse_variable as_sparse = as_sparse_variable
def constant(x): def constant(x, name=None):
if not isinstance(x, scipy.sparse.spmatrix): if not isinstance(x, scipy.sparse.spmatrix):
raise TypeError("sparse.constant must be called on a scipy.sparse.spmatrix") raise TypeError("sparse.constant must be called on a scipy.sparse.spmatrix")
try: try:
return SparseConstant(SparseType(format = x.format, return SparseConstant(SparseType(format = x.format,
dtype = x.dtype), x.copy()) dtype = x.dtype), x.copy(),name=name)
except TypeError: except TypeError:
raise TypeError("Could not convert %s to SparseType" % x, type(x)) raise TypeError("Could not convert %s to SparseType" % x, type(x))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论