提交 f36d29fd authored 作者: serdyuk's avatar serdyuk

Fixed CudaNdarray pickling/unpickling

上级 88c7752a
......@@ -6,22 +6,29 @@ unit tests or regression tests.
"""
import numpy
import pickle
from six import StringIO
import sys
import tempfile
import zipfile
import warnings
from collections import defaultdict
from contextlib import closing
from pickle import HIGHEST_PROTOCOL
from six import StringIO
try:
from pickle import DEFAULT_PROTOCOL
except ImportError:
DEFAULT_PROTOCOL = HIGHEST_PROTOCOL
import theano
from theano import config
from theano.compat import PY3
from theano.compat.six import string_types
from theano.compile.sharedvalue import SharedVariable
try:
from cuda_ndarray import cuda_ndarray
except ImportError:
cuda_ndarray = None
__docformat__ = "restructuredtext en"
......@@ -137,13 +144,17 @@ class PersistentNdarrayID(object):
return name
def __call__(self, obj):
if type(obj) is numpy.ndarray:
if ((type(obj) is numpy.ndarray) or
(type(obj) is cuda_ndarray.CudaNdarray)):
if id(obj) not in self.seen:
def write_array(f):
numpy.lib.format.write_array(f, obj)
numpy.lib.format.write_array(f, numpy.asarray(obj))
name = self._resolve_name(obj)
zipadd(write_array, self.zip_file, name)
self.seen[id(obj)] = 'ndarray.{}'.format(name)
if type(obj) is cuda_ndarray.CudaNdarray:
self.seen[id(obj)] = 'cuda_ndarray.{}'.format(name)
else:
self.seen[id(obj)] = 'ndarray.{}'.format(name)
return self.seen[id(obj)]
......@@ -218,7 +229,22 @@ class PersistentNdarrayLoad(object):
def __call__(self, persid):
array_type, name = persid.split('.')
return numpy.lib.format.read_array(self.zip_file.open(name))
array = numpy.lib.format.read_array(self.zip_file.open(name))
if array_type == 'cuda_ndarray':
if config.experimental.unpickle_gpu_on_cpu:
# directly return numpy array
warnings.warn("config.experimental.unpickle_gpu_on_cpu is set "
"to True. Unpickling CudaNdarray as "
"numpy.ndarray")
return array
elif cuda_ndarray:
return cuda_ndarray.CudaNdarray(array)
else:
raise ImportError("Cuda not found. Cannot unpickle "
"CudaNdarray")
else:
return array
def dump(obj, f, protocol=DEFAULT_PROTOCOL,
......
from numpy.testing import assert_allclose
from theano.sandbox.cuda.type import CudaNdarrayType
from theano.sandbox.cuda.var import CudaNdarraySharedVariable
from theano.misc.pkl_utils import dump, load
def test_dump_load():
x = CudaNdarraySharedVariable('x', CudaNdarrayType((1, 1), name='x'),
[[1]], False)
with open('test', 'w') as f:
dump(x, f)
with open('test', 'r') as f:
x = load(f)
assert x.name == 'x'
assert_allclose(x.get_value(), [[1]])
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论