提交 3b14da6b authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Silence expected _logger error messages in some tests

上级 af9d4e8b
...@@ -650,8 +650,12 @@ class T_max_and_argmax(unittest.TestCase): ...@@ -650,8 +650,12 @@ class T_max_and_argmax(unittest.TestCase):
def test2_invalid(self): def test2_invalid(self):
n = as_tensor_variable(numpy.random.rand(2,3)) n = as_tensor_variable(numpy.random.rand(2,3))
# Silence expected error messages
old_stderr = sys.stderr old_stderr = sys.stderr
sys.stderr = StringIO.StringIO() sys.stderr = StringIO.StringIO()
_logger = logging.getLogger('theano.gof.opt')
oldlevel = _logger.getEffectiveLevel()
_logger.setLevel(logging.CRITICAL)
try: try:
eval_outputs(max_and_argmax(n,3)) eval_outputs(max_and_argmax(n,3))
assert False assert False
...@@ -659,6 +663,7 @@ class T_max_and_argmax(unittest.TestCase): ...@@ -659,6 +663,7 @@ class T_max_and_argmax(unittest.TestCase):
pass pass
finally: finally:
sys.stderr = old_stderr sys.stderr = old_stderr
_logger.setLevel(oldlevel)
def test2_invalid_neg(self): def test2_invalid_neg(self):
n = as_tensor_variable(numpy.random.rand(2,3)) n = as_tensor_variable(numpy.random.rand(2,3))
old_stderr = sys.stderr old_stderr = sys.stderr
...@@ -719,8 +724,12 @@ class T_subtensor(unittest.TestCase): ...@@ -719,8 +724,12 @@ class T_subtensor(unittest.TestCase):
n = as_tensor_variable(numpy.ones(3)) n = as_tensor_variable(numpy.ones(3))
t = n[7] t = n[7]
self.failUnless(isinstance(t.owner.op, Subtensor)) self.failUnless(isinstance(t.owner.op, Subtensor))
# Silence expected error messages
old_stderr = sys.stderr old_stderr = sys.stderr
sys.stderr = StringIO.StringIO() sys.stderr = StringIO.StringIO()
_logger = logging.getLogger('theano.gof.opt')
oldlevel = _logger.getEffectiveLevel()
_logger.setLevel(logging.CRITICAL)
try: try:
tval = eval_outputs([t]) tval = eval_outputs([t])
assert 0 assert 0
...@@ -729,6 +738,7 @@ class T_subtensor(unittest.TestCase): ...@@ -729,6 +738,7 @@ class T_subtensor(unittest.TestCase):
raise raise
finally: finally:
sys.stderr = old_stderr sys.stderr = old_stderr
_logger.setLevel(oldlevel)
def test1_err_subslice(self): def test1_err_subslice(self):
n = as_tensor_variable(numpy.ones(3)) n = as_tensor_variable(numpy.ones(3))
try: try:
...@@ -796,6 +806,9 @@ class T_subtensor(unittest.TestCase): ...@@ -796,6 +806,9 @@ class T_subtensor(unittest.TestCase):
self.failUnless(isinstance(t.owner.op, Subtensor)) self.failUnless(isinstance(t.owner.op, Subtensor))
old_stderr = sys.stderr old_stderr = sys.stderr
sys.stderr = StringIO.StringIO() sys.stderr = StringIO.StringIO()
_logger = logging.getLogger('theano.gof.opt')
oldlevel = _logger.getEffectiveLevel()
_logger.setLevel(logging.CRITICAL)
try: try:
tval = eval_outputs([t]) tval = eval_outputs([t])
assert 0 assert 0
...@@ -803,6 +816,7 @@ class T_subtensor(unittest.TestCase): ...@@ -803,6 +816,7 @@ class T_subtensor(unittest.TestCase):
pass pass
finally: finally:
sys.stderr = old_stderr sys.stderr = old_stderr
_logger.setLevel(oldlevel)
def test2_err_bounds1(self): def test2_err_bounds1(self):
n = as_tensor_variable(numpy.ones((2,3))*5) n = as_tensor_variable(numpy.ones((2,3))*5)
t = n[4:5,2] t = n[4:5,2]
...@@ -1464,10 +1478,13 @@ class t_dot(unittest.TestCase): ...@@ -1464,10 +1478,13 @@ class t_dot(unittest.TestCase):
def not_aligned(self, x, y): def not_aligned(self, x, y):
z = dot(x,y) z = dot(x,y)
# constant folding will complain to stderr that things are not aligned # constant folding will complain to _logger that things are not aligned
# this is normal, testers are not interested in seeing that output. # this is normal, testers are not interested in seeing that output.
old_stderr = sys.stderr old_stderr = sys.stderr
sys.stderr = StringIO.StringIO() sys.stderr = StringIO.StringIO()
_logger = logging.getLogger('theano.gof.opt')
oldlevel = _logger.getEffectiveLevel()
_logger.setLevel(logging.CRITICAL)
try: try:
tz = eval_outputs([z]) tz = eval_outputs([z])
assert False # should have raised exception assert False # should have raised exception
...@@ -1477,6 +1494,7 @@ class t_dot(unittest.TestCase): ...@@ -1477,6 +1494,7 @@ class t_dot(unittest.TestCase):
e[0].split()[0:2] == ['Shape', 'mismatch:'], e) # reported by blas return self.fail() e[0].split()[0:2] == ['Shape', 'mismatch:'], e) # reported by blas return self.fail()
finally: finally:
sys.stderr = old_stderr sys.stderr = old_stderr
_logger.setLevel(oldlevel)
def test_align_1_1(self): self.not_aligned(self.rand(5), self.rand(6)) def test_align_1_1(self): self.not_aligned(self.rand(5), self.rand(6))
def test_align_1_2(self): self.not_aligned(self.rand(5), self.rand(6,4)) def test_align_1_2(self): self.not_aligned(self.rand(5), self.rand(6,4))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论