提交 0e1102f1 authored 作者: Frederic's avatar Frederic

Better fix that would work for more cases if this happen again.

Also fix the test.
上级 8f7c7e9d
......@@ -4,8 +4,6 @@ import StringIO
import sys
import traceback
import numpy
import theano
from theano.gof import utils
from theano.gof import graph
......@@ -321,21 +319,18 @@ class Container(object):
return "<" + repr(self.storage[0]) + ">"
def __deepcopy__(self, memo):
# this exists because copy.deepcopy on numpy arrays is broken
a = self.storage[0]
if type(a) in (numpy.ndarray, numpy.memmap):
a = a.copy()
else:
a = copy.deepcopy(a)
r = type(self)(
deepcopy(self.type, memo),
[a],
deepcopy(self.storage),
deepcopy(self.readonly),
deepcopy(self.strict),
deepcopy(self.allow_downcast),
deepcopy(self.name, memo),
)
# To force the call to filter. This is a work around NumPy
# deepcopy of ndarray with 0 dimention that don't return an
# ndarray.
r.data = r.data
return r
......
......@@ -186,11 +186,11 @@ def test_sort_schedule_fn():
def test_container_deepcopy():
"""
This is a test to a work around a NumPy.
This is a test to a work around a NumPy bug.
"""
t = theano.tensor.scalar()
v = numpy.asarray(0.)
c = Container(t, [v])
assert isinstance(c.storage[0], numpy.ndarray)
deepcopy(c)
assert isinstance(c.storage[0], numpy.ndarray)
d = deepcopy(c)
assert isinstance(d.storage[0], numpy.ndarray)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论