提交 8ddca387 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix DebugMode to accept numpy.memmap.

Mention memmap in the documentation.
上级 456fa6b4
......@@ -300,7 +300,8 @@ TensorType and TensorVariable
.. class:: TensorType(Type)
The Type class used to mark Variables that stand for `numpy.ndarray`
values. Recalling to the tutorial, the purple box in
values (`numpy.memmap`, which is a subclass of `numpy.ndarray`, is also allowed).
Recalling to the tutorial, the purple box in
:ref:`the tutorial's graph-structure figure <tutorial-graphfigure>` is an instance of this class.
.. attribute:: broadcastable
......
......@@ -870,17 +870,18 @@ def _lessbroken_deepcopy(a):
"""
:param a: any object
Returns a copy of `a` that shares no internal storage with the original. A deep copy.
This function handles numpy arrays specially to avoid some bug I had one time... (possibly
about copying 0-d arrays?)
Returns a copy of `a` that shares no internal storage with the original.
A deep copy.
This function handles numpy arrays specially, because copy.deepcopy()
called on a 0-d array will return a numpy scalar, not an array.
"""
# this exists because numpy copies are broken
if type(a) is numpy.ndarray:
rval = numpy.array(a, copy=True, dtype=a.dtype)
# this exists because copy.deepcopy on numpy arrays is broken
if type(a) in (numpy.ndarray, numpy.memmap):
rval = a.copy()
else:
rval = copy.deepcopy(a)
assert type(rval) == type(a)
assert type(rval) == type(a), (type(rval), type(a))
if isinstance(rval, numpy.ndarray):
assert rval.dtype == a.dtype
return rval
......@@ -1992,7 +1993,7 @@ class _Linker(gof.link.LocalLinker):
if r.owner is None:
assert r in fgraph.inputs
#HACK TO LOOK LIKE A REAL DESTRUCTIVE ACTION TOOK PLACE
if type(dr_vals[r][0]) is numpy.ndarray \
if type(dr_vals[r][0]) in (numpy.ndarray, numpy.memmap) \
and dr_vals[r][0].dtype == storage_map[r][0].dtype \
and dr_vals[r][0].shape == storage_map[r][0].shape:
if len(dr_vals[r][0].shape):
......
......@@ -218,7 +218,7 @@ class DimShuffle(Op):
storage, = out
# drop
res = input
if type(res) != numpy.ndarray:
if type(res) != numpy.ndarray and type(res) != numpy.memmap:
raise TypeError(res)
shape = list(res.shape)
for drop in reversed(self.drop):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论