提交 b95217e9 authored 作者: nouiz's avatar nouiz

Merge pull request #819 from lamblin/memmap

Consider numpy.memmap a valid type for TensorType
......@@ -300,7 +300,8 @@ TensorType and TensorVariable
.. class:: TensorType(Type)
The Type class used to mark Variables that stand for `numpy.ndarray`
values. Recalling to the tutorial, the purple box in
values (`numpy.memmap`, which is a subclass of `numpy.ndarray`, is also allowed).
Recalling to the tutorial, the purple box in
:ref:`the tutorial's graph-structure figure <tutorial-graphfigure>` is an instance of this class.
.. attribute:: broadcastable
......
......@@ -870,17 +870,18 @@ def _lessbroken_deepcopy(a):
"""
:param a: any object
Returns a copy of `a` that shares no internal storage with the original. A deep copy.
This function handles numpy arrays specially to avoid some bug I had one time... (possibly
about copying 0-d arrays?)
Returns a copy of `a` that shares no internal storage with the original.
A deep copy.
This function handles numpy arrays specially, because copy.deepcopy()
called on a 0-d array will return a numpy scalar, not an array.
"""
# this exists because numpy copies are broken
if type(a) is numpy.ndarray:
rval = numpy.array(a, copy=True, dtype=a.dtype)
# this exists because copy.deepcopy on numpy arrays is broken
if type(a) in (numpy.ndarray, numpy.memmap):
rval = a.copy()
else:
rval = copy.deepcopy(a)
assert type(rval) == type(a)
assert type(rval) == type(a), (type(rval), type(a))
if isinstance(rval, numpy.ndarray):
assert rval.dtype == a.dtype
return rval
......@@ -1992,7 +1993,7 @@ class _Linker(gof.link.LocalLinker):
if r.owner is None:
assert r in fgraph.inputs
#HACK TO LOOK LIKE A REAL DESTRUCTIVE ACTION TOOK PLACE
if type(dr_vals[r][0]) is numpy.ndarray \
if type(dr_vals[r][0]) in (numpy.ndarray, numpy.memmap) \
and dr_vals[r][0].dtype == storage_map[r][0].dtype \
and dr_vals[r][0].shape == storage_map[r][0].shape:
if len(dr_vals[r][0].shape):
......
......@@ -611,6 +611,13 @@ class TensorType(Type):
if data.dtype.num != self.numpy_dtype.num:
data = theano._asarray(data, dtype=self.dtype)
# -- now fall through to ndim check
elif((type(data) is numpy.memmap)
and (data.dtype == self.numpy_dtype)):
# numpy.memmap is a "safe" subclass of ndarray,
# so we can use it whereever we expect a base ndarray.
# however, casting it would defeat the purpose of not
# loading the whole data into memory
pass
elif strict:
# If any of the two conditions above was not met,
# we raise a meaningful TypeError.
......
......@@ -218,7 +218,7 @@ class DimShuffle(Op):
storage, = out
# drop
res = input
if type(res) != numpy.ndarray:
if type(res) != numpy.ndarray and type(res) != numpy.memmap:
raise TypeError(res)
shape = list(res.shape)
for drop in reversed(self.drop):
......
import itertools
import logging
import operator
import os
import StringIO
import sys
from tempfile import mkstemp
import unittest
import warnings
from copy import copy, deepcopy
......@@ -176,7 +178,7 @@ def safe_make_node(op, *inputs):
def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
bad_runtime=None, grad=None, mode=None, grad_rtol=None,
eps=1e-10, skip=False):
eps=1e-10, skip=False, test_memmap=True):
if checks is None:
checks = {}
if good is None:
......@@ -193,6 +195,7 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
_op, _expected, _checks, _good = op, expected, checks, good
_bad_build, _bad_runtime, _grad = bad_build, bad_runtime, grad
_mode, _grad_rtol, _eps, skip_ = mode, grad_rtol, eps, skip
_test_memmap = test_memmap
class Checker(unittest.TestCase):
......@@ -205,6 +208,48 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
grad = _grad
mode = _mode
skip = skip_
test_memmap = _test_memmap
def setUp(self):
# If test_memmap is True, we create a temporary file
# containing a copy of the data passed in the "good" dict,
# then open it as a memmapped array, and use the result as a
# new test value.
# We keep a list of temporary files created, to remove them
# at the end of the test.
self.tmp_files = []
if not self.test_memmap:
return
# Copy dict before modifying them
self.good = self.good.copy()
for k, v in self.good.items():
new_k = '_'.join((k, 'memmap'))
if new_k in self.good:
# A corresponding key was already provided
break
new_v = []
for inp in v:
if type(inp) is numpy.ndarray and inp.size > 0:
f, fname = mkstemp()
self.tmp_files.append((f, fname))
new_inp = numpy.memmap(fname, dtype=inp.dtype,
mode='w+', shape=inp.shape)
new_inp[...] = inp[...]
new_v.append(new_inp)
else:
new_v.append(inp)
self.good[new_k] = new_v
# We only need one value, no need to copy all of them
break
def tearDown(self):
for f, fname in self.tmp_files:
os.close(f)
os.remove(fname)
def test_good(self):
if skip:
......@@ -301,7 +346,6 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
if skip:
raise SkipTest(skip)
for testname, inputs in self.bad_runtime.items():
inputs = [copy(input) for input in inputs]
inputrs = [shared(input) for input in inputs]
try:
node = safe_make_node(self.op, *inputrs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论