提交 40274190 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Put back the initial value of compute_test_value at the end of tests

上级 6f2690dd
...@@ -8,6 +8,8 @@ from theano import tensor as T ...@@ -8,6 +8,8 @@ from theano import tensor as T
class TestComputeTestValue(unittest.TestCase): class TestComputeTestValue(unittest.TestCase):
def test_variable_only(self): def test_variable_only(self):
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = True theano.config.compute_test_value = True
x = T.matrix('x') x = T.matrix('x')
...@@ -21,9 +23,13 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -21,9 +23,13 @@ class TestComputeTestValue(unittest.TestCase):
# this test should fail # this test should fail
y.tag.test_value = numpy.random.rand(6,5) y.tag.test_value = numpy.random.rand(6,5)
self.assertRaises(ValueError, T.dot, x, y) self.assertRaises(ValueError, T.dot, x, y)
finally:
theano.config.compute_test_value = orig_compute_test_value
def test_compute_flag(self):
def test_compute_flag(self):
orig_compute_test_value = theano.config.compute_test_value
try:
x = T.matrix('x') x = T.matrix('x')
y = T.matrix('y') y = T.matrix('y')
y.tag.test_value = numpy.random.rand(4,5) y.tag.test_value = numpy.random.rand(4,5)
...@@ -37,8 +43,12 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -37,8 +43,12 @@ class TestComputeTestValue(unittest.TestCase):
self.assertRaises(Warning, T.dot, x, y) self.assertRaises(Warning, T.dot, x, y)
theano.config.compute_test_value = 'err' theano.config.compute_test_value = 'err'
self.assertRaises(ValueError, T.dot, x, y) self.assertRaises(ValueError, T.dot, x, y)
finally:
theano.config.compute_test_value = orig_compute_test_value
def test_string_var(self): def test_string_var(self):
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = True theano.config.compute_test_value = True
x = T.matrix('x') x = T.matrix('x')
...@@ -56,8 +66,12 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -56,8 +66,12 @@ class TestComputeTestValue(unittest.TestCase):
# this test should fail # this test should fail
z.set_value(numpy.random.rand(7,6)) z.set_value(numpy.random.rand(7,6))
self.assertRaises(ValueError, f, x, y, z) self.assertRaises(ValueError, f, x, y, z)
finally:
theano.config.compute_test_value = orig_compute_test_value
def test_shared(self): def test_shared(self):
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = True theano.config.compute_test_value = True
x = T.matrix('x') x = T.matrix('x')
...@@ -70,8 +84,12 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -70,8 +84,12 @@ class TestComputeTestValue(unittest.TestCase):
# this test should fail # this test should fail
y.set_value(numpy.random.rand(5,6)) y.set_value(numpy.random.rand(5,6))
self.assertRaises(ValueError, T.dot, x, y) self.assertRaises(ValueError, T.dot, x, y)
finally:
theano.config.compute_test_value = orig_compute_test_value
def test_ndarray(self): def test_ndarray(self):
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = True theano.config.compute_test_value = True
x = numpy.random.rand(2,3) x = numpy.random.rand(2,3)
...@@ -83,8 +101,12 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -83,8 +101,12 @@ class TestComputeTestValue(unittest.TestCase):
# this test should fail # this test should fail
x = numpy.random.rand(2,4) x = numpy.random.rand(2,4)
self.assertRaises(ValueError, T.dot, x, y) self.assertRaises(ValueError, T.dot, x, y)
finally:
theano.config.compute_test_value = orig_compute_test_value
def test_constant(self): def test_constant(self):
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = True theano.config.compute_test_value = True
x = T.constant(numpy.random.rand(2,3)) x = T.constant(numpy.random.rand(2,3))
...@@ -96,3 +118,5 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -96,3 +118,5 @@ class TestComputeTestValue(unittest.TestCase):
# this test should fail # this test should fail
x = T.constant(numpy.random.rand(2,4)) x = T.constant(numpy.random.rand(2,4))
self.assertRaises(ValueError, T.dot, x, y) self.assertRaises(ValueError, T.dot, x, y)
finally:
theano.config.compute_test_value = orig_compute_test_value
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论