提交 00f26d88 authored 作者: Hani's avatar Hani

Added SortOp

上级 0379bf04
......@@ -5759,24 +5759,28 @@ def all(x, axis=None):
class SortOp(theano.Op):
"""
This class is a wrapper for numpy sort function
"""
def __init__(self, kind, order=None):
self.kind = kind
self.order = order
def __eq__(self, other):
return type(self) == type(other) and self.order == other.order
return type(self) == type(other) and self.order == other.order and self.kind==other.kind
def __hash__(self):
return hash(type(self)) ^ hash(self.order)
return hash(type(self)) ^ hash(self.order) ^ hash(self.kind)
def __str__(self):
return self.__class__.__name__ + "{%s, %s}" % (self.kind, str(self.order))
def make_node(self, input, axis=-1):
if axis is None:
raise ValueError("Current Implementation does not sipport axis=None")
return
input = theano.tensor.as_tensor_variable(input)
axis = theano.tensor.as_tensor_variable(axis)
return theano.Apply(self, [input, axis], [input.type()])
def perform(self, node, inputs, output_storage):
......@@ -5790,7 +5794,6 @@ class SortOp(theano.Op):
#**** It need the argsort, so we can't do it now.
#def grad(self, inputs, output_grads):
"""
def R_op(self, inputs, eval_points):
# R_op can receive None as eval_points.
......@@ -5804,4 +5807,21 @@ class SortOp(theano.Op):
def sort(a, axis=-1, kind='quicksort', order=None):
"""
Return a sorted copy of an array.
a : Tensor
Tensor to be sorted
axis : Tensor
Axis along which to sort .None is not still supported.
kind : {'quicksort', 'mergesort', 'heapsort'}, optional
Sorting algorithm. Default is 'quicksort'.
order : list, optional
When a is a structured array, this argument specifies which fields to compare first, second, and so on. This list does not need to include all of the fields.
"""
return SortOp(kind, order)(a, axis)
......@@ -5586,54 +5586,58 @@ def test_transpose():
def test_sort():
testMatrix = [[4,9,1],[1,3,2]]
testVector = [1,10,0,2]
print "Example 1: "
a = theano.tensor.dmatrix()
w = sort(a)
f = theano.function([a],w)
print testMatrix
print f(testMatrix)
print "------------------------------"
print "Example 2: "
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
w = sort(a,axis)
f = theano.function([a,axis],w)
print testMatrix
print f(testMatrix,1)
print "------------------------------"
print "Example 3: "
a = theano.tensor.dvector()
w2 = sort(a)
f = theano.function([a],w2)
print testVector
print f(testVector)
print "------------------------------"
print "Example 4: "
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
l = sort(a,axis,"mergesort")
f = theano.function([a,axis],l)
print testMatrix
print f(testMatrix,1)
print "------------------------------"
print "Example 5: Check __eq__ function "
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
a1 = SortOp("mergesort",[])
a2 = SortOp("quicksort",[])
#All the below should give true
print a1 == a2
print a1 == SortOp("mergesort",[])
print a2 == SortOp("quicksort",[])
testMatrix = [[4,9,1],[1,3,2]]
testVector = [1,10,0,2]
print "Example 1: "
a = theano.tensor.dmatrix()
w = sort(a)
f = theano.function([a],w)
assert f(testMatrix) == numpy.sort(testMatrix)
print "------------------------------"
print "Example 2: "
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
w = sort(a,axis)
f = theano.function([a,axis],w)
print f(testMatrix,1)
print "------------------------------"
print "Example 3: "
a = theano.tensor.dvector()
w2 = sort(a)
f = theano.function([a],w2)
print f(testVector)
print "------------------------------"
print "Example 4: "
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
l = sort(a,axis,"mergesort")
f = theano.function([a,axis],l)
print f(testMatrix,1)
print "------------------------------"
print "Example 5: Check __eq__ function "
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
a1 = SortOp("mergesort",[])
a2 = SortOp("quicksort",[])
#All the below should give true
assert a1 != a2
assert a1 == SortOp("mergesort",[])
assert a2 == SortOp("quicksort",[])
print "Example 5: axis=None"
a = theano.tensor.dmatrix()
try:
l = sort(a,None)
except ValueError:
pass
else:
assert False
if __name__ == '__main__':
if 0:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论