提交 7ffa8835 authored 作者: Frederic's avatar Frederic

Add sort.grad support for vector and matrix when axis is None.

上级 d82eb54a
...@@ -3,7 +3,7 @@ import numpy as np ...@@ -3,7 +3,7 @@ import numpy as np
import theano import theano
from theano.tensor import tensor from theano.tensor import tensor
from theano.tensor.basic import mul from theano.tensor.basic import mul, arange
class SortOp(theano.Op): class SortOp(theano.Op):
...@@ -27,7 +27,8 @@ class SortOp(theano.Op): ...@@ -27,7 +27,8 @@ class SortOp(theano.Op):
def make_node(self, input, axis=-1): def make_node(self, input, axis=-1):
input = theano.tensor.as_tensor_variable(input) input = theano.tensor.as_tensor_variable(input)
if axis is None: if (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)):
axis = theano.Constant(theano.gof.generic, None) axis = theano.Constant(theano.gof.generic, None)
# axis=None flattens the array before sorting # axis=None flattens the array before sorting
out_type = tensor(dtype=input.dtype, broadcastable=[False]) out_type = tensor(dtype=input.dtype, broadcastable=[False])
...@@ -55,8 +56,34 @@ class SortOp(theano.Op): ...@@ -55,8 +56,34 @@ class SortOp(theano.Op):
assert inputs_shapes[1] == () assert inputs_shapes[1] == ()
return [inputs_shapes[0]] return [inputs_shapes[0]]
#**** It need the argsort, so we can't do it now. def grad(self, inputs, output_grads):
#def grad(self, inputs, output_grads): a, axis = inputs
inp_grad = theano.gradient.grad_not_implemented(
self, 0, axis,
"Currently, we only implement the gradient on sort for vector")
if a.ndim == 1:
idx = argsort(*inputs, kind=self.kind, order=self.order)
# rev_idx = numpy.where(idx[None, :]==numpy.arange(5)[:,None])[1]
rev_idx = theano.tensor.eq(idx[None, :],
arange(a.shape[0])[:, None]).nonzero()[1]
inp_grad = output_grads[0][rev_idx]
elif a.ndim == 2:
if (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)):
idx = argsort(*inputs, kind=self.kind, order=self.order)
rev_idx = theano.tensor.eq(idx[None, :],
arange(a.shape[0]*a.shape[1])[:, None]).nonzero()[1]
inp_grad = output_grads[0][rev_idx].reshape(a.shape)
elif (axis == 0 or
(isinstance(axis, theano.Constant) and axis.data == 0)):
idx = argsort(*inputs, kind=self.kind, order=self.order)
#not working: numpy.where(idx[None, :]==numpy.arange(2)[:, None, None])
pass
axis_grad = theano.gradient.grad_undefined(
self, 1, axis,
"sort is not defined for non-integer axes so"
" sort(x, axis+eps) is undefined")
return [inp_grad, axis_grad]
""" """
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
# R_op can receive None as eval_points. # R_op can receive None as eval_points.
...@@ -115,7 +142,8 @@ class ArgSortOp(theano.Op): ...@@ -115,7 +142,8 @@ class ArgSortOp(theano.Op):
def make_node(self, input, axis=-1): def make_node(self, input, axis=-1):
input = theano.tensor.as_tensor_variable(input) input = theano.tensor.as_tensor_variable(input)
if axis is None: if (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)):
axis = theano.Constant(theano.gof.generic, None) axis = theano.Constant(theano.gof.generic, None)
bcast = [False] bcast = [False]
else: else:
......
...@@ -68,6 +68,33 @@ class test_sort(unittest.TestCase): ...@@ -68,6 +68,33 @@ class test_sort(unittest.TestCase):
gt = np.sort(self.m_val, None) gt = np.sort(self.m_val, None)
assert np.allclose(gv, gt) assert np.allclose(gv, gt)
def test_grad_vector(self):
a = theano.tensor.vector()
#cost = np.power(sort(a), 2).sum()
#g = theano.tensor.grad(cost, a)
#f = theano.function([a], g)
data = np.asarray([7., 10., 2.], dtype=theano.config.floatX)
data = np.random.rand(10).astype(theano.config.floatX)
#assert (f(data) == [20., 4., 14.]).all()
utt.verify_grad(sort, [data])
def test_grad_none_axis(self):
#a = theano.tensor.vector()
#cost = np.power(sort(a, None), 2).sum()
#g = theano.tensor.grad(cost, a)
#f = theano.function([a], g)
data = np.asarray([7., 10., 2.], dtype=theano.config.floatX)
data = np.random.rand(10).astype(theano.config.floatX)
#assert (f(data) == [20., 4., 14.]).all()
utt.verify_grad(lambda x: sort(x, None), [data])
utt.verify_grad(lambda x: sort(x, 0), [data])
#a = theano.tensor.matrix()
data = np.random.rand(2, 3).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, None), [data])
#utt.verify_grad(lambda x: sort(x, 0), [data])
#utt.verify_grad(lambda x: sort(x, 1), [data])
class TensorInferShapeTester(utt.InferShapeTester): class TensorInferShapeTester(utt.InferShapeTester):
def test_sort(self): def test_sort(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论