提交 f4c06f34 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Refactor tests, and test that axis=None works

上级 95a63ec7
......@@ -5584,42 +5584,46 @@ def test_transpose():
assert numpy.all(t3d == numpy.transpose(x3v, [0, 2, 1]))
def test_sort():
class test_sort(unittest.TestCase):
testMatrix = [[4, 9, 1], [1, 3, 2]]
testVector = [1, 10, 0, 2]
def setUp(self):
self.rng = numpy.random.RandomState(seed=utt.fetch_seed())
self.m_val = self.rng.rand(3,2)
self.v_val = self.rng.rand(4)
print "Example 1: "
def test1(self):
a = theano.tensor.dmatrix()
w = sort(a)
f = theano.function([a], w)
assert numpy.allclose(f(testMatrix), numpy.sort(testMatrix))
print "------------------------------"
assert numpy.allclose(f(self.m_val), numpy.sort(self.m_val))
print "Example 2: "
def test2(self):
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
w = sort(a, axis)
f = theano.function([a, axis], w)
assert numpy.allclose(f(testMatrix, 1), numpy.sort(testMatrix, 1))
print "------------------------------"
for axis_val in 0, 1:
assert numpy.allclose(
f(self.m_val, axis_val),
numpy.sort(self.m_val, axis_val))
print "Example 3: "
def test3(self):
a = theano.tensor.dvector()
w2 = sort(a)
f = theano.function([a], w2)
assert numpy.allclose(f(testVector), numpy.sort(testVector))
print "------------------------------"
assert numpy.allclose(f(self.v_val), numpy.sort(self.v_val))
print "Example 4: "
def test4(self):
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
l = sort(a, axis, "mergesort")
f = theano.function([a, axis], l)
assert numpy.allclose(f(testMatrix, 1), numpy.sort(testMatrix, 1))
print "------------------------------"
for axis_val in 0, 1:
assert numpy.allclose(
f(self.m_val, axis_val),
numpy.sort(self.m_val, axis_val))
print "Example 5: Check __eq__ function "
def test5(self):
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
a1 = SortOp("mergesort", [])
......@@ -5630,14 +5634,12 @@ def test_sort():
assert a1 == SortOp("mergesort", [])
assert a2 == SortOp("quicksort", [])
print "Example 5: axis=None"
def test_None(self):
a = theano.tensor.dmatrix()
try:
l = sort(a, None)
except ValueError:
pass
else:
assert False
f = theano.function([a], l)
assert numpy.allclose(f(self.m_val),
numpy.sort(self.m_val, None))
if __name__ == '__main__':
if 0:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论