提交 d1fffa71 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3498 from harmdevries89/grad_argsort

Grad argsort
...@@ -173,11 +173,7 @@ class ArgSortOp(theano.Op): ...@@ -173,11 +173,7 @@ class ArgSortOp(theano.Op):
def grad(self, inputs, output_grads): def grad(self, inputs, output_grads):
# No grad defined for intergers. # No grad defined for intergers.
inp, axis = inputs inp, axis = inputs
inp_grad = theano.gradient.grad_not_implemented( inp_grad = inp.zeros_like()
self, 0, axis,
"I'm not sure if argsort should have its gradient"
" implemented or is should be marked as undefined."
" So I mark it as not implemented for now.")
axis_grad = theano.gradient.grad_undefined( axis_grad = theano.gradient.grad_undefined(
self, 1, axis, self, 1, axis,
"argsort is not defined for non-integer axes so" "argsort is not defined for non-integer axes so"
......
...@@ -133,6 +133,7 @@ class test_sort(unittest.TestCase): ...@@ -133,6 +133,7 @@ class test_sort(unittest.TestCase):
data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX) data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 3), [data]) utt.verify_grad(lambda x: sort(x, 3), [data])
class TensorInferShapeTester(utt.InferShapeTester): class TensorInferShapeTester(utt.InferShapeTester):
def test_sort(self): def test_sort(self):
x = tensor.matrix() x = tensor.matrix()
...@@ -209,3 +210,13 @@ def test_argsort(): ...@@ -209,3 +210,13 @@ def test_argsort():
assert np.allclose(gv, gt) assert np.allclose(gv, gt)
def test_argsort_grad():
# Testing grad of argsort
data = np.random.rand(2, 3).astype(theano.config.floatX)
utt.verify_grad(lambda x: argsort(x, axis=-1), [data])
data = np.random.rand(2, 3, 4, 5).astype(theano.config.floatX)
utt.verify_grad(lambda x: argsort(x, axis=-3), [data])
data = np.random.rand(2, 3, 3).astype(theano.config.floatX)
utt.verify_grad(lambda x: argsort(x, axis=2), [data])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论