提交 4786b8ed authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Create files in a tmp dir, clean up afterwards

上级 fc356844
import numpy import numpy
from numpy.testing import assert_allclose from numpy.testing import assert_allclose
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
import os
from tempfile import mkdtemp
import shutil
import unittest
import theano import theano
import theano.sandbox.cuda as cuda_ndarray import theano.sandbox.cuda as cuda_ndarray
...@@ -11,7 +15,21 @@ from theano.sandbox.rng_mrg import MRG_RandomStreams ...@@ -11,7 +15,21 @@ from theano.sandbox.rng_mrg import MRG_RandomStreams
from theano.misc.pkl_utils import dump, load from theano.misc.pkl_utils import dump, load
def test_dump_load(): class T_dump_load(unittest.TestCase):
def setUp(self):
# Work in a temporary directory to avoid cluttering the repository
self.origdir = os.getcwd()
self.tmpdir = None
self.tmpdir = mkdtemp()
os.chdir(self.tmpdir)
def tearDown(self):
# Get back to the original dir, and delete the temporary one
os.chdir(self.origdir)
if self.tmpdir is not None:
shutil.rmtree(self.tmpdir)
def test_dump_load(self):
if not cuda_ndarray.cuda_enabled: if not cuda_ndarray.cuda_enabled:
raise SkipTest('Optional package cuda disabled') raise SkipTest('Optional package cuda disabled')
...@@ -27,8 +45,7 @@ def test_dump_load(): ...@@ -27,8 +45,7 @@ def test_dump_load():
assert x.name == 'x' assert x.name == 'x'
assert_allclose(x.get_value(), [[1]]) assert_allclose(x.get_value(), [[1]])
def test_dump_load_mrg(self):
def test_dump_load_mrg():
rng = MRG_RandomStreams(use_cuda=cuda_ndarray.cuda_enabled) rng = MRG_RandomStreams(use_cuda=cuda_ndarray.cuda_enabled)
with open('test', 'wb') as f: with open('test', 'wb') as f:
...@@ -39,8 +56,7 @@ def test_dump_load_mrg(): ...@@ -39,8 +56,7 @@ def test_dump_load_mrg():
assert type(rng) == MRG_RandomStreams assert type(rng) == MRG_RandomStreams
def test_dump_zip_names(self):
def test_dump_zip_names():
foo_1 = theano.shared(0, name='foo') foo_1 = theano.shared(0, name='foo')
foo_2 = theano.shared(1, name='foo') foo_2 = theano.shared(1, name='foo')
foo_3 = theano.shared(2, name='foo') foo_3 = theano.shared(2, name='foo')
......
...@@ -256,7 +256,7 @@ class T_Scan(unittest.TestCase): ...@@ -256,7 +256,7 @@ class T_Scan(unittest.TestCase):
finally: finally:
f_in.close() f_in.close()
finally: finally:
# Get back to the orinal dir, and delete temporary one. # Get back to the original dir, and delete the temporary one.
os.chdir(origdir) os.chdir(origdir)
if tmpdir is not None: if tmpdir is not None:
shutil.rmtree(tmpdir) shutil.rmtree(tmpdir)
......
...@@ -878,7 +878,7 @@ class T_loading_and_saving(unittest.TestCase): ...@@ -878,7 +878,7 @@ class T_loading_and_saving(unittest.TestCase):
loaded_objects.append(pickle.load(f)) loaded_objects.append(pickle.load(f))
f.close() f.close()
finally: finally:
# Get back to the orinal dir, and temporary one. # Get back to the original dir, and delete the temporary one.
os.chdir(origdir) os.chdir(origdir)
if tmpdir is not None: if tmpdir is not None:
shutil.rmtree(tmpdir) shutil.rmtree(tmpdir)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论