提交 519f9d26 authored 作者: Tanjay94's avatar Tanjay94

Added GetItem2Lists function, its grad and all its tests.

上级 bf4954f9
...@@ -1099,16 +1099,14 @@ class GetItem2Lists(gof.op.Op): ...@@ -1099,16 +1099,14 @@ class GetItem2Lists(gof.op.Op):
ind1 = tensor.as_tensor_variable(ind1) ind1 = tensor.as_tensor_variable(ind1)
ind2 = tensor.as_tensor_variable(ind2) ind2 = tensor.as_tensor_variable(ind2)
return gof.Apply(self, [x, ind1, ind2], [theano.tensor.vector()]) return gof.Apply(self, [x, ind1, ind2],
[theano.tensor.vector()])
def perform(self, node, inp, (out, )): def perform(self, node, inp, (out, )):
x = inp[0] x = inp[0]
ind1 = inp[1] ind1 = inp[1]
ind2 = inp[2] ind2 = inp[2]
p = [] out[0] = numpy.asarray(x[ind1, ind2]).flatten()
for i in ind1:
p.append(x[(ind1[i],ind2[i])])
out[0] = numpy.asarray(p)
def grad(self, inputs, g_outputs): def grad(self, inputs, g_outputs):
x, ind1, ind2 = inputs x, ind1, ind2 = inputs
...@@ -1136,13 +1134,11 @@ class GetItem2ListsGrad(gof.op.Op): ...@@ -1136,13 +1134,11 @@ class GetItem2ListsGrad(gof.op.Op):
def make_node(self, x, ind1, ind2, gz): def make_node(self, x, ind1, ind2, gz):
x = as_sparse_variable(x) x = as_sparse_variable(x)
gz = as_sparse_variable(gz)
assert x.format in ["csr", "csc"] assert x.format in ["csr", "csc"]
assert gz.format in ["csr", "csc"]
ind1 = tensor.as_tensor_variable(index) ind1 = tensor.as_tensor_variable(ind1)
ind2 = tensor.as_tensor_variable(index) ind2 = tensor.as_tensor_variable(ind2)
assert ind1.ndim == 1 assert ind1.ndim == 1
assert ind2.ndim == 1 assert ind2.ndim == 1
......
...@@ -32,7 +32,7 @@ from theano.sparse import ( ...@@ -32,7 +32,7 @@ from theano.sparse import (
AddSS, AddSD, MulSS, MulSD, Transpose, Neg, Remove0, AddSS, AddSD, MulSS, MulSD, Transpose, Neg, Remove0,
add, mul, structured_dot, transpose, add, mul, structured_dot, transpose,
csc_from_dense, csr_from_dense, dense_from_sparse, csc_from_dense, csr_from_dense, dense_from_sparse,
Dot, Usmm, sp_ones_like, GetItemScalar, GetItemList, Dot, Usmm, sp_ones_like, GetItemScalar, GetItemList, GetItem2Lists,
SparseFromDense, SparseFromDense,
Cast, cast, HStack, VStack, AddSSData, add_s_s_data, Cast, cast, HStack, VStack, AddSSData, add_s_s_data,
structured_minimum, structured_maximum, structured_add, structured_minimum, structured_maximum, structured_add,
...@@ -2056,6 +2056,41 @@ class Test_getitem(unittest.TestCase): ...@@ -2056,6 +2056,41 @@ class Test_getitem(unittest.TestCase):
verify_grad_sparse(op_with_fixed_index, x_val) verify_grad_sparse(op_with_fixed_index, x_val)
def test_GetItem2Lists(self):
a, A = sparse_random_inputs('csr', (4, 5))
b, B = sparse_random_inputs('csc', (4, 5))
y = a[0][[0, 0, 1, 3], [0, 1, 2, 4]]
z = b[0][[0, 0, 1, 3], [0, 1, 2, 4]]
fa = theano.function([a[0]], y)
fb = theano.function([b[0]], z)
t_geta = fa(A[0])
t_getb = fb(B[0])
s_geta = numpy.asarray(scipy.sparse.csr_matrix(A[0])[[0, 0, 1, 3], [0, 1, 2, 4]])
s_getb = numpy.asarray(scipy.sparse.csc_matrix(B[0])[[0, 0, 1, 3], [0, 1, 2, 4]])
utt.assert_allclose(t_geta, s_geta)
utt.assert_allclose(t_getb, s_getb)
def test_GetItem2Lists_wrong_index(self):
a, A = sparse_random_inputs('csr', (4, 5))
y = a[0][[0, 4], [0, 4]]
f = theano.function([a[0]], y)
self.assertRaises(IndexError, f, A[0])
def test_get_item_2lists_grad(self):
op = theano.sparse.basic.GetItem2Lists()
def op_with_fixed_index(x):
return op(x, ind1=numpy.asarray([0, 1]), ind2=numpy.asarray([1, 1]))
x, x_val = sparse_random_inputs("csr", (4,5), out_dtype="float64")
verify_grad_sparse(op_with_fixed_index, x_val)
def test_GetItem2D(self): def test_GetItem2D(self):
sparse_formats = ('csc', 'csr') sparse_formats = ('csc', 'csr')
for format in sparse_formats: for format in sparse_formats:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论