提交 6a485e55 authored 作者: Frederic's avatar Frederic

Add doc about ConstructSparseFromList and clean up the code a little.

上级 82368140
...@@ -3316,9 +3316,9 @@ class ConstructSparseFromList(gof.Op): ...@@ -3316,9 +3316,9 @@ class ConstructSparseFromList(gof.Op):
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
def make_node(self, x, y, ilist): def make_node(self, x, values, ilist):
x_ = theano.tensor.as_tensor_variable(x) x_ = theano.tensor.as_tensor_variable(x)
y_ = theano.tensor.as_tensor_variable(y) values_ = theano.tensor.as_tensor_variable(values)
ilist_ = theano.tensor.as_tensor_variable(ilist) ilist_ = theano.tensor.as_tensor_variable(ilist)
if ilist_.type.dtype[:3] not in ('int', 'uin'): if ilist_.type.dtype[:3] not in ('int', 'uin'):
...@@ -3327,21 +3327,31 @@ class ConstructSparseFromList(gof.Op): ...@@ -3327,21 +3327,31 @@ class ConstructSparseFromList(gof.Op):
raise TypeError('index must be vector') raise TypeError('index must be vector')
if x_.type.ndim == 0: if x_.type.ndim == 0:
raise TypeError('cannot index into a scalar') raise TypeError('cannot index into a scalar')
if y_.type.ndim > x_.type.ndim: if values_.type.ndim > x_.type.ndim:
raise TypeError('cannot construct sparse matrix as dimensions differ') raise TypeError('cannot construct sparse matrix as dimensions differ')
return gof.Apply(self, [x_, y_, ilist_], [theano.sparse.csc_matrix(dtype=x.dtype)]) return gof.Apply(self, [x_, values_, ilist_],
[csc_matrix(dtype=x.dtype)])
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
x, values, idx = inp """
:param inp: tuple(x, values, ilist)
x: specify the output shape and dtype only.
values: a matrix with the values that we want in the output
ilist: a vector with the same lenght as the number of row
then values. It specify where in the output to put
the corresponding rows.
"""
x, values, ilist = inp
out, = out_ out, = out_
rows, cols = values.shape rows, cols = values.shape
assert rows == len(idx) assert rows == len(ilist)
indptr = numpy.arange(cols + 1) * rows indptr = numpy.arange(cols + 1) * rows
indices = as_strided(idx, indices = as_strided(ilist,
strides=(0, idx.strides[0]), strides=(0, ilist.strides[0]),
shape = (cols, idx.shape[0])).flatten() shape=(cols, ilist.shape[0])).flatten()
data = values.T.flatten() data = values.T.flatten()
out[0] = scipy.sparse.csc_matrix((data, indices, indptr), shape=x.shape, out[0] = scipy.sparse.csc_matrix((data, indices, indptr),
shape=x.shape,
dtype=x.dtype) dtype=x.dtype)
def infer_shape(self, node, ishapes): def infer_shape(self, node, ishapes):
...@@ -3368,3 +3378,5 @@ class ConstructSparseFromList(gof.Op): ...@@ -3368,3 +3378,5 @@ class ConstructSparseFromList(gof.Op):
gy = theano.tensor.advanced_subtensor1(g_output, *idx_list) gy = theano.tensor.advanced_subtensor1(g_output, *idx_list)
return [gx, gy] + [DisconnectedType()()] * len(idx_list) return [gx, gy] + [DisconnectedType()()] * len(idx_list)
construct_sparse_from_list = ConstructSparseFromList()
...@@ -6929,7 +6929,6 @@ class AdvancedSubtensor1(Op): ...@@ -6929,7 +6929,6 @@ class AdvancedSubtensor1(Op):
out[0] = x.take(i, axis=0, out=o) out[0] = x.take(i, axis=0, out=o)
def connection_pattern(self, node): def connection_pattern(self, node):
rval = [[True]] rval = [[True]]
for ipt in node.inputs[1:]: for ipt in node.inputs[1:]:
...@@ -6939,17 +6938,18 @@ class AdvancedSubtensor1(Op): ...@@ -6939,17 +6938,18 @@ class AdvancedSubtensor1(Op):
def grad(self, inputs, grads): def grad(self, inputs, grads):
global sparse_module_ref global sparse_module_ref
x, ilist = inputs
gz, = grads gz, = grads
assert len(inputs) == 2 assert len(inputs) == 2
if inputs[0].type.sparse_grad:
if x.type.sparse_grad:
if sparse_module_ref is None: if sparse_module_ref is None:
import theano.sparse as sparse_module_ref import theano.sparse as sparse_module_ref
rval1 = [sparse_module_ref.ConstructSparseFromList()( rval1 = [sparse_module_ref.construct_sparse_from_list(x, gz,
(inputs[0]), gz, inputs[1])] ilist)]
else: else:
rval1 = [advanced_inc_subtensor1( rval1 = [advanced_inc_subtensor1(zeros_like(x), gz, ilist)]
zeros_like(inputs[0]), gz, inputs[1])]
return rval1 + [DisconnectedType()()] * (len(inputs) - 1) return rval1 + [DisconnectedType()()] * (len(inputs) - 1)
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论