提交 c31d4f6b authored 作者: Frederic's avatar Frederic

pep8

上级 9150cbe6
......@@ -5763,22 +5763,24 @@ class SortOp(theano.Op):
This class is a wrapper for numpy sort function
"""
def __init__(self, kind, order=None):
self.kind = kind
self.order = order
self.kind = kind
self.order = order
def __eq__(self, other):
return type(self) == type(other) and self.order == other.order and self.kind==other.kind
return (type(self) == type(other) and self.order == other.order and
self.kind == other.kind)
def __hash__(self):
return hash(type(self)) ^ hash(self.order) ^ hash(self.kind)
def __str__(self):
return self.__class__.__name__ + "{%s, %s}" % (self.kind, str(self.order))
return self.__class__.__name__ + "{%s, %s}" % (self.kind,
str(self.order))
def make_node(self, input, axis=-1):
if axis is None:
raise ValueError("Current Implementation does not sipport axis=None")
return
if axis is None:
raise ValueError("Current Implementation does not support"
" axis=None")
input = theano.tensor.as_tensor_variable(input)
axis = theano.tensor.as_tensor_variable(axis)
return theano.Apply(self, [input, axis], [input.type()])
......@@ -5787,7 +5789,7 @@ class SortOp(theano.Op):
a = inputs[0]
axis = inputs[1]
z = output_storage[0]
z[0] = numpy.sort(a,axis,self.kind,self.order)
z[0] = numpy.sort(a, axis, self.kind, self.order)
def infer_shape(self, node, inputs_shapes):
return [inputs_shapes[0]]
......@@ -5821,7 +5823,9 @@ def sort(a, axis=-1, kind='quicksort', order=None):
order : list, optional
When a is a structured array, this argument specifies which fields to compare first, second, and so on. This list does not need to include all of the fields.
When a is a structured array, this argument specifies which
fields to compare first, second, and so on. This list does not
need to include all of the fields.
"""
return SortOp(kind, order)(a, axis)
......@@ -5586,54 +5586,54 @@ def test_transpose():
def test_sort():
testMatrix = [[4,9,1],[1,3,2]]
testVector = [1,10,0,2]
testMatrix = [[4, 9, 1], [1, 3, 2]]
testVector = [1, 10, 0, 2]
print "Example 1: "
a = theano.tensor.dmatrix()
w = sort(a)
f = theano.function([a],w)
f = theano.function([a], w)
assert numpy.allclose(f(testMatrix), numpy.sort(testMatrix))
print "------------------------------"
print "Example 2: "
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
w = sort(a,axis)
f = theano.function([a,axis],w)
w = sort(a, axis)
f = theano.function([a, axis], w)
assert numpy.allclose(f(testMatrix, 1), numpy.sort(testMatrix, 1))
print "------------------------------"
print "Example 3: "
a = theano.tensor.dvector()
w2 = sort(a)
f = theano.function([a],w2)
f = theano.function([a], w2)
assert numpy.allclose(f(testVector), numpy.sort(testVector))
print "------------------------------"
print "Example 4: "
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
l = sort(a,axis,"mergesort")
f = theano.function([a,axis],l)
l = sort(a, axis, "mergesort")
f = theano.function([a, axis], l)
assert numpy.allclose(f(testMatrix, 1), numpy.sort(testMatrix, 1))
print "------------------------------"
print "Example 5: Check __eq__ function "
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
a1 = SortOp("mergesort",[])
a2 = SortOp("quicksort",[])
a1 = SortOp("mergesort", [])
a2 = SortOp("quicksort", [])
#All the below should give true
assert a1 != a2
assert a1 == SortOp("mergesort",[])
assert a2 == SortOp("quicksort",[])
assert a1 == SortOp("mergesort", [])
assert a2 == SortOp("quicksort", [])
print "Example 5: axis=None"
a = theano.tensor.dmatrix()
try:
l = sort(a,None)
l = sort(a, None)
except ValueError:
pass
else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论