提交 9555a94f authored 作者: lamblin's avatar lamblin

Merge pull request #1139 from jlowin/sort_argsort

Add sort method and tensor.argsort function
......@@ -54,6 +54,6 @@ import nnet # used for softmax, sigmoid, etc.
from theano.gradient import Rop, Lop, grad, numeric_grad, verify_grad, \
jacobian, hessian
from theano.tensor.sort import sort
from theano.tensor.sort import sort, argsort
from extra_ops import (DiffOp, bincount, squeeze,
repeat, bartlett, fill_diagonal)
......@@ -1715,7 +1715,7 @@ class _tensor_py_operators:
def take(self, indices, axis=None, mode='raise'):
return take(self, indices, axis, mode)
# COPYING
def copy(self):
return tensor_copy(self)
......@@ -1751,7 +1751,7 @@ class _tensor_py_operators:
return dot(left, right)
dot = __dot__
def sum(self, axis=None, dtype=None, keepdims=False):
"""See `theano.tensor.sum`"""
return sum(self, axis=axis, dtype=dtype, keepdims=keepdims)
......@@ -1796,11 +1796,16 @@ class _tensor_py_operators:
"""See `theano.tensor.argmax`"""
return argmax(self, axis, keepdims=keepdims)
def sort(self, axis=-1, kind='quicksort', order=None):
"""See `theano.tensor.sort`"""
from theano.tensor.sort import sort
return sort(self, axis, kind, order)
def argsort(self, axis=-1, kind='quicksort', order=None):
"""See `theano.tensor.sort.argsort`"""
"""See `theano.tensor.argsort`"""
from theano.tensor.sort import argsort
return argsort(self, axis, kind, order)
def clip(self, a_min, a_max):
"Clip (limit) the values in an array."
return clip(self, a_min, a_max)
......@@ -1810,7 +1815,7 @@ class _tensor_py_operators:
return conj(self)
conjugate = conj
def repeat(self, repeats, axis=None):
"""See `theano.tensor.repeat`"""
from theano.tensor.extra_ops import repeat
......@@ -6851,7 +6856,7 @@ def take(a, indices, axis=None, mode='raise'):
shape = concatenate([a.shape[:axis], indices.shape, a.shape[axis+1:]])
ndim = a.ndim + indices.ndim - 1
return take(a, indices.flatten(), axis, mode).reshape(shape, ndim)
#########################
# Linalg : Dot
#########################
......@@ -7285,12 +7290,12 @@ class Diagonal(Op):
:return: A vector representing the diagonal elements.
"""
def __init__(self, offset=0, axis1=0, axis2=1):
self.offset = offset
self.axis1 = axis1
self.axis2 = axis2
def __eq__(self, other):
return (type(self) == type(other) and
self.offset == other.offset and
......@@ -7324,7 +7329,7 @@ class Diagonal(Op):
if offset > 0:
diag_size = clip(dim2 - offset, 0, dim1)
elif offset < 0:
diag_size = clip(dim1 + offset, 0, dim2)
diag_size = clip(dim1 + offset, 0, dim2)
else:
diag_size = minimum(dim1, dim2)
out_shape.append(diag_size)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论