提交 bf8e97a7 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Test memmap as input in basic tensor tests

上级 f9ea6e0c
import itertools import itertools
import logging import logging
import operator import operator
import os
import StringIO import StringIO
import sys import sys
from tempfile import mkstemp
import unittest import unittest
import warnings import warnings
from copy import copy, deepcopy from copy import copy, deepcopy
...@@ -176,7 +178,7 @@ def safe_make_node(op, *inputs): ...@@ -176,7 +178,7 @@ def safe_make_node(op, *inputs):
def makeTester(name, op, expected, checks=None, good=None, bad_build=None, def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
bad_runtime=None, grad=None, mode=None, grad_rtol=None, bad_runtime=None, grad=None, mode=None, grad_rtol=None,
eps=1e-10, skip=False): eps=1e-10, skip=False, test_memmap=True):
if checks is None: if checks is None:
checks = {} checks = {}
if good is None: if good is None:
...@@ -193,6 +195,7 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None, ...@@ -193,6 +195,7 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
_op, _expected, _checks, _good = op, expected, checks, good _op, _expected, _checks, _good = op, expected, checks, good
_bad_build, _bad_runtime, _grad = bad_build, bad_runtime, grad _bad_build, _bad_runtime, _grad = bad_build, bad_runtime, grad
_mode, _grad_rtol, _eps, skip_ = mode, grad_rtol, eps, skip _mode, _grad_rtol, _eps, skip_ = mode, grad_rtol, eps, skip
_test_memmap = test_memmap
class Checker(unittest.TestCase): class Checker(unittest.TestCase):
...@@ -205,6 +208,48 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None, ...@@ -205,6 +208,48 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
grad = _grad grad = _grad
mode = _mode mode = _mode
skip = skip_ skip = skip_
test_memmap = _test_memmap
def setUp(self):
# If test_memmap is True, we create a temporary file
# containing a copy of the data passed in the "good" dict,
# then open it as a memmapped array, and use the result as a
# new test value.
# We keep a list of temporary files created, to remove them
# at the end of the test.
self.tmp_files = []
if not self.test_memmap:
return
# Copy dict before modifying them
self.good = self.good.copy()
for k, v in self.good.items():
new_k = '_'.join((k, 'memmap'))
if new_k in self.good:
# A corresponding key was already provided
break
new_v = []
for inp in v:
if type(inp) is numpy.ndarray and inp.size > 0:
f, fname = mkstemp()
self.tmp_files.append((f, fname))
new_inp = numpy.memmap(fname, dtype=inp.dtype,
mode='w+', shape=inp.shape)
new_inp[...] = inp[...]
new_v.append(new_inp)
else:
new_v.append(copy(inp))
self.good[new_k] = new_v
# We only need one value, no need to copy all of them
break
def tearDown(self):
for f, fname in self.tmp_files:
os.close(f)
os.remove(fname)
def test_good(self): def test_good(self):
if skip: if skip:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论