提交 22c1b380 authored 作者: Anatoly Belikov's avatar Anatoly Belikov

fix grad for axis=0, add tests

上级 d82c903c
...@@ -61,7 +61,7 @@ class SortOp(theano.Op): ...@@ -61,7 +61,7 @@ class SortOp(theano.Op):
inp_grad = theano.gradient.grad_not_implemented( inp_grad = theano.gradient.grad_not_implemented(
self, 0, axis, self, 0, axis,
"Currently, we only implement the gradient on sort for vector" "Currently, we only implement the gradient on sort for vector"
" and matrix (and axis is None or 0)") " matrix (and axis is None or 0) and tensor3")
if a.ndim == 1: if a.ndim == 1:
idx = argsort(*inputs, kind=self.kind, order=self.order) idx = argsort(*inputs, kind=self.kind, order=self.order)
# rev_idx = numpy.where(idx[None, :]==numpy.arange(5)[:,None])[1] # rev_idx = numpy.where(idx[None, :]==numpy.arange(5)[:,None])[1]
...@@ -95,17 +95,24 @@ class SortOp(theano.Op): ...@@ -95,17 +95,24 @@ class SortOp(theano.Op):
return [inp_grad, axis_grad] return [inp_grad, axis_grad]
def __get_argsort_indices(self, a, axis): def __get_argsort_indices(self, a, axis):
"""applies argsort to a along axis, returns indices which """Calculates indices which can be used to reverse
can be used to sort original array""" sorting operation of "a" tensor along "axis"
returns:
1d array if axis is None
list of lenght len(a.shape) otherwise
"""
# The goal is to get gradient wrt input from gradient
# wrt sort(input, axis)
idx = argsort(a, axis, kind=self.kind, order=self.order) idx = argsort(a, axis, kind=self.kind, order=self.order)
# rev_idx is the reverse of previous argsort operation
rev_idx = argsort(idx, axis, kind=self.kind, order=self.order) rev_idx = argsort(idx, axis, kind=self.kind, order=self.order)
if (axis is None or if (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)): (isinstance(axis, theano.Constant) and axis.data is None)):
return rev_idx return rev_idx
indices = [] indices = []
if axis.data > 0: if axis.data >= 0:
axis_data = axis.data axis_data = axis.data
else: else:
axis_data = a.ndim + axis.data axis_data = a.ndim + axis.data
...@@ -115,6 +122,7 @@ class SortOp(theano.Op): ...@@ -115,6 +122,7 @@ class SortOp(theano.Op):
else: else:
index_shape = [1] * a.ndim index_shape = [1] * a.ndim
index_shape[i] = a.shape[i] index_shape[i] = a.shape[i]
# it's a way to emulate numpy.ogrid[0:, a.shape[0], 0:a.shape[1], a.shape[2]]
indices.append(theano.tensor.arange(a.shape[i]).reshape(index_shape)) indices.append(theano.tensor.arange(a.shape[i]).reshape(index_shape))
return indices return indices
""" """
......
...@@ -82,17 +82,24 @@ class test_sort(unittest.TestCase): ...@@ -82,17 +82,24 @@ class test_sort(unittest.TestCase):
utt.verify_grad(lambda x: sort(x, None), [data]) utt.verify_grad(lambda x: sort(x, None), [data])
#utt.verify_grad(lambda x: sort(x, 0), [data]) #utt.verify_grad(lambda x: sort(x, 0), [data])
#utt.verify_grad(lambda x: sort(x, 1), [data]) #utt.verify_grad(lambda x: sort(x, 1), [data])
data = np.random.rand(2, 3, 2).astype(theano.config.floatX) data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, None), [data]) utt.verify_grad(lambda x: sort(x, None), [data])
def test_grad_negative_axis(self): def test_grad_negative_axis(self):
data = np.random.rand(2, 3, 2).astype(theano.config.floatX) data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -1), [data]) utt.verify_grad(lambda x: sort(x, -1), [data])
data = np.random.rand(2, 3, 2).astype(theano.config.floatX) data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -3), [data])
data = np.random.rand(2, 3, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -2), [data]) utt.verify_grad(lambda x: sort(x, -2), [data])
data = np.random.rand(2, 3, 2).astype(theano.config.floatX) data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -3), [data])
def test_grad_nonnegative_axis(self):
data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 0), [data])
data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 1), [data])
data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 2), [data])
class TensorInferShapeTester(utt.InferShapeTester): class TensorInferShapeTester(utt.InferShapeTester):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论