提交 e72884a4 authored 作者: Frederic's avatar Frederic

Make the apply node of ConstructSparseFromList take the shape as input.

This allow to remove the value itself from the graph earlier.
上级 41d61600
......@@ -3317,6 +3317,18 @@ class ConstructSparseFromList(gof.Op):
return self.__class__.__name__
def make_node(self, x, values, ilist):
"""
:param x: a matrix that specify the output shape.
:param values: a matrix with the values that we want in the output.
:param ilist: a vector with the same lenght as the number of rows
then values. It specify where in the output to put
the corresponding rows.
This create a sparse matrix with the same shape as `x`. Its
values are the are the rows of `values` moved to the `ilist`
corresponding rows.
"""
x_ = theano.tensor.as_tensor_variable(x)
values_ = theano.tensor.as_tensor_variable(values)
ilist_ = theano.tensor.as_tensor_variable(ilist)
......@@ -3325,23 +3337,19 @@ class ConstructSparseFromList(gof.Op):
raise TypeError('index must be integers')
if ilist_.type.ndim != 1:
raise TypeError('index must be vector')
if x_.type.ndim == 0:
raise TypeError('cannot index into a scalar')
if values_.type.ndim > x_.type.ndim:
raise TypeError('cannot construct sparse matrix as dimensions differ')
return gof.Apply(self, [x_, values_, ilist_],
if x_.type.ndim != 2:
raise TypeError(
'cannot create a sparse matrix with %d dimensions' %
x_.type.ndim)
if values_.type.ndim != 2:
raise TypeError(
'cannot create a sparse matrix from values with %d ndim' %
values_.type.ndim)
return gof.Apply(self, [x_.shape, values_, ilist_],
[csc_matrix(dtype=x.dtype)])
def perform(self, node, inp, out_):
"""
:param inp: tuple(x, values, ilist)
x: specify the output shape and dtype only.
values: a matrix with the values that we want in the output
ilist: a vector with the same lenght as the number of row
then values. It specify where in the output to put
the corresponding rows.
"""
x, values, ilist = inp
out_shape, values, ilist = inp
out, = out_
rows, cols = values.shape
assert rows == len(ilist)
......@@ -3351,8 +3359,8 @@ class ConstructSparseFromList(gof.Op):
shape=(cols, ilist.shape[0])).flatten()
data = values.T.flatten()
out[0] = scipy.sparse.csc_matrix((data, indices, indptr),
shape=x.shape,
dtype=x.dtype)
shape=out_shape,
dtype=values.dtype)
def infer_shape(self, node, ishapes):
x, y, ilist = ishapes
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论