提交 bbca839e authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3161 from lamblin/misc_fixes

Misc fixes for tests
"""
This test is for testing the NanGuardMode.
"""
from theano.compile.nanguardmode import NanGuardMode
import logging
from nose.tools import assert_raises
import numpy
from theano.compile.nanguardmode import NanGuardMode
import theano
import theano.tensor as T
......@@ -29,21 +33,14 @@ def test_NanGuardMode():
biga = numpy.tile(
numpy.asarray(1e20).astype(theano.config.floatX), (3, 5))
work = [False, False, False]
fun(a) # normal values
try:
fun(infa) # INFs
except AssertionError:
work[0] = True
try:
fun(nana) # NANs
except AssertionError:
work[1] = True
try:
fun(biga) # big values
except AssertionError:
work[2] = True
if not (work[0] and work[1] and work[2]):
raise AssertionError("NanGuardMode not working.")
# Temporarily silence logger
_logger = logging.getLogger("theano.compile.nanguardmode")
try:
_logger.propagate = False
assert_raises(AssertionError, fun, infa) # INFs
assert_raises(AssertionError, fun, nana) # NANs
assert_raises(AssertionError, fun, biga) # big values
finally:
_logger.propagate = True
import os
import shutil
import unittest
from tempfile import mkdtemp
import numpy
from numpy.testing import assert_allclose
from nose.plugins.skip import SkipTest
......@@ -11,45 +16,56 @@ from theano.sandbox.rng_mrg import MRG_RandomStreams
from theano.misc.pkl_utils import dump, load
def test_dump_load():
if not cuda_ndarray.cuda_enabled:
raise SkipTest('Optional package cuda disabled')
class T_dump_load(unittest.TestCase):
def setUp(self):
# Work in a temporary directory to avoid cluttering the repository
self.origdir = os.getcwd()
self.tmpdir = mkdtemp()
os.chdir(self.tmpdir)
x = CudaNdarraySharedVariable('x', CudaNdarrayType((1, 1), name='x'),
[[1]], False)
def tearDown(self):
# Get back to the original dir, and delete the temporary one
os.chdir(self.origdir)
if self.tmpdir is not None:
shutil.rmtree(self.tmpdir)
with open('test', 'wb') as f:
dump(x, f)
def test_dump_load(self):
if not cuda_ndarray.cuda_enabled:
raise SkipTest('Optional package cuda disabled')
with open('test', 'rb') as f:
x = load(f)
x = CudaNdarraySharedVariable('x', CudaNdarrayType((1, 1), name='x'),
[[1]], False)
assert x.name == 'x'
assert_allclose(x.get_value(), [[1]])
with open('test', 'wb') as f:
dump(x, f)
with open('test', 'rb') as f:
x = load(f)
def test_dump_load_mrg():
rng = MRG_RandomStreams(use_cuda=cuda_ndarray.cuda_enabled)
assert x.name == 'x'
assert_allclose(x.get_value(), [[1]])
with open('test', 'wb') as f:
dump(rng, f)
def test_dump_load_mrg(self):
rng = MRG_RandomStreams(use_cuda=cuda_ndarray.cuda_enabled)
with open('test', 'rb') as f:
rng = load(f)
with open('test', 'wb') as f:
dump(rng, f)
assert type(rng) == MRG_RandomStreams
with open('test', 'rb') as f:
rng = load(f)
assert type(rng) == MRG_RandomStreams
def test_dump_zip_names():
foo_1 = theano.shared(0, name='foo')
foo_2 = theano.shared(1, name='foo')
foo_3 = theano.shared(2, name='foo')
with open('model.zip', 'wb') as f:
dump((foo_1, foo_2, foo_3, numpy.array(3)), f)
keys = list(numpy.load('model.zip').keys())
assert keys == ['foo', 'foo_2', 'foo_3', 'array_0', 'pkl']
foo_3 = numpy.load('model.zip')['foo_3']
assert foo_3 == numpy.array(2)
with open('model.zip', 'rb') as f:
foo_1, foo_2, foo_3, array = load(f)
assert array == numpy.array(3)
def test_dump_zip_names(self):
foo_1 = theano.shared(0, name='foo')
foo_2 = theano.shared(1, name='foo')
foo_3 = theano.shared(2, name='foo')
with open('model.zip', 'wb') as f:
dump((foo_1, foo_2, foo_3, numpy.array(3)), f)
keys = list(numpy.load('model.zip').keys())
assert keys == ['foo', 'foo_2', 'foo_3', 'array_0', 'pkl']
foo_3 = numpy.load('model.zip')['foo_3']
assert foo_3 == numpy.array(2)
with open('model.zip', 'rb') as f:
foo_1, foo_2, foo_3, array = load(f)
assert array == numpy.array(3)
......@@ -256,7 +256,7 @@ class T_Scan(unittest.TestCase):
finally:
f_in.close()
finally:
# Get back to the orinal dir, and delete temporary one.
# Get back to the original dir, and delete the temporary one.
os.chdir(origdir)
if tmpdir is not None:
shutil.rmtree(tmpdir)
......
......@@ -878,7 +878,7 @@ class T_loading_and_saving(unittest.TestCase):
loaded_objects.append(pickle.load(f))
f.close()
finally:
# Get back to the orinal dir, and temporary one.
# Get back to the original dir, and delete the temporary one.
os.chdir(origdir)
if tmpdir is not None:
shutil.rmtree(tmpdir)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论