提交 30921ef6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Handle static shape in core sparse methods

上级 ec2351bf
...@@ -207,19 +207,19 @@ def sp_zeros_like(x): ...@@ -207,19 +207,19 @@ def sp_zeros_like(x):
# for more dtypes, call SparseTensorType(format, dtype) # for more dtypes, call SparseTensorType(format, dtype)
def matrix(format, name=None, dtype=None): def matrix(format, name=None, dtype=None, shape=None):
if dtype is None: if dtype is None:
dtype = config.floatX dtype = config.floatX
type = SparseTensorType(format=format, dtype=dtype) type = SparseTensorType(format=format, dtype=dtype, shape=shape)
return type(name) return type(name)
def csc_matrix(name=None, dtype=None): def csc_matrix(name=None, dtype=None, shape=None):
return matrix("csc", name, dtype) return matrix("csc", name=name, dtype=dtype, shape=shape)
def csr_matrix(name=None, dtype=None): def csr_matrix(name=None, dtype=None, shape=None):
return matrix("csr", name, dtype) return matrix("csr", name=name, dtype=dtype, shape=shape)
def bsr_matrix(name=None, dtype=None): def bsr_matrix(name=None, dtype=None):
...@@ -434,10 +434,22 @@ class CSM(Op): ...@@ -434,10 +434,22 @@ class CSM(Op):
if shape.type.ndim != 1 or shape.type.dtype not in discrete_dtypes: if shape.type.ndim != 1 or shape.type.dtype not in discrete_dtypes:
raise TypeError("n_rows must be integer type", shape, shape.type) raise TypeError("n_rows must be integer type", shape, shape.type)
static_shape = (None, None)
if (
shape.owner is not None
and isinstance(shape.owner.op, CSMProperties)
and shape.owner.outputs[3] is shape
):
static_shape = shape.owner.inputs[0].type.shape
return Apply( return Apply(
self, self,
[data, indices, indptr, shape], [data, indices, indptr, shape],
[SparseTensorType(dtype=data.type.dtype, format=self.format)()], [
SparseTensorType(
dtype=data.type.dtype, format=self.format, shape=static_shape
)()
],
) )
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
...@@ -698,7 +710,7 @@ class DenseFromSparse(Op): ...@@ -698,7 +710,7 @@ class DenseFromSparse(Op):
return Apply( return Apply(
self, self,
[x], [x],
[TensorType(dtype=x.type.dtype, shape=(None, None))()], [TensorType(dtype=x.type.dtype, shape=x.type.shape)()],
) )
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
......
...@@ -127,6 +127,8 @@ class _sparse_py_operators: ...@@ -127,6 +127,8 @@ class _sparse_py_operators:
def toarray(self): def toarray(self):
return dense_from_sparse(self) return dense_from_sparse(self)
todense = toarray
@property @property
def shape(self): def shape(self):
# TODO: The plan is that the ShapeFeature in ptb.opt will do shape # TODO: The plan is that the ShapeFeature in ptb.opt will do shape
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论