提交 7add8bc5 authored 作者: lamblin's avatar lamblin

Merge pull request #475 from lamblin/sort_axis_none

Implement tensor.sort(..., axis=None)
...@@ -39,6 +39,7 @@ New features: ...@@ -39,6 +39,7 @@ New features:
* Alloc, GpuAlloc are not always pre-computed (constant_folding optimization) * Alloc, GpuAlloc are not always pre-computed (constant_folding optimization)
at compile time if all their inputs are constant. at compile time if all their inputs are constant.
(Frederic B., Pascal L., reported by Sander Dieleman) (Frederic B., Pascal L., reported by Sander Dieleman)
* New Op tensor.sort(), wrapping numpy.sort (Hani Almousli)
============= =============
Release Notes Release Notes
......
...@@ -5790,12 +5790,15 @@ class SortOp(theano.Op): ...@@ -5790,12 +5790,15 @@ class SortOp(theano.Op):
str(self.order)) str(self.order))
def make_node(self, input, axis=-1): def make_node(self, input, axis=-1):
if axis is None:
raise ValueError("Current Implementation does not support"
" axis=None")
input = theano.tensor.as_tensor_variable(input) input = theano.tensor.as_tensor_variable(input)
if axis is None:
axis = Constant(gof.generic, None)
# axis=None flattens the array before sorting
out_type = tensor(dtype=input.dtype, broadcastable=[False])
else:
axis = theano.tensor.as_tensor_variable(axis) axis = theano.tensor.as_tensor_variable(axis)
return theano.Apply(self, [input, axis], [input.type()]) out_type = input.type()
return theano.Apply(self, [input, axis], [out_type])
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
a = inputs[0] a = inputs[0]
...@@ -5804,6 +5807,10 @@ class SortOp(theano.Op): ...@@ -5804,6 +5807,10 @@ class SortOp(theano.Op):
z[0] = numpy.sort(a, axis, self.kind, self.order) z[0] = numpy.sort(a, axis, self.kind, self.order)
def infer_shape(self, node, inputs_shapes): def infer_shape(self, node, inputs_shapes):
if inputs_shapes[1] is None:
# That probably means axis = None,
# so the array is flattened before being sorted
return [(mul(*inputs_shapes[0]),)]
return [inputs_shapes[0]] return [inputs_shapes[0]]
#**** It need the argsort, so we can't do it now. #**** It need the argsort, so we can't do it now.
......
...@@ -5584,42 +5584,46 @@ def test_transpose(): ...@@ -5584,42 +5584,46 @@ def test_transpose():
assert numpy.all(t3d == numpy.transpose(x3v, [0, 2, 1])) assert numpy.all(t3d == numpy.transpose(x3v, [0, 2, 1]))
def test_sort(): class test_sort(unittest.TestCase):
testMatrix = [[4, 9, 1], [1, 3, 2]] def setUp(self):
testVector = [1, 10, 0, 2] self.rng = numpy.random.RandomState(seed=utt.fetch_seed())
self.m_val = self.rng.rand(3,2)
self.v_val = self.rng.rand(4)
print "Example 1: " def test1(self):
a = theano.tensor.dmatrix() a = theano.tensor.dmatrix()
w = sort(a) w = sort(a)
f = theano.function([a], w) f = theano.function([a], w)
assert numpy.allclose(f(testMatrix), numpy.sort(testMatrix)) assert numpy.allclose(f(self.m_val), numpy.sort(self.m_val))
print "------------------------------"
print "Example 2: " def test2(self):
a = theano.tensor.dmatrix() a = theano.tensor.dmatrix()
axis = theano.tensor.scalar() axis = theano.tensor.scalar()
w = sort(a, axis) w = sort(a, axis)
f = theano.function([a, axis], w) f = theano.function([a, axis], w)
assert numpy.allclose(f(testMatrix, 1), numpy.sort(testMatrix, 1)) for axis_val in 0, 1:
print "------------------------------" assert numpy.allclose(
f(self.m_val, axis_val),
numpy.sort(self.m_val, axis_val))
print "Example 3: " def test3(self):
a = theano.tensor.dvector() a = theano.tensor.dvector()
w2 = sort(a) w2 = sort(a)
f = theano.function([a], w2) f = theano.function([a], w2)
assert numpy.allclose(f(testVector), numpy.sort(testVector)) assert numpy.allclose(f(self.v_val), numpy.sort(self.v_val))
print "------------------------------"
print "Example 4: " def test4(self):
a = theano.tensor.dmatrix() a = theano.tensor.dmatrix()
axis = theano.tensor.scalar() axis = theano.tensor.scalar()
l = sort(a, axis, "mergesort") l = sort(a, axis, "mergesort")
f = theano.function([a, axis], l) f = theano.function([a, axis], l)
assert numpy.allclose(f(testMatrix, 1), numpy.sort(testMatrix, 1)) for axis_val in 0, 1:
print "------------------------------" assert numpy.allclose(
f(self.m_val, axis_val),
numpy.sort(self.m_val, axis_val))
print "Example 5: Check __eq__ function " def test5(self):
a = theano.tensor.dmatrix() a = theano.tensor.dmatrix()
axis = theano.tensor.scalar() axis = theano.tensor.scalar()
a1 = SortOp("mergesort", []) a1 = SortOp("mergesort", [])
...@@ -5630,14 +5634,12 @@ def test_sort(): ...@@ -5630,14 +5634,12 @@ def test_sort():
assert a1 == SortOp("mergesort", []) assert a1 == SortOp("mergesort", [])
assert a2 == SortOp("quicksort", []) assert a2 == SortOp("quicksort", [])
print "Example 5: axis=None" def test_None(self):
a = theano.tensor.dmatrix() a = theano.tensor.dmatrix()
try:
l = sort(a, None) l = sort(a, None)
except ValueError: f = theano.function([a], l)
pass assert numpy.allclose(f(self.m_val),
else: numpy.sort(self.m_val, None))
assert False
if __name__ == '__main__': if __name__ == '__main__':
if 0: if 0:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论