提交 40274190 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Put back the initial value of compute_test_value at the end of tests

上级 6f2690dd
......@@ -8,6 +8,8 @@ from theano import tensor as T
class TestComputeTestValue(unittest.TestCase):
def test_variable_only(self):
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = True
x = T.matrix('x')
......@@ -21,9 +23,13 @@ class TestComputeTestValue(unittest.TestCase):
# this test should fail
y.tag.test_value = numpy.random.rand(6,5)
self.assertRaises(ValueError, T.dot, x, y)
finally:
theano.config.compute_test_value = orig_compute_test_value
def test_compute_flag(self):
def test_compute_flag(self):
orig_compute_test_value = theano.config.compute_test_value
try:
x = T.matrix('x')
y = T.matrix('y')
y.tag.test_value = numpy.random.rand(4,5)
......@@ -37,8 +43,12 @@ class TestComputeTestValue(unittest.TestCase):
self.assertRaises(Warning, T.dot, x, y)
theano.config.compute_test_value = 'err'
self.assertRaises(ValueError, T.dot, x, y)
finally:
theano.config.compute_test_value = orig_compute_test_value
def test_string_var(self):
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = True
x = T.matrix('x')
......@@ -56,8 +66,12 @@ class TestComputeTestValue(unittest.TestCase):
# this test should fail
z.set_value(numpy.random.rand(7,6))
self.assertRaises(ValueError, f, x, y, z)
finally:
theano.config.compute_test_value = orig_compute_test_value
def test_shared(self):
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = True
x = T.matrix('x')
......@@ -70,8 +84,12 @@ class TestComputeTestValue(unittest.TestCase):
# this test should fail
y.set_value(numpy.random.rand(5,6))
self.assertRaises(ValueError, T.dot, x, y)
finally:
theano.config.compute_test_value = orig_compute_test_value
def test_ndarray(self):
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = True
x = numpy.random.rand(2,3)
......@@ -83,8 +101,12 @@ class TestComputeTestValue(unittest.TestCase):
# this test should fail
x = numpy.random.rand(2,4)
self.assertRaises(ValueError, T.dot, x, y)
finally:
theano.config.compute_test_value = orig_compute_test_value
def test_constant(self):
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = True
x = T.constant(numpy.random.rand(2,3))
......@@ -96,3 +118,5 @@ class TestComputeTestValue(unittest.TestCase):
# this test should fail
x = T.constant(numpy.random.rand(2,4))
self.assertRaises(ValueError, T.dot, x, y)
finally:
theano.config.compute_test_value = orig_compute_test_value
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论