提交 e493985e authored 作者: abergeron's avatar abergeron

Merge pull request #1923 from nouiz/sort_grad

Add sort.grad support for vector and matrix when axis is None.
......@@ -3,7 +3,7 @@ import numpy as np
import theano
from theano.tensor import tensor
from theano.tensor.basic import mul
from theano.tensor.basic import mul, arange
class SortOp(theano.Op):
......@@ -27,7 +27,8 @@ class SortOp(theano.Op):
def make_node(self, input, axis=-1):
input = theano.tensor.as_tensor_variable(input)
if axis is None:
if (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)):
axis = theano.Constant(theano.gof.generic, None)
# axis=None flattens the array before sorting
out_type = tensor(dtype=input.dtype, broadcastable=[False])
......@@ -55,8 +56,35 @@ class SortOp(theano.Op):
assert inputs_shapes[1] == ()
return [inputs_shapes[0]]
#**** It need the argsort, so we can't do it now.
#def grad(self, inputs, output_grads):
def grad(self, inputs, output_grads):
a, axis = inputs
inp_grad = theano.gradient.grad_not_implemented(
self, 0, axis,
"Currently, we only implement the gradient on sort for vector"
" and matrix (and axis is None)")
if a.ndim == 1:
idx = argsort(*inputs, kind=self.kind, order=self.order)
# rev_idx = numpy.where(idx[None, :]==numpy.arange(5)[:,None])[1]
rev_idx = theano.tensor.eq(idx[None, :],
arange(a.shape[0])[:, None]).nonzero()[1]
inp_grad = output_grads[0][rev_idx]
elif a.ndim == 2:
if (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)):
idx = argsort(*inputs, kind=self.kind, order=self.order)
rev_idx = theano.tensor.eq(idx[None, :],
arange(a.shape[0]*a.shape[1])[:, None]).nonzero()[1]
inp_grad = output_grads[0][rev_idx].reshape(a.shape)
elif (axis == 0 or
(isinstance(axis, theano.Constant) and axis.data == 0)):
idx = argsort(*inputs, kind=self.kind, order=self.order)
#not working: numpy.where(idx[None, :]==numpy.arange(2)[:, None, None])
pass
axis_grad = theano.gradient.grad_undefined(
self, 1, axis,
"sort is not defined for non-integer axes so"
" sort(x, axis+eps) is undefined")
return [inp_grad, axis_grad]
"""
def R_op(self, inputs, eval_points):
# R_op can receive None as eval_points.
......@@ -115,7 +143,8 @@ class ArgSortOp(theano.Op):
def make_node(self, input, axis=-1):
input = theano.tensor.as_tensor_variable(input)
if axis is None:
if (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)):
axis = theano.Constant(theano.gof.generic, None)
bcast = [False]
else:
......
......@@ -68,6 +68,21 @@ class test_sort(unittest.TestCase):
gt = np.sort(self.m_val, None)
assert np.allclose(gv, gt)
def test_grad_vector(self):
a = theano.tensor.vector()
data = np.random.rand(10).astype(theano.config.floatX)
utt.verify_grad(sort, [data])
def test_grad_none_axis(self):
data = np.random.rand(10).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, None), [data])
utt.verify_grad(lambda x: sort(x, 0), [data])
data = np.random.rand(2, 3).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, None), [data])
#utt.verify_grad(lambda x: sort(x, 0), [data])
#utt.verify_grad(lambda x: sort(x, 1), [data])
class TensorInferShapeTester(utt.InferShapeTester):
def test_sort(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论