提交 7add8bc5 authored 作者: lamblin's avatar lamblin

Merge pull request #475 from lamblin/sort_axis_none

Implement tensor.sort(..., axis=None)
......@@ -39,6 +39,7 @@ New features:
* Alloc, GpuAlloc are not always pre-computed (constant_folding optimization)
at compile time if all their inputs are constant.
(Frederic B., Pascal L., reported by Sander Dieleman)
* New Op tensor.sort(), wrapping numpy.sort (Hani Almousli)
=============
Release Notes
......
......@@ -5790,12 +5790,15 @@ class SortOp(theano.Op):
str(self.order))
def make_node(self, input, axis=-1):
if axis is None:
raise ValueError("Current Implementation does not support"
" axis=None")
input = theano.tensor.as_tensor_variable(input)
if axis is None:
axis = Constant(gof.generic, None)
# axis=None flattens the array before sorting
out_type = tensor(dtype=input.dtype, broadcastable=[False])
else:
axis = theano.tensor.as_tensor_variable(axis)
return theano.Apply(self, [input, axis], [input.type()])
out_type = input.type()
return theano.Apply(self, [input, axis], [out_type])
def perform(self, node, inputs, output_storage):
a = inputs[0]
......@@ -5804,6 +5807,10 @@ class SortOp(theano.Op):
z[0] = numpy.sort(a, axis, self.kind, self.order)
def infer_shape(self, node, inputs_shapes):
if inputs_shapes[1] is None:
# That probably means axis = None,
# so the array is flattened before being sorted
return [(mul(*inputs_shapes[0]),)]
return [inputs_shapes[0]]
#**** It need the argsort, so we can't do it now.
......
......@@ -5584,42 +5584,46 @@ def test_transpose():
assert numpy.all(t3d == numpy.transpose(x3v, [0, 2, 1]))
def test_sort():
class test_sort(unittest.TestCase):
testMatrix = [[4, 9, 1], [1, 3, 2]]
testVector = [1, 10, 0, 2]
def setUp(self):
self.rng = numpy.random.RandomState(seed=utt.fetch_seed())
self.m_val = self.rng.rand(3,2)
self.v_val = self.rng.rand(4)
print "Example 1: "
def test1(self):
a = theano.tensor.dmatrix()
w = sort(a)
f = theano.function([a], w)
assert numpy.allclose(f(testMatrix), numpy.sort(testMatrix))
print "------------------------------"
assert numpy.allclose(f(self.m_val), numpy.sort(self.m_val))
print "Example 2: "
def test2(self):
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
w = sort(a, axis)
f = theano.function([a, axis], w)
assert numpy.allclose(f(testMatrix, 1), numpy.sort(testMatrix, 1))
print "------------------------------"
for axis_val in 0, 1:
assert numpy.allclose(
f(self.m_val, axis_val),
numpy.sort(self.m_val, axis_val))
print "Example 3: "
def test3(self):
a = theano.tensor.dvector()
w2 = sort(a)
f = theano.function([a], w2)
assert numpy.allclose(f(testVector), numpy.sort(testVector))
print "------------------------------"
assert numpy.allclose(f(self.v_val), numpy.sort(self.v_val))
print "Example 4: "
def test4(self):
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
l = sort(a, axis, "mergesort")
f = theano.function([a, axis], l)
assert numpy.allclose(f(testMatrix, 1), numpy.sort(testMatrix, 1))
print "------------------------------"
for axis_val in 0, 1:
assert numpy.allclose(
f(self.m_val, axis_val),
numpy.sort(self.m_val, axis_val))
print "Example 5: Check __eq__ function "
def test5(self):
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
a1 = SortOp("mergesort", [])
......@@ -5630,14 +5634,12 @@ def test_sort():
assert a1 == SortOp("mergesort", [])
assert a2 == SortOp("quicksort", [])
print "Example 5: axis=None"
def test_None(self):
a = theano.tensor.dmatrix()
try:
l = sort(a, None)
except ValueError:
pass
else:
assert False
f = theano.function([a], l)
assert numpy.allclose(f(self.m_val),
numpy.sort(self.m_val, None))
if __name__ == '__main__':
if 0:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论