提交 c31d4f6b authored 作者: Frederic's avatar Frederic

pep8

上级 9150cbe6
...@@ -5767,18 +5767,20 @@ class SortOp(theano.Op): ...@@ -5767,18 +5767,20 @@ class SortOp(theano.Op):
self.order = order self.order = order
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.order == other.order and self.kind==other.kind return (type(self) == type(other) and self.order == other.order and
self.kind == other.kind)
def __hash__(self): def __hash__(self):
return hash(type(self)) ^ hash(self.order) ^ hash(self.kind) return hash(type(self)) ^ hash(self.order) ^ hash(self.kind)
def __str__(self): def __str__(self):
return self.__class__.__name__ + "{%s, %s}" % (self.kind, str(self.order)) return self.__class__.__name__ + "{%s, %s}" % (self.kind,
str(self.order))
def make_node(self, input, axis=-1): def make_node(self, input, axis=-1):
if axis is None: if axis is None:
raise ValueError("Current Implementation does not sipport axis=None") raise ValueError("Current Implementation does not support"
return " axis=None")
input = theano.tensor.as_tensor_variable(input) input = theano.tensor.as_tensor_variable(input)
axis = theano.tensor.as_tensor_variable(axis) axis = theano.tensor.as_tensor_variable(axis)
return theano.Apply(self, [input, axis], [input.type()]) return theano.Apply(self, [input, axis], [input.type()])
...@@ -5787,7 +5789,7 @@ class SortOp(theano.Op): ...@@ -5787,7 +5789,7 @@ class SortOp(theano.Op):
a = inputs[0] a = inputs[0]
axis = inputs[1] axis = inputs[1]
z = output_storage[0] z = output_storage[0]
z[0] = numpy.sort(a,axis,self.kind,self.order) z[0] = numpy.sort(a, axis, self.kind, self.order)
def infer_shape(self, node, inputs_shapes): def infer_shape(self, node, inputs_shapes):
return [inputs_shapes[0]] return [inputs_shapes[0]]
...@@ -5821,7 +5823,9 @@ def sort(a, axis=-1, kind='quicksort', order=None): ...@@ -5821,7 +5823,9 @@ def sort(a, axis=-1, kind='quicksort', order=None):
order : list, optional order : list, optional
When a is a structured array, this argument specifies which fields to compare first, second, and so on. This list does not need to include all of the fields. When a is a structured array, this argument specifies which
fields to compare first, second, and so on. This list does not
need to include all of the fields.
""" """
return SortOp(kind, order)(a, axis) return SortOp(kind, order)(a, axis)
...@@ -5586,54 +5586,54 @@ def test_transpose(): ...@@ -5586,54 +5586,54 @@ def test_transpose():
def test_sort(): def test_sort():
testMatrix = [[4,9,1],[1,3,2]] testMatrix = [[4, 9, 1], [1, 3, 2]]
testVector = [1,10,0,2] testVector = [1, 10, 0, 2]
print "Example 1: " print "Example 1: "
a = theano.tensor.dmatrix() a = theano.tensor.dmatrix()
w = sort(a) w = sort(a)
f = theano.function([a],w) f = theano.function([a], w)
assert numpy.allclose(f(testMatrix), numpy.sort(testMatrix)) assert numpy.allclose(f(testMatrix), numpy.sort(testMatrix))
print "------------------------------" print "------------------------------"
print "Example 2: " print "Example 2: "
a = theano.tensor.dmatrix() a = theano.tensor.dmatrix()
axis = theano.tensor.scalar() axis = theano.tensor.scalar()
w = sort(a,axis) w = sort(a, axis)
f = theano.function([a,axis],w) f = theano.function([a, axis], w)
assert numpy.allclose(f(testMatrix, 1), numpy.sort(testMatrix, 1)) assert numpy.allclose(f(testMatrix, 1), numpy.sort(testMatrix, 1))
print "------------------------------" print "------------------------------"
print "Example 3: " print "Example 3: "
a = theano.tensor.dvector() a = theano.tensor.dvector()
w2 = sort(a) w2 = sort(a)
f = theano.function([a],w2) f = theano.function([a], w2)
assert numpy.allclose(f(testVector), numpy.sort(testVector)) assert numpy.allclose(f(testVector), numpy.sort(testVector))
print "------------------------------" print "------------------------------"
print "Example 4: " print "Example 4: "
a = theano.tensor.dmatrix() a = theano.tensor.dmatrix()
axis = theano.tensor.scalar() axis = theano.tensor.scalar()
l = sort(a,axis,"mergesort") l = sort(a, axis, "mergesort")
f = theano.function([a,axis],l) f = theano.function([a, axis], l)
assert numpy.allclose(f(testMatrix, 1), numpy.sort(testMatrix, 1)) assert numpy.allclose(f(testMatrix, 1), numpy.sort(testMatrix, 1))
print "------------------------------" print "------------------------------"
print "Example 5: Check __eq__ function " print "Example 5: Check __eq__ function "
a = theano.tensor.dmatrix() a = theano.tensor.dmatrix()
axis = theano.tensor.scalar() axis = theano.tensor.scalar()
a1 = SortOp("mergesort",[]) a1 = SortOp("mergesort", [])
a2 = SortOp("quicksort",[]) a2 = SortOp("quicksort", [])
#All the below should give true #All the below should give true
assert a1 != a2 assert a1 != a2
assert a1 == SortOp("mergesort",[]) assert a1 == SortOp("mergesort", [])
assert a2 == SortOp("quicksort",[]) assert a2 == SortOp("quicksort", [])
print "Example 5: axis=None" print "Example 5: axis=None"
a = theano.tensor.dmatrix() a = theano.tensor.dmatrix()
try: try:
l = sort(a,None) l = sort(a, None)
except ValueError: except ValueError:
pass pass
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论