提交 3652d765 authored 作者: Markus Roth's avatar Markus Roth

Unpickle cudaNdarray as numpy.ndarray, if cuda is not enabled.

However, emit a warning.
上级 99295012
ctheano.sandbox.cuda.type
CudaNdarray_unpickler
p1
(cnumpy.core.multiarray
_reconstruct
p2
(cnumpy
ndarray
p3
(I0
tS'b'
tRp4
(I1
(I1
tcnumpy
dtype
p5
(S'f4'
I0
I1
tRp6
(I3
S'<'
NNNI-1
I-1
I0
tbI00
S'\x00\x00(\xc2'
tbtR.
\ No newline at end of file
import unittest
import cPickle
import numpy
from numpy.testing.decorators import skipif
import os.path
import theano
from theano import tensor
......@@ -10,6 +12,9 @@ from theano.sandbox.cuda import CudaNdarrayType, cuda_available
import theano.sandbox.cuda as cuda
if cuda_available:
from theano.sandbox.cuda import CudaNdarray
@skipif(not cuda_available, msg='Optional package cuda disabled')
def test_float32_shared_constructor():
npy_row = numpy.zeros((1, 10), dtype='float32')
......@@ -121,3 +126,20 @@ class T_updates(unittest.TestCase):
output_func = theano.function(inputs=[], outputs=[],
updates=[(output_var, up)])
output_func()
def test_unpickle_cudandarray_as_numpy_ndarray():
# testfile created on cuda enabled machine using
# >>> with open('CudaNdarray.pkl', 'wb') as fp:
# >>> cPickle.dump(theano.sandbox.cuda.CudaNdarray(np.array([-42.0], dtype=np.float32)), fp)
testfile_dir = os.path.dirname(os.path.realpath(__file__))
with open(os.path.join(testfile_dir, 'CudaNdarray.pkl')) as fp:
mat = cPickle.load(fp)
if cuda_available:
assert isinstance(mat, CudaNdarray)
else:
assert isinstance(mat, numpy.ndarray)
assert mat[0] == -42.0
\ No newline at end of file
......@@ -2,6 +2,7 @@
"""
import os
import copy_reg
import warnings
import numpy
......@@ -487,7 +488,13 @@ theano.compile.register_deep_copy_op_c_code(
# equal the pickled version, and the cmodule cache is not happy with
# the situation.
def CudaNdarray_unpickler(npa):
return cuda.CudaNdarray(npa)
if cuda:
return cuda.CudaNdarray(npa)
else:
# directly return numpy array
warnings.warn("CUDA not found. Unpickling CudaNdarray as numpy.ndarray")
return npa
copy_reg.constructor(CudaNdarray_unpickler)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论