提交 30921ef6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Handle static shape in core sparse methods

上级 ec2351bf
......@@ -207,19 +207,19 @@ def sp_zeros_like(x):
# for more dtypes, call SparseTensorType(format, dtype)
def matrix(format, name=None, dtype=None):
def matrix(format, name=None, dtype=None, shape=None):
if dtype is None:
dtype = config.floatX
type = SparseTensorType(format=format, dtype=dtype)
type = SparseTensorType(format=format, dtype=dtype, shape=shape)
return type(name)
def csc_matrix(name=None, dtype=None):
return matrix("csc", name, dtype)
def csc_matrix(name=None, dtype=None, shape=None):
return matrix("csc", name=name, dtype=dtype, shape=shape)
def csr_matrix(name=None, dtype=None):
return matrix("csr", name, dtype)
def csr_matrix(name=None, dtype=None, shape=None):
return matrix("csr", name=name, dtype=dtype, shape=shape)
def bsr_matrix(name=None, dtype=None):
......@@ -434,10 +434,22 @@ class CSM(Op):
if shape.type.ndim != 1 or shape.type.dtype not in discrete_dtypes:
raise TypeError("n_rows must be integer type", shape, shape.type)
static_shape = (None, None)
if (
shape.owner is not None
and isinstance(shape.owner.op, CSMProperties)
and shape.owner.outputs[3] is shape
):
static_shape = shape.owner.inputs[0].type.shape
return Apply(
self,
[data, indices, indptr, shape],
[SparseTensorType(dtype=data.type.dtype, format=self.format)()],
[
SparseTensorType(
dtype=data.type.dtype, format=self.format, shape=static_shape
)()
],
)
def perform(self, node, inputs, outputs):
......@@ -698,7 +710,7 @@ class DenseFromSparse(Op):
return Apply(
self,
[x],
[TensorType(dtype=x.type.dtype, shape=(None, None))()],
[TensorType(dtype=x.type.dtype, shape=x.type.shape)()],
)
def perform(self, node, inputs, outputs):
......
......@@ -127,6 +127,8 @@ class _sparse_py_operators:
def toarray(self):
return dense_from_sparse(self)
todense = toarray
@property
def shape(self):
# TODO: The plan is that the ShapeFeature in ptb.opt will do shape
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论