提交 c6e8e146 authored 作者: Frederic Bastien's avatar Frederic Bastien

Split test to help work around travis timeout. They are super fast here.

上级 a1abed83
......@@ -84,14 +84,13 @@ class test_sort(unittest.TestCase):
data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, None), [data])
def test_grad_negative_axis(self):
# test 2D
def test_grad_negative_axis_2d(self):
data = np.random.rand(2, 3).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -1), [data])
data = np.random.rand(2, 3).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -2), [data])
# test 3D
def test_grad_negative_axis_3d(self):
data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -1), [data])
data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
......@@ -99,7 +98,7 @@ class test_sort(unittest.TestCase):
data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -3), [data])
# test 4D
def test_grad_negative_axis_4d(self):
data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -1), [data])
data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX)
......@@ -109,14 +108,13 @@ class test_sort(unittest.TestCase):
data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, -4), [data])
def test_grad_nonnegative_axis(self):
# test 2D
def test_grad_nonnegative_axis_2d(self):
data = np.random.rand(2, 3).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 0), [data])
data = np.random.rand(2, 3).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 1), [data])
# test 3D
def test_grad_nonnegative_axis_3d(self):
data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 0), [data])
data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
......@@ -124,7 +122,7 @@ class test_sort(unittest.TestCase):
data = np.random.rand(2, 3, 4).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 2), [data])
# test 4D
def test_grad_nonnegative_axis_4d(self):
data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 0), [data])
data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论