提交 b7c58547 authored 作者: abergeron's avatar abergeron

Merge pull request #1864 from nouiz/fix_numpy_deepcopy

Fix deepcopy of shared variable due to a Numpy problem.
......@@ -847,6 +847,7 @@ def _lessbroken_deepcopy(a):
called on a 0-d array will return a numpy scalar, not an array.
"""
# this exists because copy.deepcopy on numpy arrays is broken
# This logic is also in link.py
if type(a) in (numpy.ndarray, numpy.memmap):
rval = a.copy()
else:
......
"""WRITEME"""
from copy import copy
from copy import copy, deepcopy
import StringIO
import sys
import traceback
......@@ -318,6 +318,30 @@ class Container(object):
def __repr__(self):
return "<" + repr(self.storage[0]) + ">"
def __deepcopy__(self, memo):
data_was_in_memo = id(self.storage[0]) in memo
r = type(self)(
deepcopy(self.type, memo=memo),
deepcopy(self.storage, memo=memo),
deepcopy(self.readonly, memo=memo),
deepcopy(self.strict, memo=memo),
deepcopy(self.allow_downcast, memo=memo),
deepcopy(self.name, memo=memo),
)
# Work around NumPy deepcopy of ndarray with 0 dimention that
# don't return an ndarray.
if (r.storage[0] is not None and
not self.type.is_valid_value(r.storage[0])):
assert not data_was_in_memo
assert self.type.is_valid_value(self.storage[0])
# This should also work for read only container.
r.storage[0] = self.type.filter(r.storage[0],
strict=False,
allow_downcast=False)
memo[id(self.storage[0])] = r.storage[0]
return r
def map_storage(fgraph, order, input_storage, output_storage):
"""Ensure there is storage (a length-1 list) for inputs, outputs, and interior nodes.
......
from copy import deepcopy
import unittest
import numpy
from theano.gof import graph
from theano.gof.graph import Variable, Apply, Constant
from theano.gof.type import Type
......@@ -9,6 +12,7 @@ from theano.gof import fg
from theano.gof.link import *
from theano.compat import cmp
def as_variable(x):
assert isinstance(x, Variable)
return x
......@@ -110,7 +114,8 @@ class TestPerformLinker(unittest.TestCase):
x, y, z = inputs()
a, d = add(x, y), div(x, y)
e = mul(a, d)
fn = perform_linker(FunctionGraph(*graph.clone([x, y, a], [e]))).make_function()
fn = perform_linker(FunctionGraph(*graph.clone([x, y, a],
[e]))).make_function()
assert fn(1.0, 2.0, 9.0) == 4.5
def test_skiphole(self):
......@@ -118,7 +123,8 @@ class TestPerformLinker(unittest.TestCase):
a = add(x, y)
r = raise_err(a)
e = add(r, a)
fn = perform_linker(FunctionGraph(*graph.clone([x, y, r], [e]))).make_function()
fn = perform_linker(FunctionGraph(*graph.clone([x, y, r],
[e]))).make_function()
assert fn(1.0, 2.0, 4.5) == 7.5
......@@ -137,8 +143,8 @@ class TestWrapLinker(unittest.TestCase):
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
fn, i, o = wrap_linker(
FunctionGraph([x, y, z], [e]),
[PerformLinker(allow_gc=False)], wrap).make_thunk()
FunctionGraph([x, y, z], [e]),
[PerformLinker(allow_gc=False)], wrap).make_thunk()
i[0].data = 1
i[1].data = 2
fn()
......@@ -155,20 +161,21 @@ class TestWrapLinker(unittest.TestCase):
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
fn, i, o = wrap_linker(
FunctionGraph([x, y, z], [e]),
[PerformLinker(allow_gc=False)], wrap).make_thunk()
FunctionGraph([x, y, z], [e]),
[PerformLinker(allow_gc=False)], wrap).make_thunk()
i[0].data = 1
i[1].data = 2
fn()
assert nodes == [div, add, mul]
assert o[0].data == 1.5
def test_sort_schedule_fn():
import theano
from theano.gof.sched import sort_schedule_fn, make_depends
x = theano.tensor.matrix('x')
y = theano.tensor.dot(x[:5]*2, x.T+1).T
str_cmp = lambda a, b: cmp(str(a), str(b)) # lexicographical sort
str_cmp = lambda a, b: cmp(str(a), str(b)) # lexicographical sort
linker = theano.OpWiseCLinker(schedule=sort_schedule_fn(str_cmp))
mode = theano.Mode(linker=linker)
f = theano.function((x,), (y,), mode=mode)
......@@ -176,5 +183,18 @@ def test_sort_schedule_fn():
nodes = f.maker.linker.make_all()[-1]
depends = make_depends()
for a, b in zip(nodes[:-1], nodes[1:]):
if not depends((b,a)):
if not depends((b, a)):
assert str(a) < str(b)
def test_container_deepcopy():
"""
This is a test to a work around a NumPy bug.
"""
t = theano.tensor.scalar()
v = numpy.asarray(0.)
for readonly in [True, False]:
c = Container(t, [v], readonly=readonly)
assert isinstance(c.storage[0], numpy.ndarray)
d = deepcopy(c)
assert isinstance(d.storage[0], numpy.ndarray)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论