提交 f4c06f34 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Refactor tests, and test that axis=None works

上级 95a63ec7
...@@ -5584,60 +5584,62 @@ def test_transpose(): ...@@ -5584,60 +5584,62 @@ def test_transpose():
assert numpy.all(t3d == numpy.transpose(x3v, [0, 2, 1])) assert numpy.all(t3d == numpy.transpose(x3v, [0, 2, 1]))
def test_sort(): class test_sort(unittest.TestCase):
testMatrix = [[4, 9, 1], [1, 3, 2]] def setUp(self):
testVector = [1, 10, 0, 2] self.rng = numpy.random.RandomState(seed=utt.fetch_seed())
self.m_val = self.rng.rand(3,2)
print "Example 1: " self.v_val = self.rng.rand(4)
a = theano.tensor.dmatrix()
w = sort(a) def test1(self):
f = theano.function([a], w) a = theano.tensor.dmatrix()
assert numpy.allclose(f(testMatrix), numpy.sort(testMatrix)) w = sort(a)
print "------------------------------" f = theano.function([a], w)
assert numpy.allclose(f(self.m_val), numpy.sort(self.m_val))
print "Example 2: "
a = theano.tensor.dmatrix() def test2(self):
axis = theano.tensor.scalar() a = theano.tensor.dmatrix()
w = sort(a, axis) axis = theano.tensor.scalar()
f = theano.function([a, axis], w) w = sort(a, axis)
assert numpy.allclose(f(testMatrix, 1), numpy.sort(testMatrix, 1)) f = theano.function([a, axis], w)
print "------------------------------" for axis_val in 0, 1:
assert numpy.allclose(
print "Example 3: " f(self.m_val, axis_val),
a = theano.tensor.dvector() numpy.sort(self.m_val, axis_val))
w2 = sort(a)
f = theano.function([a], w2) def test3(self):
assert numpy.allclose(f(testVector), numpy.sort(testVector)) a = theano.tensor.dvector()
print "------------------------------" w2 = sort(a)
f = theano.function([a], w2)
print "Example 4: " assert numpy.allclose(f(self.v_val), numpy.sort(self.v_val))
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar() def test4(self):
l = sort(a, axis, "mergesort") a = theano.tensor.dmatrix()
f = theano.function([a, axis], l) axis = theano.tensor.scalar()
assert numpy.allclose(f(testMatrix, 1), numpy.sort(testMatrix, 1)) l = sort(a, axis, "mergesort")
print "------------------------------" f = theano.function([a, axis], l)
for axis_val in 0, 1:
print "Example 5: Check __eq__ function " assert numpy.allclose(
a = theano.tensor.dmatrix() f(self.m_val, axis_val),
axis = theano.tensor.scalar() numpy.sort(self.m_val, axis_val))
a1 = SortOp("mergesort", [])
a2 = SortOp("quicksort", []) def test5(self):
a = theano.tensor.dmatrix()
#All the below should give true axis = theano.tensor.scalar()
assert a1 != a2 a1 = SortOp("mergesort", [])
assert a1 == SortOp("mergesort", []) a2 = SortOp("quicksort", [])
assert a2 == SortOp("quicksort", [])
#All the below should give true
print "Example 5: axis=None" assert a1 != a2
a = theano.tensor.dmatrix() assert a1 == SortOp("mergesort", [])
try: assert a2 == SortOp("quicksort", [])
def test_None(self):
a = theano.tensor.dmatrix()
l = sort(a, None) l = sort(a, None)
except ValueError: f = theano.function([a], l)
pass assert numpy.allclose(f(self.m_val),
else: numpy.sort(self.m_val, None))
assert False
if __name__ == '__main__': if __name__ == '__main__':
if 0: if 0:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论