提交 b60f9a5d authored 作者: lamblin's avatar lamblin

Merge pull request #395 from HaniAlmousli/sort

Added tensor.sort() op
......@@ -5768,3 +5768,76 @@ def any(x, axis=None):
def all(x, axis=None):
return elemwise.All(axis)(x)
class SortOp(theano.Op):
"""
This class is a wrapper for numpy sort function
"""
def __init__(self, kind, order=None):
self.kind = kind
self.order = order
def __eq__(self, other):
return (type(self) == type(other) and self.order == other.order and
self.kind == other.kind)
def __hash__(self):
return hash(type(self)) ^ hash(self.order) ^ hash(self.kind)
def __str__(self):
return self.__class__.__name__ + "{%s, %s}" % (self.kind,
str(self.order))
def make_node(self, input, axis=-1):
if axis is None:
raise ValueError("Current Implementation does not support"
" axis=None")
input = theano.tensor.as_tensor_variable(input)
axis = theano.tensor.as_tensor_variable(axis)
return theano.Apply(self, [input, axis], [input.type()])
def perform(self, node, inputs, output_storage):
a = inputs[0]
axis = inputs[1]
z = output_storage[0]
z[0] = numpy.sort(a, axis, self.kind, self.order)
def infer_shape(self, node, inputs_shapes):
return [inputs_shapes[0]]
#**** It need the argsort, so we can't do it now.
#def grad(self, inputs, output_grads):
"""
def R_op(self, inputs, eval_points):
# R_op can receive None as eval_points.
# That mean there is no diferientiable path through that input
# If this imply that you cannot compute some outputs,
# return None for those.
if eval_points[0] is None:
return eval_points
return self.grad(inputs, eval_points)
"""
def sort(a, axis=-1, kind='quicksort', order=None):
"""
Return a sorted copy of an array.
a : Tensor
Tensor to be sorted
axis : Tensor
Axis along which to sort. None is not still supported.
kind : {'quicksort', 'mergesort', 'heapsort'}, optional
Sorting algorithm. Default is 'quicksort'.
order : list, optional
When a is a structured array, this argument specifies which
fields to compare first, second, and so on. This list does not
need to include all of the fields.
"""
return SortOp(kind, order)(a, axis)
......@@ -34,7 +34,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
get_constant_value, ivector, reshape, scalar_from_tensor, scal,
iscalars, arange, dscalars, fvector, imatrix, numeric_grad,
opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll,
tile, patternbroadcast)
tile, patternbroadcast, sort, SortOp, )
from theano.tests import unittest_tools as utt
......@@ -5584,6 +5584,61 @@ def test_transpose():
assert numpy.all(t3d == numpy.transpose(x3v, [0, 2, 1]))
def test_sort():
testMatrix = [[4, 9, 1], [1, 3, 2]]
testVector = [1, 10, 0, 2]
print "Example 1: "
a = theano.tensor.dmatrix()
w = sort(a)
f = theano.function([a], w)
assert numpy.allclose(f(testMatrix), numpy.sort(testMatrix))
print "------------------------------"
print "Example 2: "
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
w = sort(a, axis)
f = theano.function([a, axis], w)
assert numpy.allclose(f(testMatrix, 1), numpy.sort(testMatrix, 1))
print "------------------------------"
print "Example 3: "
a = theano.tensor.dvector()
w2 = sort(a)
f = theano.function([a], w2)
assert numpy.allclose(f(testVector), numpy.sort(testVector))
print "------------------------------"
print "Example 4: "
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
l = sort(a, axis, "mergesort")
f = theano.function([a, axis], l)
assert numpy.allclose(f(testMatrix, 1), numpy.sort(testMatrix, 1))
print "------------------------------"
print "Example 5: Check __eq__ function "
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
a1 = SortOp("mergesort", [])
a2 = SortOp("quicksort", [])
#All the below should give true
assert a1 != a2
assert a1 == SortOp("mergesort", [])
assert a2 == SortOp("quicksort", [])
print "Example 5: axis=None"
a = theano.tensor.dmatrix()
try:
l = sort(a, None)
except ValueError:
pass
else:
assert False
if __name__ == '__main__':
if 0:
unittest.main()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论