提交 7254a4b1 authored 作者: abalkin's avatar abalkin

Added TensorVariable.argsort().

上级 4f349e93
......@@ -1781,6 +1781,11 @@ class _tensor_py_operators:
"""See `theano.tensor.argmax`"""
return argmax(self, axis, keepdims=keepdims)
def argsort(self, axis=-1, kind='quicksort', order=None):
"""See `theano.tensor.sort.argsort`"""
from theano.tensor.sort import argsort
return argsort(self, axis, kind, order)
def clip(self, a_min, a_max):
"Clip (limit) the values in an array."
return clip(self, a_min, a_max)
......
......@@ -7016,6 +7016,12 @@ class TestTensorInstanceMethods(unittest.TestCase):
x, _ = self.vals
self.assertTrue(numpy.all(X.argmax().eval({X: x}) == x.argmax()))
def test_argsort(self):
X, _ = self.vars
x, _ = self.vals
self.assertTrue(numpy.all(X.argsort().eval({X: x}) == x.argsort()))
self.assertTrue(numpy.all(X.argsort(1).eval({X: x}) == x.argsort(1)))
def test_dot(self):
X, Y = self.vars
x, y = self.vals
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论