提交 9555a94f authored 作者: lamblin's avatar lamblin

Merge pull request #1139 from jlowin/sort_argsort

Add sort method and tensor.argsort function
...@@ -54,6 +54,6 @@ import nnet # used for softmax, sigmoid, etc. ...@@ -54,6 +54,6 @@ import nnet # used for softmax, sigmoid, etc.
from theano.gradient import Rop, Lop, grad, numeric_grad, verify_grad, \ from theano.gradient import Rop, Lop, grad, numeric_grad, verify_grad, \
jacobian, hessian jacobian, hessian
from theano.tensor.sort import sort from theano.tensor.sort import sort, argsort
from extra_ops import (DiffOp, bincount, squeeze, from extra_ops import (DiffOp, bincount, squeeze,
repeat, bartlett, fill_diagonal) repeat, bartlett, fill_diagonal)
...@@ -1715,7 +1715,7 @@ class _tensor_py_operators: ...@@ -1715,7 +1715,7 @@ class _tensor_py_operators:
def take(self, indices, axis=None, mode='raise'): def take(self, indices, axis=None, mode='raise'):
return take(self, indices, axis, mode) return take(self, indices, axis, mode)
# COPYING # COPYING
def copy(self): def copy(self):
return tensor_copy(self) return tensor_copy(self)
...@@ -1751,7 +1751,7 @@ class _tensor_py_operators: ...@@ -1751,7 +1751,7 @@ class _tensor_py_operators:
return dot(left, right) return dot(left, right)
dot = __dot__ dot = __dot__
def sum(self, axis=None, dtype=None, keepdims=False): def sum(self, axis=None, dtype=None, keepdims=False):
"""See `theano.tensor.sum`""" """See `theano.tensor.sum`"""
return sum(self, axis=axis, dtype=dtype, keepdims=keepdims) return sum(self, axis=axis, dtype=dtype, keepdims=keepdims)
...@@ -1796,11 +1796,16 @@ class _tensor_py_operators: ...@@ -1796,11 +1796,16 @@ class _tensor_py_operators:
"""See `theano.tensor.argmax`""" """See `theano.tensor.argmax`"""
return argmax(self, axis, keepdims=keepdims) return argmax(self, axis, keepdims=keepdims)
def sort(self, axis=-1, kind='quicksort', order=None):
"""See `theano.tensor.sort`"""
from theano.tensor.sort import sort
return sort(self, axis, kind, order)
def argsort(self, axis=-1, kind='quicksort', order=None): def argsort(self, axis=-1, kind='quicksort', order=None):
"""See `theano.tensor.sort.argsort`""" """See `theano.tensor.argsort`"""
from theano.tensor.sort import argsort from theano.tensor.sort import argsort
return argsort(self, axis, kind, order) return argsort(self, axis, kind, order)
def clip(self, a_min, a_max): def clip(self, a_min, a_max):
"Clip (limit) the values in an array." "Clip (limit) the values in an array."
return clip(self, a_min, a_max) return clip(self, a_min, a_max)
...@@ -1810,7 +1815,7 @@ class _tensor_py_operators: ...@@ -1810,7 +1815,7 @@ class _tensor_py_operators:
return conj(self) return conj(self)
conjugate = conj conjugate = conj
def repeat(self, repeats, axis=None): def repeat(self, repeats, axis=None):
"""See `theano.tensor.repeat`""" """See `theano.tensor.repeat`"""
from theano.tensor.extra_ops import repeat from theano.tensor.extra_ops import repeat
...@@ -6851,7 +6856,7 @@ def take(a, indices, axis=None, mode='raise'): ...@@ -6851,7 +6856,7 @@ def take(a, indices, axis=None, mode='raise'):
shape = concatenate([a.shape[:axis], indices.shape, a.shape[axis+1:]]) shape = concatenate([a.shape[:axis], indices.shape, a.shape[axis+1:]])
ndim = a.ndim + indices.ndim - 1 ndim = a.ndim + indices.ndim - 1
return take(a, indices.flatten(), axis, mode).reshape(shape, ndim) return take(a, indices.flatten(), axis, mode).reshape(shape, ndim)
######################### #########################
# Linalg : Dot # Linalg : Dot
######################### #########################
...@@ -7285,12 +7290,12 @@ class Diagonal(Op): ...@@ -7285,12 +7290,12 @@ class Diagonal(Op):
:return: A vector representing the diagonal elements. :return: A vector representing the diagonal elements.
""" """
def __init__(self, offset=0, axis1=0, axis2=1): def __init__(self, offset=0, axis1=0, axis2=1):
self.offset = offset self.offset = offset
self.axis1 = axis1 self.axis1 = axis1
self.axis2 = axis2 self.axis2 = axis2
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other) and return (type(self) == type(other) and
self.offset == other.offset and self.offset == other.offset and
...@@ -7324,7 +7329,7 @@ class Diagonal(Op): ...@@ -7324,7 +7329,7 @@ class Diagonal(Op):
if offset > 0: if offset > 0:
diag_size = clip(dim2 - offset, 0, dim1) diag_size = clip(dim2 - offset, 0, dim1)
elif offset < 0: elif offset < 0:
diag_size = clip(dim1 + offset, 0, dim2) diag_size = clip(dim1 + offset, 0, dim2)
else: else:
diag_size = minimum(dim1, dim2) diag_size = minimum(dim1, dim2)
out_shape.append(diag_size) out_shape.append(diag_size)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论