提交 3a591778 authored 作者: Tanjay94's avatar Tanjay94

Fixed rebasing error.

上级 8331e3aa
...@@ -346,11 +346,10 @@ class _sparse_py_operators: ...@@ -346,11 +346,10 @@ class _sparse_py_operators:
ret = get_item_scalar(self, args) ret = get_item_scalar(self, args)
else: else:
ret = get_item_2d(self, args) ret = get_item_2d(self, args)
elif isinstance(args[0], list):
ret = get_item_list(self, args[0])
else: else:
if isinstance(args[0], list): ret = get_item_2d(self, args)
ret = get_item_list(self, args)
else:
ret = get_item_2d(self, args)
return ret return ret
...@@ -995,9 +994,8 @@ class SparseFromDense(gof.op.Op): ...@@ -995,9 +994,8 @@ class SparseFromDense(gof.op.Op):
csr_from_dense = SparseFromDense('csr') csr_from_dense = SparseFromDense('csr')
"""Convert a dense matrix to a sparse csr matrix. """Convert a dense matrix to a sparse csr matrix.
<<<<<<< HEAD """
:param x: A dense matrix.
=======
# Indexing # Indexing
class GetItemList(gof.op.Op): class GetItemList(gof.op.Op):
...@@ -1012,19 +1010,16 @@ class GetItemList(gof.op.Op): ...@@ -1012,19 +1010,16 @@ class GetItemList(gof.op.Op):
assert x.format in ["csr", "csc"] assert x.format in ["csr", "csc"]
ind = tensor.as_tensor_variable(index) ind = tensor.as_tensor_variable(index)
assert ind.ndim == 2 assert ind.ndim == 1
assert 'int' in ind.dtype assert 'int' in ind.dtype
return gof.Apply(self, [x, ind], [x.type()]) return gof.Apply(self, [x, ind], [x.type()])
def perform(self, node, inp, (out, )): def perform(self, node, inp, (out, )):
x = inp[0] x = inp[0]
y = []
indices = inp[1] indices = inp[1]
assert _is_sparse(x) assert _is_sparse(x)
out[0] = x[indices]
for ind in indices:
out[0] = x[indices[0][ind]]
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
......
...@@ -2021,34 +2021,30 @@ class Test_getitem(unittest.TestCase): ...@@ -2021,34 +2021,30 @@ class Test_getitem(unittest.TestCase):
self.rng = numpy.random.RandomState(utt.fetch_seed()) self.rng = numpy.random.RandomState(utt.fetch_seed())
def test_GetItemList(self): def test_GetItemList(self):
rng = numpy.random.RandomState(utt.fetch_seed())
a = sparse.csr_matrix()
b = sparse.csc_matrix()
y = a[[0, 1, 2, 3, 1]]
z = b[[0, 1, 2, 3, 1]]
fa = theano.function([a], y) a, A = sparse_random_inputs('csr', (4, 5))
fb = theano.function([b], z) b, B = sparse_random_inputs('csc', (4, 5))
y = a[0][[0, 1, 2, 3, 1]]
z = b[0][[0, 1, 2, 3, 1]]
A = rng.rand(4, 4).astype(theano.config.floatX) fa = theano.function([a[0]], y)
fb = theano.function([b[0]], z)
t_geta = fa(numpy.asarray(A, dtype="float64")).todense() t_geta = fa(A[0]).todense()
t_getb = fb(numpy.asarray(A, dtype="float64")).todense() t_getb = fb(B[0]).todense()
s_geta = numpy.asarray(scipy.sparse.csr_matrix(A)[[0, 1, 2, 3, 1]].todense(), dtype="float64") s_geta = scipy.sparse.csr_matrix(A[0])[[0, 1, 2, 3, 1]].todense()
s_getb = numpy.asarray(scipy.sparse.csc_matrix(A)[[0, 1, 2, 3, 1]].todense(), dtype="float64") s_getb = scipy.sparse.csc_matrix(B[0])[[0, 1, 2, 3, 1]].todense()
utt.assert_allclose(t_geta, s_geta) utt.assert_allclose(t_geta, s_geta)
utt.assert_allclose(t_getb, s_getb) utt.assert_allclose(t_getb, s_getb)
def test_GetItemList_wrong_index(self): def test_GetItemList_wrong_index(self):
rng = numpy.random.RandomState(utt.fetch_seed()) a, A = sparse_random_inputs('csr', (4, 5))
x = sparse.csr_matrix() y = a[0][[0, 4]]
y = x[[0, 4]] f = theano.function([a[0]], y)
f = theano.function([x], y)
A = rng.rand(2, 2).astype(theano.config.floatX) self.assertRaises(IndexError, f, A[0])
self.assertRaises(IndexError, f, A)
def test_GetItem2D(self): def test_GetItem2D(self):
sparse_formats = ('csc', 'csr') sparse_formats = ('csc', 'csr')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论