提交 0e1102f1 authored 作者: Frederic's avatar Frederic

Better fix that would work for more cases if this happen again.

Also fix the test.
上级 8f7c7e9d
...@@ -4,8 +4,6 @@ import StringIO ...@@ -4,8 +4,6 @@ import StringIO
import sys import sys
import traceback import traceback
import numpy
import theano import theano
from theano.gof import utils from theano.gof import utils
from theano.gof import graph from theano.gof import graph
...@@ -321,21 +319,18 @@ class Container(object): ...@@ -321,21 +319,18 @@ class Container(object):
return "<" + repr(self.storage[0]) + ">" return "<" + repr(self.storage[0]) + ">"
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
# this exists because copy.deepcopy on numpy arrays is broken
a = self.storage[0]
if type(a) in (numpy.ndarray, numpy.memmap):
a = a.copy()
else:
a = copy.deepcopy(a)
r = type(self)( r = type(self)(
deepcopy(self.type, memo), deepcopy(self.type, memo),
[a], deepcopy(self.storage),
deepcopy(self.readonly), deepcopy(self.readonly),
deepcopy(self.strict), deepcopy(self.strict),
deepcopy(self.allow_downcast), deepcopy(self.allow_downcast),
deepcopy(self.name, memo), deepcopy(self.name, memo),
) )
# To force the call to filter. This is a work around NumPy
# deepcopy of ndarray with 0 dimention that don't return an
# ndarray.
r.data = r.data
return r return r
......
...@@ -186,11 +186,11 @@ def test_sort_schedule_fn(): ...@@ -186,11 +186,11 @@ def test_sort_schedule_fn():
def test_container_deepcopy(): def test_container_deepcopy():
""" """
This is a test to a work around a NumPy. This is a test to a work around a NumPy bug.
""" """
t = theano.tensor.scalar() t = theano.tensor.scalar()
v = numpy.asarray(0.) v = numpy.asarray(0.)
c = Container(t, [v]) c = Container(t, [v])
assert isinstance(c.storage[0], numpy.ndarray) assert isinstance(c.storage[0], numpy.ndarray)
deepcopy(c) d = deepcopy(c)
assert isinstance(c.storage[0], numpy.ndarray) assert isinstance(d.storage[0], numpy.ndarray)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论