提交 074e7ffa authored 作者: HaniAlmousli's avatar HaniAlmousli

Merge pull request #1 from nouiz/sort2

Sort2
......@@ -5763,22 +5763,24 @@ class SortOp(theano.Op):
This class is a wrapper for numpy sort function
"""
def __init__(self, kind, order=None):
self.kind = kind
self.order = order
self.kind = kind
self.order = order
def __eq__(self, other):
return type(self) == type(other) and self.order == other.order and self.kind==other.kind
return (type(self) == type(other) and self.order == other.order and
self.kind == other.kind)
def __hash__(self):
return hash(type(self)) ^ hash(self.order) ^ hash(self.kind)
def __str__(self):
return self.__class__.__name__ + "{%s, %s}" % (self.kind, str(self.order))
return self.__class__.__name__ + "{%s, %s}" % (self.kind,
str(self.order))
def make_node(self, input, axis=-1):
if axis is None:
raise ValueError("Current Implementation does not sipport axis=None")
return
if axis is None:
raise ValueError("Current Implementation does not support"
" axis=None")
input = theano.tensor.as_tensor_variable(input)
axis = theano.tensor.as_tensor_variable(axis)
return theano.Apply(self, [input, axis], [input.type()])
......@@ -5787,7 +5789,7 @@ class SortOp(theano.Op):
a = inputs[0]
axis = inputs[1]
z = output_storage[0]
z[0] = numpy.sort(a,axis,self.kind,self.order)
z[0] = numpy.sort(a, axis, self.kind, self.order)
def infer_shape(self, node, inputs_shapes):
return [inputs_shapes[0]]
......@@ -5813,7 +5815,7 @@ def sort(a, axis=-1, kind='quicksort', order=None):
Tensor to be sorted
axis : Tensor
Axis along which to sort .None is not still supported.
Axis along which to sort. None is not still supported.
kind : {'quicksort', 'mergesort', 'heapsort'}, optional
......@@ -5821,7 +5823,9 @@ def sort(a, axis=-1, kind='quicksort', order=None):
order : list, optional
When a is a structured array, this argument specifies which fields to compare first, second, and so on. This list does not need to include all of the fields.
When a is a structured array, this argument specifies which
fields to compare first, second, and so on. This list does not
need to include all of the fields.
"""
return SortOp(kind, order)(a, axis)
......@@ -5586,54 +5586,54 @@ def test_transpose():
def test_sort():
testMatrix = [[4,9,1],[1,3,2]]
testVector = [1,10,0,2]
testMatrix = [[4, 9, 1], [1, 3, 2]]
testVector = [1, 10, 0, 2]
print "Example 1: "
a = theano.tensor.dmatrix()
w = sort(a)
f = theano.function([a],w)
assert f(testMatrix) == numpy.sort(testMatrix)
f = theano.function([a], w)
assert numpy.allclose(f(testMatrix), numpy.sort(testMatrix))
print "------------------------------"
print "Example 2: "
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
w = sort(a,axis)
f = theano.function([a,axis],w)
print f(testMatrix,1)
w = sort(a, axis)
f = theano.function([a, axis], w)
assert numpy.allclose(f(testMatrix, 1), numpy.sort(testMatrix, 1))
print "------------------------------"
print "Example 3: "
a = theano.tensor.dvector()
w2 = sort(a)
f = theano.function([a],w2)
print f(testVector)
f = theano.function([a], w2)
assert numpy.allclose(f(testVector), numpy.sort(testVector))
print "------------------------------"
print "Example 4: "
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
l = sort(a,axis,"mergesort")
f = theano.function([a,axis],l)
print f(testMatrix,1)
l = sort(a, axis, "mergesort")
f = theano.function([a, axis], l)
assert numpy.allclose(f(testMatrix, 1), numpy.sort(testMatrix, 1))
print "------------------------------"
print "Example 5: Check __eq__ function "
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
a1 = SortOp("mergesort",[])
a2 = SortOp("quicksort",[])
a1 = SortOp("mergesort", [])
a2 = SortOp("quicksort", [])
#All the below should give true
assert a1 != a2
assert a1 == SortOp("mergesort",[])
assert a2 == SortOp("quicksort",[])
assert a1 == SortOp("mergesort", [])
assert a2 == SortOp("quicksort", [])
print "Example 5: axis=None"
a = theano.tensor.dmatrix()
try:
l = sort(a,None)
l = sort(a, None)
except ValueError:
pass
else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论