提交 8f7c7e9d authored 作者: Frederic's avatar Frederic

Fix deepcopy of shared variable due to a Numpy problem.

The fix is taken from debugmode.
上级 01f54ac2
...@@ -847,6 +847,7 @@ def _lessbroken_deepcopy(a): ...@@ -847,6 +847,7 @@ def _lessbroken_deepcopy(a):
called on a 0-d array will return a numpy scalar, not an array. called on a 0-d array will return a numpy scalar, not an array.
""" """
# this exists because copy.deepcopy on numpy arrays is broken # this exists because copy.deepcopy on numpy arrays is broken
# This logic is also in link.py
if type(a) in (numpy.ndarray, numpy.memmap): if type(a) in (numpy.ndarray, numpy.memmap):
rval = a.copy() rval = a.copy()
else: else:
......
"""WRITEME""" """WRITEME"""
from copy import copy from copy import copy, deepcopy
import StringIO import StringIO
import sys import sys
import traceback import traceback
import numpy
import theano import theano
from theano.gof import utils from theano.gof import utils
from theano.gof import graph from theano.gof import graph
...@@ -318,6 +320,24 @@ class Container(object): ...@@ -318,6 +320,24 @@ class Container(object):
def __repr__(self): def __repr__(self):
return "<" + repr(self.storage[0]) + ">" return "<" + repr(self.storage[0]) + ">"
def __deepcopy__(self, memo):
# this exists because copy.deepcopy on numpy arrays is broken
a = self.storage[0]
if type(a) in (numpy.ndarray, numpy.memmap):
a = a.copy()
else:
a = copy.deepcopy(a)
r = type(self)(
deepcopy(self.type, memo),
[a],
deepcopy(self.readonly),
deepcopy(self.strict),
deepcopy(self.allow_downcast),
deepcopy(self.name, memo),
)
return r
def map_storage(fgraph, order, input_storage, output_storage): def map_storage(fgraph, order, input_storage, output_storage):
"""Ensure there is storage (a length-1 list) for inputs, outputs, and interior nodes. """Ensure there is storage (a length-1 list) for inputs, outputs, and interior nodes.
......
from copy import deepcopy
import unittest import unittest
import numpy
from theano.gof import graph from theano.gof import graph
from theano.gof.graph import Variable, Apply, Constant from theano.gof.graph import Variable, Apply, Constant
from theano.gof.type import Type from theano.gof.type import Type
...@@ -9,6 +12,7 @@ from theano.gof import fg ...@@ -9,6 +12,7 @@ from theano.gof import fg
from theano.gof.link import * from theano.gof.link import *
from theano.compat import cmp from theano.compat import cmp
def as_variable(x): def as_variable(x):
assert isinstance(x, Variable) assert isinstance(x, Variable)
return x return x
...@@ -178,3 +182,15 @@ def test_sort_schedule_fn(): ...@@ -178,3 +182,15 @@ def test_sort_schedule_fn():
for a, b in zip(nodes[:-1], nodes[1:]): for a, b in zip(nodes[:-1], nodes[1:]):
if not depends((b,a)): if not depends((b,a)):
assert str(a) < str(b) assert str(a) < str(b)
def test_container_deepcopy():
"""
This is a test to a work around a NumPy.
"""
t = theano.tensor.scalar()
v = numpy.asarray(0.)
c = Container(t, [v])
assert isinstance(c.storage[0], numpy.ndarray)
deepcopy(c)
assert isinstance(c.storage[0], numpy.ndarray)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论