提交 8331e3aa authored 作者: Tanjay94's avatar Tanjay94

Added Fancy Indexing for sparse matrix.

上级 c4266e45
...@@ -347,7 +347,10 @@ class _sparse_py_operators: ...@@ -347,7 +347,10 @@ class _sparse_py_operators:
else: else:
ret = get_item_2d(self, args) ret = get_item_2d(self, args)
else: else:
ret = get_item_2d(self, args) if isinstance(args[0], list):
ret = get_item_list(self, args)
else:
ret = get_item_2d(self, args)
return ret return ret
...@@ -992,7 +995,47 @@ class SparseFromDense(gof.op.Op): ...@@ -992,7 +995,47 @@ class SparseFromDense(gof.op.Op):
csr_from_dense = SparseFromDense('csr') csr_from_dense = SparseFromDense('csr')
"""Convert a dense matrix to a sparse csr matrix. """Convert a dense matrix to a sparse csr matrix.
<<<<<<< HEAD
:param x: A dense matrix. :param x: A dense matrix.
=======
# Indexing
class GetItemList(gof.op.Op):
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x, index):
x = as_sparse_variable(x)
assert x.format in ["csr", "csc"]
ind = tensor.as_tensor_variable(index)
assert ind.ndim == 2
assert 'int' in ind.dtype
return gof.Apply(self, [x, ind], [x.type()])
def perform(self, node, inp, (out, )):
x = inp[0]
y = []
indices = inp[1]
assert _is_sparse(x)
for ind in indices:
out[0] = x[indices[0][ind]]
def __str__(self):
return self.__class__.__name__
get_item_list = GetItemList()
class GetItem2d(gof.op.Op):
"""Implement a subtensor of sparse variable and that return a
sparse matrix.
>>>>>>> Added Fancy Indexing for sparse matrix.
:return: The same as `x` in a sparse matrix format. :return: The same as `x` in a sparse matrix format.
......
...@@ -15,7 +15,7 @@ from theano import sparse ...@@ -15,7 +15,7 @@ from theano import sparse
from theano import compile, config, gof from theano import compile, config, gof
from theano.sparse import enable_sparse from theano.sparse import enable_sparse
from theano.gof.python25 import all, any, product from theano.gof.python25 import all, any, product
from theano.tensor.basic import _allclose
if not enable_sparse: if not enable_sparse:
raise SkipTest('Optional package sparse disabled') raise SkipTest('Optional package sparse disabled')
...@@ -2020,6 +2020,36 @@ class Test_getitem(unittest.TestCase): ...@@ -2020,6 +2020,36 @@ class Test_getitem(unittest.TestCase):
def setUp(self): def setUp(self):
self.rng = numpy.random.RandomState(utt.fetch_seed()) self.rng = numpy.random.RandomState(utt.fetch_seed())
def test_GetItemList(self):
rng = numpy.random.RandomState(utt.fetch_seed())
a = sparse.csr_matrix()
b = sparse.csc_matrix()
y = a[[0, 1, 2, 3, 1]]
z = b[[0, 1, 2, 3, 1]]
fa = theano.function([a], y)
fb = theano.function([b], z)
A = rng.rand(4, 4).astype(theano.config.floatX)
t_geta = fa(numpy.asarray(A, dtype="float64")).todense()
t_getb = fb(numpy.asarray(A, dtype="float64")).todense()
s_geta = numpy.asarray(scipy.sparse.csr_matrix(A)[[0, 1, 2, 3, 1]].todense(), dtype="float64")
s_getb = numpy.asarray(scipy.sparse.csc_matrix(A)[[0, 1, 2, 3, 1]].todense(), dtype="float64")
utt.assert_allclose(t_geta, s_geta)
utt.assert_allclose(t_getb, s_getb)
def test_GetItemList_wrong_index(self):
rng = numpy.random.RandomState(utt.fetch_seed())
x = sparse.csr_matrix()
y = x[[0, 4]]
f = theano.function([x], y)
A = rng.rand(2, 2).astype(theano.config.floatX)
self.assertRaises(IndexError, f, A)
def test_GetItem2D(self): def test_GetItem2D(self):
sparse_formats = ('csc', 'csr') sparse_formats = ('csc', 'csr')
for format in sparse_formats: for format in sparse_formats:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论