提交 18d155df authored 作者: Sina Honari's avatar Sina Honari

changing sort to work with other ndim

上级 3f3bf149
......@@ -68,23 +68,10 @@ class SortOp(theano.Op):
rev_idx = theano.tensor.eq(idx[None, :],
arange(a.shape[0])[:, None]).nonzero()[1]
inp_grad = output_grads[0][rev_idx]
elif a.ndim == 2:
if (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)):
idx = argsort(*inputs, kind=self.kind, order=self.order)
rev_idx = theano.tensor.eq(
idx[None, :],
arange(a.shape[0] * a.shape[1])[:, None]).nonzero()[1]
inp_grad = output_grads[0][rev_idx].reshape(a.shape)
elif (axis == 0 or
(isinstance(axis, theano.Constant) and axis.data == 0)):
idx = argsort(*inputs, kind=self.kind, order=self.order)
# not working: numpy.where(idx[None, :]==numpy.arange(2)[:, None, None])
pass
elif a.ndim == 3:
elif isinstance(axis, theano.Constant):
if isinstance(axis, theano.Constant) and axis.data is not None:
indices = self.__get_argsort_indices(a, axis)
inp_grad = output_grads[0][indices[0], indices[1], indices[2]]
inp_grad = output_grads[0][tuple(indices)]
elif (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)):
rev_idx = self.__get_argsort_indices(a, axis)
......
......@@ -80,12 +80,17 @@ class test_sort(unittest.TestCase):
data = np.random.rand(2, 3).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, None), [data])
#utt.verify_grad(lambda x: sort(x, 0), [data])
#utt.verify_grad(lambda x: sort(x, 1), [data])
data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, None), [data])
def test_grad_negative_axis(self):
# test 2D
data = np.random.rand(2, 3).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -1), [data])
data = np.random.rand(2, 3).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -2), [data])
# test 3D
data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -1), [data])
data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
......@@ -93,7 +98,24 @@ class test_sort(unittest.TestCase):
data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -3), [data])
# test 4D
data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -1), [data])
data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -2), [data])
data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -3), [data])
data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -4), [data])
def test_grad_nonnegative_axis(self):
# test 2D
data = np.random.rand(2, 3).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 0), [data])
data = np.random.rand(2, 3).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 1), [data])
# test 3D
data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 0), [data])
data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
......@@ -101,6 +123,15 @@ class test_sort(unittest.TestCase):
data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 2), [data])
# test 4D
data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 0), [data])
data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 1), [data])
data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 2), [data])
data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 3), [data])
class TensorInferShapeTester(utt.InferShapeTester):
def test_sort(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论