提交 bbca839e authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3161 from lamblin/misc_fixes

Misc fixes for tests
""" """
This test is for testing the NanGuardMode. This test is for testing the NanGuardMode.
""" """
from theano.compile.nanguardmode import NanGuardMode import logging
from nose.tools import assert_raises
import numpy import numpy
from theano.compile.nanguardmode import NanGuardMode
import theano import theano
import theano.tensor as T import theano.tensor as T
...@@ -29,21 +33,14 @@ def test_NanGuardMode(): ...@@ -29,21 +33,14 @@ def test_NanGuardMode():
biga = numpy.tile( biga = numpy.tile(
numpy.asarray(1e20).astype(theano.config.floatX), (3, 5)) numpy.asarray(1e20).astype(theano.config.floatX), (3, 5))
work = [False, False, False]
fun(a) # normal values fun(a) # normal values
try:
fun(infa) # INFs
except AssertionError:
work[0] = True
try:
fun(nana) # NANs
except AssertionError:
work[1] = True
try:
fun(biga) # big values
except AssertionError:
work[2] = True
if not (work[0] and work[1] and work[2]): # Temporarily silence logger
raise AssertionError("NanGuardMode not working.") _logger = logging.getLogger("theano.compile.nanguardmode")
try:
_logger.propagate = False
assert_raises(AssertionError, fun, infa) # INFs
assert_raises(AssertionError, fun, nana) # NANs
assert_raises(AssertionError, fun, biga) # big values
finally:
_logger.propagate = True
import os
import shutil
import unittest
from tempfile import mkdtemp
import numpy import numpy
from numpy.testing import assert_allclose from numpy.testing import assert_allclose
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
...@@ -11,7 +16,20 @@ from theano.sandbox.rng_mrg import MRG_RandomStreams ...@@ -11,7 +16,20 @@ from theano.sandbox.rng_mrg import MRG_RandomStreams
from theano.misc.pkl_utils import dump, load from theano.misc.pkl_utils import dump, load
def test_dump_load(): class T_dump_load(unittest.TestCase):
def setUp(self):
# Work in a temporary directory to avoid cluttering the repository
self.origdir = os.getcwd()
self.tmpdir = mkdtemp()
os.chdir(self.tmpdir)
def tearDown(self):
# Get back to the original dir, and delete the temporary one
os.chdir(self.origdir)
if self.tmpdir is not None:
shutil.rmtree(self.tmpdir)
def test_dump_load(self):
if not cuda_ndarray.cuda_enabled: if not cuda_ndarray.cuda_enabled:
raise SkipTest('Optional package cuda disabled') raise SkipTest('Optional package cuda disabled')
...@@ -27,8 +45,7 @@ def test_dump_load(): ...@@ -27,8 +45,7 @@ def test_dump_load():
assert x.name == 'x' assert x.name == 'x'
assert_allclose(x.get_value(), [[1]]) assert_allclose(x.get_value(), [[1]])
def test_dump_load_mrg(self):
def test_dump_load_mrg():
rng = MRG_RandomStreams(use_cuda=cuda_ndarray.cuda_enabled) rng = MRG_RandomStreams(use_cuda=cuda_ndarray.cuda_enabled)
with open('test', 'wb') as f: with open('test', 'wb') as f:
...@@ -39,8 +56,7 @@ def test_dump_load_mrg(): ...@@ -39,8 +56,7 @@ def test_dump_load_mrg():
assert type(rng) == MRG_RandomStreams assert type(rng) == MRG_RandomStreams
def test_dump_zip_names(self):
def test_dump_zip_names():
foo_1 = theano.shared(0, name='foo') foo_1 = theano.shared(0, name='foo')
foo_2 = theano.shared(1, name='foo') foo_2 = theano.shared(1, name='foo')
foo_3 = theano.shared(2, name='foo') foo_3 = theano.shared(2, name='foo')
......
...@@ -256,7 +256,7 @@ class T_Scan(unittest.TestCase): ...@@ -256,7 +256,7 @@ class T_Scan(unittest.TestCase):
finally: finally:
f_in.close() f_in.close()
finally: finally:
# Get back to the orinal dir, and delete temporary one. # Get back to the original dir, and delete the temporary one.
os.chdir(origdir) os.chdir(origdir)
if tmpdir is not None: if tmpdir is not None:
shutil.rmtree(tmpdir) shutil.rmtree(tmpdir)
......
...@@ -878,7 +878,7 @@ class T_loading_and_saving(unittest.TestCase): ...@@ -878,7 +878,7 @@ class T_loading_and_saving(unittest.TestCase):
loaded_objects.append(pickle.load(f)) loaded_objects.append(pickle.load(f))
f.close() f.close()
finally: finally:
# Get back to the orinal dir, and temporary one. # Get back to the original dir, and delete the temporary one.
os.chdir(origdir) os.chdir(origdir)
if tmpdir is not None: if tmpdir is not None:
shutil.rmtree(tmpdir) shutil.rmtree(tmpdir)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论