提交 8ddca387 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix DebugMode to accept numpy.memmap.

Mention memmap in the documentation.
上级 456fa6b4
...@@ -300,7 +300,8 @@ TensorType and TensorVariable ...@@ -300,7 +300,8 @@ TensorType and TensorVariable
.. class:: TensorType(Type) .. class:: TensorType(Type)
The Type class used to mark Variables that stand for `numpy.ndarray` The Type class used to mark Variables that stand for `numpy.ndarray`
values. Recalling to the tutorial, the purple box in values (`numpy.memmap`, which is a subclass of `numpy.ndarray`, is also allowed).
Recalling to the tutorial, the purple box in
:ref:`the tutorial's graph-structure figure <tutorial-graphfigure>` is an instance of this class. :ref:`the tutorial's graph-structure figure <tutorial-graphfigure>` is an instance of this class.
.. attribute:: broadcastable .. attribute:: broadcastable
......
...@@ -870,17 +870,18 @@ def _lessbroken_deepcopy(a): ...@@ -870,17 +870,18 @@ def _lessbroken_deepcopy(a):
""" """
:param a: any object :param a: any object
Returns a copy of `a` that shares no internal storage with the original. A deep copy. Returns a copy of `a` that shares no internal storage with the original.
This function handles numpy arrays specially to avoid some bug I had one time... (possibly A deep copy.
about copying 0-d arrays?) This function handles numpy arrays specially, because copy.deepcopy()
called on a 0-d array will return a numpy scalar, not an array.
""" """
# this exists because numpy copies are broken # this exists because copy.deepcopy on numpy arrays is broken
if type(a) is numpy.ndarray: if type(a) in (numpy.ndarray, numpy.memmap):
rval = numpy.array(a, copy=True, dtype=a.dtype) rval = a.copy()
else: else:
rval = copy.deepcopy(a) rval = copy.deepcopy(a)
assert type(rval) == type(a) assert type(rval) == type(a), (type(rval), type(a))
if isinstance(rval, numpy.ndarray): if isinstance(rval, numpy.ndarray):
assert rval.dtype == a.dtype assert rval.dtype == a.dtype
return rval return rval
...@@ -1992,7 +1993,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -1992,7 +1993,7 @@ class _Linker(gof.link.LocalLinker):
if r.owner is None: if r.owner is None:
assert r in fgraph.inputs assert r in fgraph.inputs
#HACK TO LOOK LIKE A REAL DESTRUCTIVE ACTION TOOK PLACE #HACK TO LOOK LIKE A REAL DESTRUCTIVE ACTION TOOK PLACE
if type(dr_vals[r][0]) is numpy.ndarray \ if type(dr_vals[r][0]) in (numpy.ndarray, numpy.memmap) \
and dr_vals[r][0].dtype == storage_map[r][0].dtype \ and dr_vals[r][0].dtype == storage_map[r][0].dtype \
and dr_vals[r][0].shape == storage_map[r][0].shape: and dr_vals[r][0].shape == storage_map[r][0].shape:
if len(dr_vals[r][0].shape): if len(dr_vals[r][0].shape):
......
...@@ -218,7 +218,7 @@ class DimShuffle(Op): ...@@ -218,7 +218,7 @@ class DimShuffle(Op):
storage, = out storage, = out
# drop # drop
res = input res = input
if type(res) != numpy.ndarray: if type(res) != numpy.ndarray and type(res) != numpy.memmap:
raise TypeError(res) raise TypeError(res)
shape = list(res.shape) shape = list(res.shape)
for drop in reversed(self.drop): for drop in reversed(self.drop):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论