提交 f4c06f34 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Refactor tests, and test that axis=None works

上级 95a63ec7
...@@ -5584,42 +5584,46 @@ def test_transpose(): ...@@ -5584,42 +5584,46 @@ def test_transpose():
assert numpy.all(t3d == numpy.transpose(x3v, [0, 2, 1])) assert numpy.all(t3d == numpy.transpose(x3v, [0, 2, 1]))
def test_sort(): class test_sort(unittest.TestCase):
testMatrix = [[4, 9, 1], [1, 3, 2]] def setUp(self):
testVector = [1, 10, 0, 2] self.rng = numpy.random.RandomState(seed=utt.fetch_seed())
self.m_val = self.rng.rand(3,2)
self.v_val = self.rng.rand(4)
print "Example 1: " def test1(self):
a = theano.tensor.dmatrix() a = theano.tensor.dmatrix()
w = sort(a) w = sort(a)
f = theano.function([a], w) f = theano.function([a], w)
assert numpy.allclose(f(testMatrix), numpy.sort(testMatrix)) assert numpy.allclose(f(self.m_val), numpy.sort(self.m_val))
print "------------------------------"
print "Example 2: " def test2(self):
a = theano.tensor.dmatrix() a = theano.tensor.dmatrix()
axis = theano.tensor.scalar() axis = theano.tensor.scalar()
w = sort(a, axis) w = sort(a, axis)
f = theano.function([a, axis], w) f = theano.function([a, axis], w)
assert numpy.allclose(f(testMatrix, 1), numpy.sort(testMatrix, 1)) for axis_val in 0, 1:
print "------------------------------" assert numpy.allclose(
f(self.m_val, axis_val),
numpy.sort(self.m_val, axis_val))
print "Example 3: " def test3(self):
a = theano.tensor.dvector() a = theano.tensor.dvector()
w2 = sort(a) w2 = sort(a)
f = theano.function([a], w2) f = theano.function([a], w2)
assert numpy.allclose(f(testVector), numpy.sort(testVector)) assert numpy.allclose(f(self.v_val), numpy.sort(self.v_val))
print "------------------------------"
print "Example 4: " def test4(self):
a = theano.tensor.dmatrix() a = theano.tensor.dmatrix()
axis = theano.tensor.scalar() axis = theano.tensor.scalar()
l = sort(a, axis, "mergesort") l = sort(a, axis, "mergesort")
f = theano.function([a, axis], l) f = theano.function([a, axis], l)
assert numpy.allclose(f(testMatrix, 1), numpy.sort(testMatrix, 1)) for axis_val in 0, 1:
print "------------------------------" assert numpy.allclose(
f(self.m_val, axis_val),
numpy.sort(self.m_val, axis_val))
print "Example 5: Check __eq__ function " def test5(self):
a = theano.tensor.dmatrix() a = theano.tensor.dmatrix()
axis = theano.tensor.scalar() axis = theano.tensor.scalar()
a1 = SortOp("mergesort", []) a1 = SortOp("mergesort", [])
...@@ -5630,14 +5634,12 @@ def test_sort(): ...@@ -5630,14 +5634,12 @@ def test_sort():
assert a1 == SortOp("mergesort", []) assert a1 == SortOp("mergesort", [])
assert a2 == SortOp("quicksort", []) assert a2 == SortOp("quicksort", [])
print "Example 5: axis=None" def test_None(self):
a = theano.tensor.dmatrix() a = theano.tensor.dmatrix()
try:
l = sort(a, None) l = sort(a, None)
except ValueError: f = theano.function([a], l)
pass assert numpy.allclose(f(self.m_val),
else: numpy.sort(self.m_val, None))
assert False
if __name__ == '__main__': if __name__ == '__main__':
if 0: if 0:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论