提交 3b24a199 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #2910 from noskill/master

implementing gradient for Sort op for tensor3
...@@ -61,7 +61,7 @@ class SortOp(theano.Op): ...@@ -61,7 +61,7 @@ class SortOp(theano.Op):
inp_grad = theano.gradient.grad_not_implemented( inp_grad = theano.gradient.grad_not_implemented(
self, 0, axis, self, 0, axis,
"Currently, we only implement the gradient on sort for vector" "Currently, we only implement the gradient on sort for vector"
" and matrix (and axis is None or 0)") " matrix (and axis is None or 0) and tensor3")
if a.ndim == 1: if a.ndim == 1:
idx = argsort(*inputs, kind=self.kind, order=self.order) idx = argsort(*inputs, kind=self.kind, order=self.order)
# rev_idx = numpy.where(idx[None, :]==numpy.arange(5)[:,None])[1] # rev_idx = numpy.where(idx[None, :]==numpy.arange(5)[:,None])[1]
...@@ -80,11 +80,51 @@ class SortOp(theano.Op): ...@@ -80,11 +80,51 @@ class SortOp(theano.Op):
idx = argsort(*inputs, kind=self.kind, order=self.order) idx = argsort(*inputs, kind=self.kind, order=self.order)
# not working: numpy.where(idx[None, :]==numpy.arange(2)[:, None, None]) # not working: numpy.where(idx[None, :]==numpy.arange(2)[:, None, None])
pass pass
elif a.ndim == 3:
if isinstance(axis, theano.Constant) and axis.data is not None:
indices = self.__get_argsort_indices(a, axis)
inp_grad = output_grads[0][indices[0], indices[1], indices[2]]
elif (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)):
rev_idx = self.__get_argsort_indices(a, axis)
inp_grad = output_grads[0][rev_idx].reshape(a.shape)
axis_grad = theano.gradient.grad_undefined( axis_grad = theano.gradient.grad_undefined(
self, 1, axis, self, 1, axis,
"sort is not defined for non-integer axes so" "sort is not defined for non-integer axes so"
" sort(x, axis+eps) is undefined") " sort(x, axis+eps) is undefined")
return [inp_grad, axis_grad] return [inp_grad, axis_grad]
def __get_argsort_indices(self, a, axis):
"""Calculates indices which can be used to reverse
sorting operation of "a" tensor along "axis"
returns:
1d array if axis is None
list of lenght len(a.shape) otherwise
"""
# The goal is to get gradient wrt input from gradient
# wrt sort(input, axis)
idx = argsort(a, axis, kind=self.kind, order=self.order)
# rev_idx is the reverse of previous argsort operation
rev_idx = argsort(idx, axis, kind=self.kind, order=self.order)
if (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)):
return rev_idx
indices = []
if axis.data >= 0:
axis_data = axis.data
else:
axis_data = a.ndim + axis.data
for i in range(a.ndim):
if i == axis_data:
indices.append(rev_idx)
else:
index_shape = [1] * a.ndim
index_shape[i] = a.shape[i]
# it's a way to emulate numpy.ogrid[0: a.shape[0], 0: a.shape[1], 0: a.shape[2]]
indices.append(theano.tensor.arange(a.shape[i]).reshape(index_shape))
return indices
""" """
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
# R_op can receive None as eval_points. # R_op can receive None as eval_points.
......
...@@ -82,6 +82,24 @@ class test_sort(unittest.TestCase): ...@@ -82,6 +82,24 @@ class test_sort(unittest.TestCase):
utt.verify_grad(lambda x: sort(x, None), [data]) utt.verify_grad(lambda x: sort(x, None), [data])
#utt.verify_grad(lambda x: sort(x, 0), [data]) #utt.verify_grad(lambda x: sort(x, 0), [data])
#utt.verify_grad(lambda x: sort(x, 1), [data]) #utt.verify_grad(lambda x: sort(x, 1), [data])
data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, None), [data])
def test_grad_negative_axis(self):
data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -1), [data])
data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -2), [data])
data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -3), [data])
def test_grad_nonnegative_axis(self):
data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 0), [data])
data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 1), [data])
data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 2), [data])
class TensorInferShapeTester(utt.InferShapeTester): class TensorInferShapeTester(utt.InferShapeTester):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论