提交 b95217e9 authored 作者: nouiz's avatar nouiz

Merge pull request #819 from lamblin/memmap

Consider numpy.memmap a valid type for TensorType
...@@ -300,7 +300,8 @@ TensorType and TensorVariable ...@@ -300,7 +300,8 @@ TensorType and TensorVariable
.. class:: TensorType(Type) .. class:: TensorType(Type)
The Type class used to mark Variables that stand for `numpy.ndarray` The Type class used to mark Variables that stand for `numpy.ndarray`
values. Recalling to the tutorial, the purple box in values (`numpy.memmap`, which is a subclass of `numpy.ndarray`, is also allowed).
Recalling to the tutorial, the purple box in
:ref:`the tutorial's graph-structure figure <tutorial-graphfigure>` is an instance of this class. :ref:`the tutorial's graph-structure figure <tutorial-graphfigure>` is an instance of this class.
.. attribute:: broadcastable .. attribute:: broadcastable
......
...@@ -870,17 +870,18 @@ def _lessbroken_deepcopy(a): ...@@ -870,17 +870,18 @@ def _lessbroken_deepcopy(a):
""" """
:param a: any object :param a: any object
Returns a copy of `a` that shares no internal storage with the original. A deep copy. Returns a copy of `a` that shares no internal storage with the original.
This function handles numpy arrays specially to avoid some bug I had one time... (possibly A deep copy.
about copying 0-d arrays?) This function handles numpy arrays specially, because copy.deepcopy()
called on a 0-d array will return a numpy scalar, not an array.
""" """
# this exists because numpy copies are broken # this exists because copy.deepcopy on numpy arrays is broken
if type(a) is numpy.ndarray: if type(a) in (numpy.ndarray, numpy.memmap):
rval = numpy.array(a, copy=True, dtype=a.dtype) rval = a.copy()
else: else:
rval = copy.deepcopy(a) rval = copy.deepcopy(a)
assert type(rval) == type(a) assert type(rval) == type(a), (type(rval), type(a))
if isinstance(rval, numpy.ndarray): if isinstance(rval, numpy.ndarray):
assert rval.dtype == a.dtype assert rval.dtype == a.dtype
return rval return rval
...@@ -1992,7 +1993,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -1992,7 +1993,7 @@ class _Linker(gof.link.LocalLinker):
if r.owner is None: if r.owner is None:
assert r in fgraph.inputs assert r in fgraph.inputs
#HACK TO LOOK LIKE A REAL DESTRUCTIVE ACTION TOOK PLACE #HACK TO LOOK LIKE A REAL DESTRUCTIVE ACTION TOOK PLACE
if type(dr_vals[r][0]) is numpy.ndarray \ if type(dr_vals[r][0]) in (numpy.ndarray, numpy.memmap) \
and dr_vals[r][0].dtype == storage_map[r][0].dtype \ and dr_vals[r][0].dtype == storage_map[r][0].dtype \
and dr_vals[r][0].shape == storage_map[r][0].shape: and dr_vals[r][0].shape == storage_map[r][0].shape:
if len(dr_vals[r][0].shape): if len(dr_vals[r][0].shape):
......
...@@ -611,6 +611,13 @@ class TensorType(Type): ...@@ -611,6 +611,13 @@ class TensorType(Type):
if data.dtype.num != self.numpy_dtype.num: if data.dtype.num != self.numpy_dtype.num:
data = theano._asarray(data, dtype=self.dtype) data = theano._asarray(data, dtype=self.dtype)
# -- now fall through to ndim check # -- now fall through to ndim check
elif((type(data) is numpy.memmap)
and (data.dtype == self.numpy_dtype)):
# numpy.memmap is a "safe" subclass of ndarray,
# so we can use it whereever we expect a base ndarray.
# however, casting it would defeat the purpose of not
# loading the whole data into memory
pass
elif strict: elif strict:
# If any of the two conditions above was not met, # If any of the two conditions above was not met,
# we raise a meaningful TypeError. # we raise a meaningful TypeError.
......
...@@ -218,7 +218,7 @@ class DimShuffle(Op): ...@@ -218,7 +218,7 @@ class DimShuffle(Op):
storage, = out storage, = out
# drop # drop
res = input res = input
if type(res) != numpy.ndarray: if type(res) != numpy.ndarray and type(res) != numpy.memmap:
raise TypeError(res) raise TypeError(res)
shape = list(res.shape) shape = list(res.shape)
for drop in reversed(self.drop): for drop in reversed(self.drop):
......
import itertools import itertools
import logging import logging
import operator import operator
import os
import StringIO import StringIO
import sys import sys
from tempfile import mkstemp
import unittest import unittest
import warnings import warnings
from copy import copy, deepcopy from copy import copy, deepcopy
...@@ -176,7 +178,7 @@ def safe_make_node(op, *inputs): ...@@ -176,7 +178,7 @@ def safe_make_node(op, *inputs):
def makeTester(name, op, expected, checks=None, good=None, bad_build=None, def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
bad_runtime=None, grad=None, mode=None, grad_rtol=None, bad_runtime=None, grad=None, mode=None, grad_rtol=None,
eps=1e-10, skip=False): eps=1e-10, skip=False, test_memmap=True):
if checks is None: if checks is None:
checks = {} checks = {}
if good is None: if good is None:
...@@ -193,6 +195,7 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None, ...@@ -193,6 +195,7 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
_op, _expected, _checks, _good = op, expected, checks, good _op, _expected, _checks, _good = op, expected, checks, good
_bad_build, _bad_runtime, _grad = bad_build, bad_runtime, grad _bad_build, _bad_runtime, _grad = bad_build, bad_runtime, grad
_mode, _grad_rtol, _eps, skip_ = mode, grad_rtol, eps, skip _mode, _grad_rtol, _eps, skip_ = mode, grad_rtol, eps, skip
_test_memmap = test_memmap
class Checker(unittest.TestCase): class Checker(unittest.TestCase):
...@@ -205,6 +208,48 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None, ...@@ -205,6 +208,48 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
grad = _grad grad = _grad
mode = _mode mode = _mode
skip = skip_ skip = skip_
test_memmap = _test_memmap
def setUp(self):
# If test_memmap is True, we create a temporary file
# containing a copy of the data passed in the "good" dict,
# then open it as a memmapped array, and use the result as a
# new test value.
# We keep a list of temporary files created, to remove them
# at the end of the test.
self.tmp_files = []
if not self.test_memmap:
return
# Copy dict before modifying them
self.good = self.good.copy()
for k, v in self.good.items():
new_k = '_'.join((k, 'memmap'))
if new_k in self.good:
# A corresponding key was already provided
break
new_v = []
for inp in v:
if type(inp) is numpy.ndarray and inp.size > 0:
f, fname = mkstemp()
self.tmp_files.append((f, fname))
new_inp = numpy.memmap(fname, dtype=inp.dtype,
mode='w+', shape=inp.shape)
new_inp[...] = inp[...]
new_v.append(new_inp)
else:
new_v.append(inp)
self.good[new_k] = new_v
# We only need one value, no need to copy all of them
break
def tearDown(self):
for f, fname in self.tmp_files:
os.close(f)
os.remove(fname)
def test_good(self): def test_good(self):
if skip: if skip:
...@@ -301,7 +346,6 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None, ...@@ -301,7 +346,6 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
if skip: if skip:
raise SkipTest(skip) raise SkipTest(skip)
for testname, inputs in self.bad_runtime.items(): for testname, inputs in self.bad_runtime.items():
inputs = [copy(input) for input in inputs]
inputrs = [shared(input) for input in inputs] inputrs = [shared(input) for input in inputs]
try: try:
node = safe_make_node(self.op, *inputrs) node = safe_make_node(self.op, *inputrs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论