提交 bc93de37 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #3168 from SinaHonari/issue3024

changing sort.grad to work with other ndim
import numpy as np
import theano
from theano.tensor import tensor
from theano.tensor.basic import mul, arange
......@@ -27,14 +24,8 @@ class SortOp(theano.Op):
def make_node(self, input, axis=-1):
input = theano.tensor.as_tensor_variable(input)
if (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)):
axis = theano.Constant(theano.gof.generic, None)
# axis=None flattens the array before sorting
out_type = tensor(dtype=input.dtype, broadcastable=[False])
else:
axis = theano.tensor.as_tensor_variable(axis)
out_type = input.type()
axis = theano.tensor.as_tensor_variable(axis)
out_type = input.type()
return theano.Apply(self, [input, axis], [out_type])
def perform(self, node, inputs, output_storage):
......@@ -58,43 +49,22 @@ class SortOp(theano.Op):
def grad(self, inputs, output_grads):
a, axis = inputs
inp_grad = theano.gradient.grad_not_implemented(
self, 0, axis,
"Currently, we only implement the gradient on sort for vector"
" matrix (and axis is None or 0) and tensor3")
if a.ndim == 1:
idx = argsort(*inputs, kind=self.kind, order=self.order)
# rev_idx = numpy.where(idx[None, :]==numpy.arange(5)[:,None])[1]
rev_idx = theano.tensor.eq(idx[None, :],
arange(a.shape[0])[:, None]).nonzero()[1]
inp_grad = output_grads[0][rev_idx]
elif a.ndim == 2:
if (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)):
idx = argsort(*inputs, kind=self.kind, order=self.order)
rev_idx = theano.tensor.eq(
idx[None, :],
arange(a.shape[0] * a.shape[1])[:, None]).nonzero()[1]
inp_grad = output_grads[0][rev_idx].reshape(a.shape)
elif (axis == 0 or
(isinstance(axis, theano.Constant) and axis.data == 0)):
idx = argsort(*inputs, kind=self.kind, order=self.order)
# not working: numpy.where(idx[None, :]==numpy.arange(2)[:, None, None])
pass
elif a.ndim == 3:
if isinstance(axis, theano.Constant) and axis.data is not None:
indices = self.__get_argsort_indices(a, axis)
inp_grad = output_grads[0][indices[0], indices[1], indices[2]]
elif (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)):
rev_idx = self.__get_argsort_indices(a, axis)
inp_grad = output_grads[0][rev_idx].reshape(a.shape)
indices = self.__get_argsort_indices(a, axis)
inp_grad = output_grads[0][tuple(indices)]
axis_grad = theano.gradient.grad_undefined(
self, 1, axis,
"sort is not defined for non-integer axes so"
" sort(x, axis+eps) is undefined")
"The gradient of sort is not defined "
"with respect to the integer axes itself")
return [inp_grad, axis_grad]
def __get_expanded_dim(self, a, axis, i):
index_shape = [1] * a.ndim
index_shape[i] = a.shape[i]
# it's a way to emulate
# numpy.ogrid[0: a.shape[0], 0: a.shape[1], 0: a.shape[2]]
index_val = arange(a.shape[i]).reshape(index_shape)
return index_val
def __get_argsort_indices(self, a, axis):
"""Calculates indices which can be used to reverse
sorting operation of "a" tensor along "axis"
......@@ -109,22 +79,15 @@ class SortOp(theano.Op):
idx = argsort(a, axis, kind=self.kind, order=self.order)
# rev_idx is the reverse of previous argsort operation
rev_idx = argsort(idx, axis, kind=self.kind, order=self.order)
if (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)):
return rev_idx
indices = []
if axis.data >= 0:
axis_data = axis.data
else:
axis_data = a.ndim + axis.data
axis_data = theano.tensor.switch(theano.tensor.ge(axis.data, 0),
axis.data, a.ndim + axis.data)
for i in range(a.ndim):
if i == axis_data:
indices.append(rev_idx)
else:
index_shape = [1] * a.ndim
index_shape[i] = a.shape[i]
# it's a way to emulate numpy.ogrid[0: a.shape[0], 0: a.shape[1], 0: a.shape[2]]
indices.append(theano.tensor.arange(a.shape[i]).reshape(index_shape))
index_val = theano.tensor.switch(theano.tensor.eq(i, axis_data),
rev_idx,
self.__get_expanded_dim(a,
axis, i))
indices.append(index_val)
return indices
"""
def R_op(self, inputs, eval_points):
......@@ -159,6 +122,9 @@ def sort(a, axis=-1, kind='quicksort', order=None):
need to include all of the fields.
"""
if axis is None:
a = a.flatten()
axis = 0
return SortOp(kind, order)(a, axis)
......@@ -184,13 +150,8 @@ class ArgSortOp(theano.Op):
def make_node(self, input, axis=-1):
input = theano.tensor.as_tensor_variable(input)
if (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)):
axis = theano.Constant(theano.gof.generic, None)
bcast = [False]
else:
axis = theano.tensor.as_tensor_variable(axis)
bcast = input.type.broadcastable
axis = theano.tensor.as_tensor_variable(axis)
bcast = input.type.broadcastable
return theano.Apply(self, [input, axis], [theano.tensor.TensorType(
dtype="int64", broadcastable=bcast)()])
......@@ -245,4 +206,7 @@ def argsort(a, axis=-1, kind='quicksort', order=None):
the same shape as a that index data along the given axis in sorted
order.
"""
if axis is None:
a = a.flatten()
axis = 0
return ArgSortOp(kind, order)(a, axis)
......@@ -80,12 +80,17 @@ class test_sort(unittest.TestCase):
data = np.random.rand(2, 3).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, None), [data])
#utt.verify_grad(lambda x: sort(x, 0), [data])
#utt.verify_grad(lambda x: sort(x, 1), [data])
data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, None), [data])
def test_grad_negative_axis(self):
# test 2D
data = np.random.rand(2, 3).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -1), [data])
data = np.random.rand(2, 3).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -2), [data])
# test 3D
data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -1), [data])
data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
......@@ -93,7 +98,24 @@ class test_sort(unittest.TestCase):
data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -3), [data])
# test 4D
data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -1), [data])
data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -2), [data])
data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -3), [data])
data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -4), [data])
def test_grad_nonnegative_axis(self):
# test 2D
data = np.random.rand(2, 3).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 0), [data])
data = np.random.rand(2, 3).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 1), [data])
# test 3D
data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 0), [data])
data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
......@@ -101,6 +123,15 @@ class test_sort(unittest.TestCase):
data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 2), [data])
# test 4D
data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 0), [data])
data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 1), [data])
data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 2), [data])
data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 3), [data])
class TensorInferShapeTester(utt.InferShapeTester):
def test_sort(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论