提交 3652d765 authored 作者: Markus Roth's avatar Markus Roth

Unpickle cudaNdarray as numpy.ndarray, if cuda is not enabled.

However, emit a warning.
上级 99295012
ctheano.sandbox.cuda.type
CudaNdarray_unpickler
p1
(cnumpy.core.multiarray
_reconstruct
p2
(cnumpy
ndarray
p3
(I0
tS'b'
tRp4
(I1
(I1
tcnumpy
dtype
p5
(S'f4'
I0
I1
tRp6
(I3
S'<'
NNNI-1
I-1
I0
tbI00
S'\x00\x00(\xc2'
tbtR.
\ No newline at end of file
import unittest import unittest
import cPickle
import numpy import numpy
from numpy.testing.decorators import skipif from numpy.testing.decorators import skipif
import os.path
import theano import theano
from theano import tensor from theano import tensor
...@@ -10,6 +12,9 @@ from theano.sandbox.cuda import CudaNdarrayType, cuda_available ...@@ -10,6 +12,9 @@ from theano.sandbox.cuda import CudaNdarrayType, cuda_available
import theano.sandbox.cuda as cuda import theano.sandbox.cuda as cuda
if cuda_available:
from theano.sandbox.cuda import CudaNdarray
@skipif(not cuda_available, msg='Optional package cuda disabled') @skipif(not cuda_available, msg='Optional package cuda disabled')
def test_float32_shared_constructor(): def test_float32_shared_constructor():
npy_row = numpy.zeros((1, 10), dtype='float32') npy_row = numpy.zeros((1, 10), dtype='float32')
...@@ -121,3 +126,20 @@ class T_updates(unittest.TestCase): ...@@ -121,3 +126,20 @@ class T_updates(unittest.TestCase):
output_func = theano.function(inputs=[], outputs=[], output_func = theano.function(inputs=[], outputs=[],
updates=[(output_var, up)]) updates=[(output_var, up)])
output_func() output_func()
def test_unpickle_cudandarray_as_numpy_ndarray():
# testfile created on cuda enabled machine using
# >>> with open('CudaNdarray.pkl', 'wb') as fp:
# >>> cPickle.dump(theano.sandbox.cuda.CudaNdarray(np.array([-42.0], dtype=np.float32)), fp)
testfile_dir = os.path.dirname(os.path.realpath(__file__))
with open(os.path.join(testfile_dir, 'CudaNdarray.pkl')) as fp:
mat = cPickle.load(fp)
if cuda_available:
assert isinstance(mat, CudaNdarray)
else:
assert isinstance(mat, numpy.ndarray)
assert mat[0] == -42.0
\ No newline at end of file
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
""" """
import os import os
import copy_reg import copy_reg
import warnings
import numpy import numpy
...@@ -487,7 +488,13 @@ theano.compile.register_deep_copy_op_c_code( ...@@ -487,7 +488,13 @@ theano.compile.register_deep_copy_op_c_code(
# equal the pickled version, and the cmodule cache is not happy with # equal the pickled version, and the cmodule cache is not happy with
# the situation. # the situation.
def CudaNdarray_unpickler(npa): def CudaNdarray_unpickler(npa):
return cuda.CudaNdarray(npa) if cuda:
return cuda.CudaNdarray(npa)
else:
# directly return numpy array
warnings.warn("CUDA not found. Unpickling CudaNdarray as numpy.ndarray")
return npa
copy_reg.constructor(CudaNdarray_unpickler) copy_reg.constructor(CudaNdarray_unpickler)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论