提交 f36d29fd authored 作者: serdyuk's avatar serdyuk

Fixed CudaNdarray pickling/unpickling

上级 88c7752a
...@@ -6,22 +6,29 @@ unit tests or regression tests. ...@@ -6,22 +6,29 @@ unit tests or regression tests.
""" """
import numpy import numpy
import pickle import pickle
from six import StringIO
import sys import sys
import tempfile import tempfile
import zipfile import zipfile
import warnings
from collections import defaultdict from collections import defaultdict
from contextlib import closing from contextlib import closing
from pickle import HIGHEST_PROTOCOL from pickle import HIGHEST_PROTOCOL
from six import StringIO
try: try:
from pickle import DEFAULT_PROTOCOL from pickle import DEFAULT_PROTOCOL
except ImportError: except ImportError:
DEFAULT_PROTOCOL = HIGHEST_PROTOCOL DEFAULT_PROTOCOL = HIGHEST_PROTOCOL
import theano import theano
from theano import config
from theano.compat import PY3 from theano.compat import PY3
from theano.compat.six import string_types from theano.compat.six import string_types
from theano.compile.sharedvalue import SharedVariable from theano.compile.sharedvalue import SharedVariable
try:
from cuda_ndarray import cuda_ndarray
except ImportError:
cuda_ndarray = None
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
...@@ -137,13 +144,17 @@ class PersistentNdarrayID(object): ...@@ -137,13 +144,17 @@ class PersistentNdarrayID(object):
return name return name
def __call__(self, obj): def __call__(self, obj):
if type(obj) is numpy.ndarray: if ((type(obj) is numpy.ndarray) or
(type(obj) is cuda_ndarray.CudaNdarray)):
if id(obj) not in self.seen: if id(obj) not in self.seen:
def write_array(f): def write_array(f):
numpy.lib.format.write_array(f, obj) numpy.lib.format.write_array(f, numpy.asarray(obj))
name = self._resolve_name(obj) name = self._resolve_name(obj)
zipadd(write_array, self.zip_file, name) zipadd(write_array, self.zip_file, name)
self.seen[id(obj)] = 'ndarray.{}'.format(name) if type(obj) is cuda_ndarray.CudaNdarray:
self.seen[id(obj)] = 'cuda_ndarray.{}'.format(name)
else:
self.seen[id(obj)] = 'ndarray.{}'.format(name)
return self.seen[id(obj)] return self.seen[id(obj)]
...@@ -218,7 +229,22 @@ class PersistentNdarrayLoad(object): ...@@ -218,7 +229,22 @@ class PersistentNdarrayLoad(object):
def __call__(self, persid): def __call__(self, persid):
array_type, name = persid.split('.') array_type, name = persid.split('.')
return numpy.lib.format.read_array(self.zip_file.open(name))
array = numpy.lib.format.read_array(self.zip_file.open(name))
if array_type == 'cuda_ndarray':
if config.experimental.unpickle_gpu_on_cpu:
# directly return numpy array
warnings.warn("config.experimental.unpickle_gpu_on_cpu is set "
"to True. Unpickling CudaNdarray as "
"numpy.ndarray")
return array
elif cuda_ndarray:
return cuda_ndarray.CudaNdarray(array)
else:
raise ImportError("Cuda not found. Cannot unpickle "
"CudaNdarray")
else:
return array
def dump(obj, f, protocol=DEFAULT_PROTOCOL, def dump(obj, f, protocol=DEFAULT_PROTOCOL,
......
from numpy.testing import assert_allclose
from theano.sandbox.cuda.type import CudaNdarrayType
from theano.sandbox.cuda.var import CudaNdarraySharedVariable
from theano.misc.pkl_utils import dump, load
def test_dump_load():
x = CudaNdarraySharedVariable('x', CudaNdarrayType((1, 1), name='x'),
[[1]], False)
with open('test', 'w') as f:
dump(x, f)
with open('test', 'r') as f:
x = load(f)
assert x.name == 'x'
assert_allclose(x.get_value(), [[1]])
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论