提交 bc93de37 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #3168 from SinaHonari/issue3024

changing sort.grad to work with other ndim
import numpy as np import numpy as np
import theano import theano
from theano.tensor import tensor
from theano.tensor.basic import mul, arange from theano.tensor.basic import mul, arange
...@@ -27,14 +24,8 @@ class SortOp(theano.Op): ...@@ -27,14 +24,8 @@ class SortOp(theano.Op):
def make_node(self, input, axis=-1): def make_node(self, input, axis=-1):
input = theano.tensor.as_tensor_variable(input) input = theano.tensor.as_tensor_variable(input)
if (axis is None or axis = theano.tensor.as_tensor_variable(axis)
(isinstance(axis, theano.Constant) and axis.data is None)): out_type = input.type()
axis = theano.Constant(theano.gof.generic, None)
# axis=None flattens the array before sorting
out_type = tensor(dtype=input.dtype, broadcastable=[False])
else:
axis = theano.tensor.as_tensor_variable(axis)
out_type = input.type()
return theano.Apply(self, [input, axis], [out_type]) return theano.Apply(self, [input, axis], [out_type])
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
...@@ -58,43 +49,22 @@ class SortOp(theano.Op): ...@@ -58,43 +49,22 @@ class SortOp(theano.Op):
def grad(self, inputs, output_grads): def grad(self, inputs, output_grads):
a, axis = inputs a, axis = inputs
inp_grad = theano.gradient.grad_not_implemented( indices = self.__get_argsort_indices(a, axis)
self, 0, axis, inp_grad = output_grads[0][tuple(indices)]
"Currently, we only implement the gradient on sort for vector"
" matrix (and axis is None or 0) and tensor3")
if a.ndim == 1:
idx = argsort(*inputs, kind=self.kind, order=self.order)
# rev_idx = numpy.where(idx[None, :]==numpy.arange(5)[:,None])[1]
rev_idx = theano.tensor.eq(idx[None, :],
arange(a.shape[0])[:, None]).nonzero()[1]
inp_grad = output_grads[0][rev_idx]
elif a.ndim == 2:
if (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)):
idx = argsort(*inputs, kind=self.kind, order=self.order)
rev_idx = theano.tensor.eq(
idx[None, :],
arange(a.shape[0] * a.shape[1])[:, None]).nonzero()[1]
inp_grad = output_grads[0][rev_idx].reshape(a.shape)
elif (axis == 0 or
(isinstance(axis, theano.Constant) and axis.data == 0)):
idx = argsort(*inputs, kind=self.kind, order=self.order)
# not working: numpy.where(idx[None, :]==numpy.arange(2)[:, None, None])
pass
elif a.ndim == 3:
if isinstance(axis, theano.Constant) and axis.data is not None:
indices = self.__get_argsort_indices(a, axis)
inp_grad = output_grads[0][indices[0], indices[1], indices[2]]
elif (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)):
rev_idx = self.__get_argsort_indices(a, axis)
inp_grad = output_grads[0][rev_idx].reshape(a.shape)
axis_grad = theano.gradient.grad_undefined( axis_grad = theano.gradient.grad_undefined(
self, 1, axis, self, 1, axis,
"sort is not defined for non-integer axes so" "The gradient of sort is not defined "
" sort(x, axis+eps) is undefined") "with respect to the integer axes itself")
return [inp_grad, axis_grad] return [inp_grad, axis_grad]
def __get_expanded_dim(self, a, axis, i):
index_shape = [1] * a.ndim
index_shape[i] = a.shape[i]
# it's a way to emulate
# numpy.ogrid[0: a.shape[0], 0: a.shape[1], 0: a.shape[2]]
index_val = arange(a.shape[i]).reshape(index_shape)
return index_val
def __get_argsort_indices(self, a, axis): def __get_argsort_indices(self, a, axis):
"""Calculates indices which can be used to reverse """Calculates indices which can be used to reverse
sorting operation of "a" tensor along "axis" sorting operation of "a" tensor along "axis"
...@@ -109,22 +79,15 @@ class SortOp(theano.Op): ...@@ -109,22 +79,15 @@ class SortOp(theano.Op):
idx = argsort(a, axis, kind=self.kind, order=self.order) idx = argsort(a, axis, kind=self.kind, order=self.order)
# rev_idx is the reverse of previous argsort operation # rev_idx is the reverse of previous argsort operation
rev_idx = argsort(idx, axis, kind=self.kind, order=self.order) rev_idx = argsort(idx, axis, kind=self.kind, order=self.order)
if (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)):
return rev_idx
indices = [] indices = []
if axis.data >= 0: axis_data = theano.tensor.switch(theano.tensor.ge(axis.data, 0),
axis_data = axis.data axis.data, a.ndim + axis.data)
else:
axis_data = a.ndim + axis.data
for i in range(a.ndim): for i in range(a.ndim):
if i == axis_data: index_val = theano.tensor.switch(theano.tensor.eq(i, axis_data),
indices.append(rev_idx) rev_idx,
else: self.__get_expanded_dim(a,
index_shape = [1] * a.ndim axis, i))
index_shape[i] = a.shape[i] indices.append(index_val)
# it's a way to emulate numpy.ogrid[0: a.shape[0], 0: a.shape[1], 0: a.shape[2]]
indices.append(theano.tensor.arange(a.shape[i]).reshape(index_shape))
return indices return indices
""" """
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
...@@ -159,6 +122,9 @@ def sort(a, axis=-1, kind='quicksort', order=None): ...@@ -159,6 +122,9 @@ def sort(a, axis=-1, kind='quicksort', order=None):
need to include all of the fields. need to include all of the fields.
""" """
if axis is None:
a = a.flatten()
axis = 0
return SortOp(kind, order)(a, axis) return SortOp(kind, order)(a, axis)
...@@ -184,13 +150,8 @@ class ArgSortOp(theano.Op): ...@@ -184,13 +150,8 @@ class ArgSortOp(theano.Op):
def make_node(self, input, axis=-1): def make_node(self, input, axis=-1):
input = theano.tensor.as_tensor_variable(input) input = theano.tensor.as_tensor_variable(input)
if (axis is None or axis = theano.tensor.as_tensor_variable(axis)
(isinstance(axis, theano.Constant) and axis.data is None)): bcast = input.type.broadcastable
axis = theano.Constant(theano.gof.generic, None)
bcast = [False]
else:
axis = theano.tensor.as_tensor_variable(axis)
bcast = input.type.broadcastable
return theano.Apply(self, [input, axis], [theano.tensor.TensorType( return theano.Apply(self, [input, axis], [theano.tensor.TensorType(
dtype="int64", broadcastable=bcast)()]) dtype="int64", broadcastable=bcast)()])
...@@ -245,4 +206,7 @@ def argsort(a, axis=-1, kind='quicksort', order=None): ...@@ -245,4 +206,7 @@ def argsort(a, axis=-1, kind='quicksort', order=None):
the same shape as a that index data along the given axis in sorted the same shape as a that index data along the given axis in sorted
order. order.
""" """
if axis is None:
a = a.flatten()
axis = 0
return ArgSortOp(kind, order)(a, axis) return ArgSortOp(kind, order)(a, axis)
...@@ -80,12 +80,17 @@ class test_sort(unittest.TestCase): ...@@ -80,12 +80,17 @@ class test_sort(unittest.TestCase):
data = np.random.rand(2, 3).astype(theano.config.floatX) data = np.random.rand(2, 3).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, None), [data]) utt.verify_grad(lambda x: sort(x, None), [data])
#utt.verify_grad(lambda x: sort(x, 0), [data])
#utt.verify_grad(lambda x: sort(x, 1), [data])
data = np.random.rand(2, 3, 4).astype(theano.config.floatX) data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, None), [data]) utt.verify_grad(lambda x: sort(x, None), [data])
def test_grad_negative_axis(self): def test_grad_negative_axis(self):
# test 2D
data = np.random.rand(2, 3).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -1), [data])
data = np.random.rand(2, 3).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -2), [data])
# test 3D
data = np.random.rand(2, 3, 4).astype(theano.config.floatX) data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -1), [data]) utt.verify_grad(lambda x: sort(x, -1), [data])
data = np.random.rand(2, 3, 4).astype(theano.config.floatX) data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
...@@ -93,7 +98,24 @@ class test_sort(unittest.TestCase): ...@@ -93,7 +98,24 @@ class test_sort(unittest.TestCase):
data = np.random.rand(2, 3, 4).astype(theano.config.floatX) data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -3), [data]) utt.verify_grad(lambda x: sort(x, -3), [data])
# test 4D
data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -1), [data])
data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -2), [data])
data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -3), [data])
data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -4), [data])
def test_grad_nonnegative_axis(self): def test_grad_nonnegative_axis(self):
# test 2D
data = np.random.rand(2, 3).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 0), [data])
data = np.random.rand(2, 3).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 1), [data])
# test 3D
data = np.random.rand(2, 3, 4).astype(theano.config.floatX) data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 0), [data]) utt.verify_grad(lambda x: sort(x, 0), [data])
data = np.random.rand(2, 3, 4).astype(theano.config.floatX) data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
...@@ -101,6 +123,15 @@ class test_sort(unittest.TestCase): ...@@ -101,6 +123,15 @@ class test_sort(unittest.TestCase):
data = np.random.rand(2, 3, 4).astype(theano.config.floatX) data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 2), [data]) utt.verify_grad(lambda x: sort(x, 2), [data])
# test 4D
data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 0), [data])
data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 1), [data])
data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 2), [data])
data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 3), [data])
class TensorInferShapeTester(utt.InferShapeTester): class TensorInferShapeTester(utt.InferShapeTester):
def test_sort(self): def test_sort(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论