提交 d82c903c authored 作者: Anatoly Belikov's avatar Anatoly Belikov

implementing gradient for Sort op for tensor3

上级 51368a6b
......@@ -80,11 +80,43 @@ class SortOp(theano.Op):
idx = argsort(*inputs, kind=self.kind, order=self.order)
# not working: numpy.where(idx[None, :]==numpy.arange(2)[:, None, None])
pass
elif a.ndim == 3:
if isinstance(axis, theano.Constant) and axis.data is not None:
indices = self.__get_argsort_indices(a, axis)
inp_grad = output_grads[0].reshape(a.shape)[indices[0], indices[1], indices[2]]
elif (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)):
rev_idx = self.__get_argsort_indices(a, axis)
inp_grad = output_grads[0][rev_idx].reshape(a.shape)
axis_grad = theano.gradient.grad_undefined(
self, 1, axis,
"sort is not defined for non-integer axes so"
" sort(x, axis+eps) is undefined")
return [inp_grad, axis_grad]
def __get_argsort_indices(self, a, axis):
"""applies argsort to a along axis, returns indices which
can be used to sort original array"""
idx = argsort(a, axis, kind=self.kind, order=self.order)
rev_idx = argsort(idx, axis, kind=self.kind, order=self.order)
if (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)):
return rev_idx
indices = []
if axis.data > 0:
axis_data = axis.data
else:
axis_data = a.ndim + axis.data
for i in range(a.ndim):
if i == axis_data:
indices.append(rev_idx)
else:
index_shape = [1] * a.ndim
index_shape[i] = a.shape[i]
indices.append(theano.tensor.arange(a.shape[i]).reshape(index_shape))
return indices
"""
def R_op(self, inputs, eval_points):
# R_op can receive None as eval_points.
......
......@@ -82,6 +82,17 @@ class test_sort(unittest.TestCase):
utt.verify_grad(lambda x: sort(x, None), [data])
#utt.verify_grad(lambda x: sort(x, 0), [data])
#utt.verify_grad(lambda x: sort(x, 1), [data])
data = np.random.rand(2, 3, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, None), [data])
def test_grad_negative_axis(self):
data = np.random.rand(2, 3, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -1), [data])
data = np.random.rand(2, 3, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -3), [data])
data = np.random.rand(2, 3, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -2), [data])
data = np.random.rand(2, 3, 2).astype(theano.config.floatX)
class TensorInferShapeTester(utt.InferShapeTester):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论