提交 83032266 authored 作者: Tanjay94's avatar Tanjay94

Added grad to GetItemList and its test.

上级 4934189c
...@@ -1407,7 +1407,7 @@ def norm(x,ord): ...@@ -1407,7 +1407,7 @@ def norm(x,ord):
elif ord == -1: elif ord == -1:
return tensor.min(tensor.sum(abs(x),0)) return tensor.min(tensor.sum(abs(x),0))
else: else:
raise ValueError(0) raise ValueError()
elif ndim > 2: elif ndim > 2:
raise NotImplementedError("We don't support norm witn ndim > 2") raise NotImplementedError("We don't support norm witn ndim > 2")
......
...@@ -18,7 +18,7 @@ from theano.gof.python25 import all ...@@ -18,7 +18,7 @@ from theano.gof.python25 import all
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
from theano.sparse.utils import hash_from_sparse from theano.sparse.utils import hash_from_sparse
import theano.tests.unittest_tools as utt import theano.tests.unittest_tools as utt
from theano.gradient import grad_not_implemented from theano.gradient import grad_not_implemented, grad_undefined
from theano.sparse.type import SparseType, _is_sparse from theano.sparse.type import SparseType, _is_sparse
from numpy.lib.stride_tricks import as_strided from numpy.lib.stride_tricks import as_strided
...@@ -1007,13 +1007,15 @@ class GetItemList(gof.op.Op): ...@@ -1007,13 +1007,15 @@ class GetItemList(gof.op.Op):
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def infer_shape(self, node, shapes):
return [(shapes[1][0], shapes[0][1])]
def make_node(self, x, index): def make_node(self, x, index):
x = as_sparse_variable(x) x = as_sparse_variable(x)
assert x.format in ["csr", "csc"] assert x.format in ["csr", "csc"]
ind = tensor.as_tensor_variable(index) ind = tensor.as_tensor_variable(index)
assert ind.ndim == 1 assert ind.ndim == 1
assert 'int' in ind.dtype
return gof.Apply(self, [x, ind], [x.type()]) return gof.Apply(self, [x, ind], [x.type()])
...@@ -1023,13 +1025,63 @@ class GetItemList(gof.op.Op): ...@@ -1023,13 +1025,63 @@ class GetItemList(gof.op.Op):
assert _is_sparse(x) assert _is_sparse(x)
out[0] = x[indices] out[0] = x[indices]
def grad(self, inputs, g_outputs):
x, indices = inputs
gout, = g_outputs
return [GetItemListGrad(self)(x, indices, gout),
grad_undefined(self, 1, indices, "No gradient for this input")]
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
get_item_list = GetItemList() get_item_list = GetItemList()
# Indexing class GetItemListGrad(gof.op.Op):
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def infer_shape(self, node, shapes):
return [(shapes[0][0], shapes[0][1])]
def make_node(self, x, index, gz):
x = as_sparse_variable(x)
gz = as_sparse_variable(gz)
assert x.format in ["csr", "csc"]
assert gz.format in ["csr", "csc"]
ind = tensor.as_tensor_variable(index)
assert ind.ndim == 1
return gof.Apply(self, [x, ind, gz], [x.type()])
def perform(self, node, inp, (out, )):
x = inp[0]
indices = inp[1]
gz = inp[2]
if x.format in ["csr"]:
y = scipy.sparse.csr_matrix((x.shape[0], x.shape[1]))
else:
y = scipy.sparse.csc_matrix((x.shape[0], x.shape[1]))
z = 0
for i in indices:
y[i] = gz[z]
z = z+1
out[0] = y
def __str__(self):
return self.__class__.__name__
get_item_list_grad = GetItemListGrad()
class GetItem2d(gof.op.Op): class GetItem2d(gof.op.Op):
# See doc in instance of this Op or function after this class definition. # See doc in instance of this Op or function after this class definition.
def __eq__(self, other): def __eq__(self, other):
......
...@@ -6,6 +6,7 @@ import numpy ...@@ -6,6 +6,7 @@ import numpy
try: try:
import scipy.sparse as sp import scipy.sparse as sp
import scipy.sparse import scipy.sparse
from scipy.sparse import csr_matrix
except ImportError: except ImportError:
pass # The variable enable_sparse will be used to disable the test file. pass # The variable enable_sparse will be used to disable the test file.
...@@ -31,7 +32,7 @@ from theano.sparse import ( ...@@ -31,7 +32,7 @@ from theano.sparse import (
AddSS, AddSD, MulSS, MulSD, Transpose, Neg, Remove0, AddSS, AddSD, MulSS, MulSD, Transpose, Neg, Remove0,
add, mul, structured_dot, transpose, add, mul, structured_dot, transpose,
csc_from_dense, csr_from_dense, dense_from_sparse, csc_from_dense, csr_from_dense, dense_from_sparse,
Dot, Usmm, sp_ones_like, GetItemScalar, Dot, Usmm, sp_ones_like, GetItemScalar, GetItemList,
SparseFromDense, SparseFromDense,
Cast, cast, HStack, VStack, AddSSData, add_s_s_data, Cast, cast, HStack, VStack, AddSSData, add_s_s_data,
structured_minimum, structured_maximum, structured_add, structured_minimum, structured_maximum, structured_add,
...@@ -2046,6 +2047,15 @@ class Test_getitem(unittest.TestCase): ...@@ -2046,6 +2047,15 @@ class Test_getitem(unittest.TestCase):
self.assertRaises(IndexError, f, A[0]) self.assertRaises(IndexError, f, A[0])
def test_get_item_list_grad(self):
op = theano.sparse.basic.GetItemList()
def op_with_fixed_index(x):
return op(x, index=numpy.asarray([0, 1]))
x, x_val = sparse_random_inputs("csr", (4,5), out_dtype="float64")
verify_grad_sparse(op_with_fixed_index, x_val)
def test_GetItem2D(self): def test_GetItem2D(self):
sparse_formats = ('csc', 'csr') sparse_formats = ('csc', 'csr')
for format in sparse_formats: for format in sparse_formats:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论