提交 0379bf04 authored 作者: Hani's avatar Hani

Added tensor.sort() op

上级 8b1c4916
......@@ -5756,3 +5756,52 @@ def any(x, axis=None):
def all(x, axis=None):
return elemwise.All(axis)(x)
class SortOp(theano.Op):
def __init__(self, kind, order=None):
self.kind = kind
self.order = order
def __eq__(self, other):
return type(self) == type(other) and self.order == other.order
def __hash__(self):
return hash(type(self)) ^ hash(self.order)
def __str__(self):
return self.__class__.__name__ + "{%s, %s}" % (self.kind, str(self.order))
def make_node(self, input, axis=-1):
input = theano.tensor.as_tensor_variable(input)
axis = theano.tensor.as_tensor_variable(axis)
return theano.Apply(self, [input, axis], [input.type()])
def perform(self, node, inputs, output_storage):
a = inputs[0]
axis = inputs[1]
z = output_storage[0]
z[0] = numpy.sort(a,axis,self.kind,self.order)
def infer_shape(self, node, inputs_shapes):
return [inputs_shapes[0]]
#**** It need the argsort, so we can't do it now.
#def grad(self, inputs, output_grads):
"""
def R_op(self, inputs, eval_points):
# R_op can receive None as eval_points.
# That mean there is no diferientiable path through that input
# If this imply that you cannot compute some outputs,
# return None for those.
if eval_points[0] is None:
return eval_points
return self.grad(inputs, eval_points)
"""
def sort(a, axis=-1, kind='quicksort', order=None):
return SortOp(kind, order)(a, axis)
......@@ -34,7 +34,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
get_constant_value, ivector, reshape, scalar_from_tensor, scal,
iscalars, arange, dscalars, fvector, imatrix, numeric_grad,
opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll,
tile, patternbroadcast)
tile, patternbroadcast, sort, SortOp, )
from theano.tests import unittest_tools as utt
......@@ -5584,6 +5584,57 @@ def test_transpose():
assert numpy.all(t3d == numpy.transpose(x3v, [0, 2, 1]))
def test_sort():
testMatrix = [[4,9,1],[1,3,2]]
testVector = [1,10,0,2]
print "Example 1: "
a = theano.tensor.dmatrix()
w = sort(a)
f = theano.function([a],w)
print testMatrix
print f(testMatrix)
print "------------------------------"
print "Example 2: "
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
w = sort(a,axis)
f = theano.function([a,axis],w)
print testMatrix
print f(testMatrix,1)
print "------------------------------"
print "Example 3: "
a = theano.tensor.dvector()
w2 = sort(a)
f = theano.function([a],w2)
print testVector
print f(testVector)
print "------------------------------"
print "Example 4: "
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
l = sort(a,axis,"mergesort")
f = theano.function([a,axis],l)
print testMatrix
print f(testMatrix,1)
print "------------------------------"
print "Example 5: Check __eq__ function "
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
a1 = SortOp("mergesort",[])
a2 = SortOp("quicksort",[])
#All the below should give true
print a1 == a2
print a1 == SortOp("mergesort",[])
print a2 == SortOp("quicksort",[])
if __name__ == '__main__':
if 0:
unittest.main()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论