提交 9a1693c3 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Correctly keep track of used tmp files

上级 596e3cdb
...@@ -216,22 +216,25 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None, ...@@ -216,22 +216,25 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
def setUp(self): def setUp(self):
# Verify that the test's name is correctly set. # Verify that the test's name is correctly set.
assert eval(self.__class__.__name__) is self.__class__ assert eval(self.__class__.__name__) is self.__class__
# We keep a list of temporary files created in add_memmap_values,
# to remove them at the end of the test.
self.tmp_files = []
def add_memmap_values(self, val_dict):
# If test_memmap is True, we create a temporary file # If test_memmap is True, we create a temporary file
# containing a copy of the data passed in the "good" dict, # containing a copy of the data passed in the "val_dict" dict,
# then open it as a memmapped array, and use the result as a # then open it as a memmapped array, and we can use the result as a
# new test value. # new test value.
# We keep a list of temporary files created, to remove them
# at the end of the test.
self.tmp_files = []
if not self.test_memmap: if not self.test_memmap:
return return val_dict
# Copy dict before modifying them # Copy dict before modifying them
self.good = self.good.copy() val_dict = val_dict.copy()
for k, v in self.good.items(): for k, v in val_dict.items():
new_k = '_'.join((k, 'memmap')) new_k = '_'.join((k, 'memmap'))
if new_k in self.good: if new_k in val_dict:
# A corresponding key was already provided # A corresponding key was already provided
break break
...@@ -246,10 +249,11 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None, ...@@ -246,10 +249,11 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
new_v.append(new_inp) new_v.append(new_inp)
else: else:
new_v.append(inp) new_v.append(inp)
self.good[new_k] = new_v val_dict[new_k] = new_v
# We only need one value, no need to copy all of them # We only need one value, no need to copy all of them
break break
return val_dict
def tearDown(self): def tearDown(self):
for f, fname in self.tmp_files: for f, fname in self.tmp_files:
...@@ -259,7 +263,10 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None, ...@@ -259,7 +263,10 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
def test_good(self): def test_good(self):
if skip: if skip:
raise SkipTest(skip) raise SkipTest(skip)
for testname, inputs in self.good.items():
good = self.add_memmap_values(self.good)
for testname, inputs in good.items():
inputs = [copy(input) for input in inputs] inputs = [copy(input) for input in inputs]
inputrs = [TensorType( inputrs = [TensorType(
dtype=input.dtype, dtype=input.dtype,
...@@ -286,7 +293,7 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None, ...@@ -286,7 +293,7 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
and testname in self.expected): and testname in self.expected):
expecteds = self.expected[testname] expecteds = self.expected[testname]
# with numpy version, when we print a number and read it # with numpy version, when we print a number and read it
# back, we don't get exactly the same result #So we accept # back, we don't get exactly the same result, so we accept
# rounding error in that case. # rounding error in that case.
eps = 5e-9 eps = 5e-9
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论