提交 0d12e256 authored 作者: serdyuk's avatar serdyuk

Fixed tests

Added test pkl_utils into flake8 whitelist Refactored pkl tools and tests Removed test_pkl_utils from flake whitelist
上级 73f2057b
......@@ -13,7 +13,7 @@ import warnings
from collections import defaultdict
from contextlib import closing
from pickle import HIGHEST_PROTOCOL
from six import StringIO
from theano.compat.six import StringIO
try:
from pickle import DEFAULT_PROTOCOL
except ImportError:
......@@ -25,12 +25,11 @@ from theano.compat import PY3
from theano.compat.six import string_types
from theano.compile.sharedvalue import SharedVariable
try:
from cuda_ndarray import cuda_ndarray
from theano.sandbox.cuda import cuda_ndarray
except ImportError:
cuda_ndarray = None
__docformat__ = "restructuredtext en"
__authors__ = "Pascal Lamblin"
__copyright__ = "Copyright 2013, Universite de Montreal"
......@@ -142,22 +141,37 @@ class PersistentNdarrayID(object):
return name
def __call__(self, obj):
if ((type(obj) is numpy.ndarray) or
(type(obj) is cuda_ndarray.CudaNdarray)):
if type(obj) is numpy.ndarray:
if id(obj) not in self.seen:
def write_array(f):
numpy.lib.format.write_array(f, obj)
name = self._resolve_name(obj)
zipadd(write_array, self.zip_file, name)
self.seen[id(obj)] = 'ndarray.{}'.format(name)
return self.seen[id(obj)]
class PersistentCudaNdarrayID(PersistentNdarrayID):
def __init__(self, zip_file):
super(PersistentCudaNdarrayID, self).__init__(zip_file)
def __call__(self, obj):
print 'try cuda'
if (cuda_ndarray is not None and
type(obj) is cuda_ndarray.cuda_ndarray.CudaNdarray):
print 'cuda'
if id(obj) not in self.seen:
def write_array(f):
numpy.lib.format.write_array(f, numpy.asarray(obj))
name = self._resolve_name(obj)
zipadd(write_array, self.zip_file, name)
if type(obj) is cuda_ndarray.CudaNdarray:
self.seen[id(obj)] = 'cuda_ndarray.{}'.format(name)
else:
self.seen[id(obj)] = 'ndarray.{}'.format(name)
return self.seen[id(obj)]
super(PersistentCudaNdarrayID, self).__call__(obj)
class PersistentSharedVariableID(PersistentNdarrayID):
"""Persist the names of shared variable arrays in the zip file.
class PersistentSharedVariableID(PersistentCudaNdarrayID):
"""Uses shared variable names when persisting to zip file.
If a shared variable has a name, this name is used as the name of the
NPY file inside of the zip file. NumPy arrays that aren't matched to a
......@@ -232,7 +246,7 @@ class PersistentNdarrayLoad(object):
"numpy.ndarray")
return array
elif cuda_ndarray:
return cuda_ndarray.CudaNdarray(array)
return cuda_ndarray.cuda_ndarray.CudaNdarray(array)
else:
raise ImportError("Cuda not found. Cannot unpickle "
"CudaNdarray")
......@@ -240,18 +254,18 @@ class PersistentNdarrayLoad(object):
return array
def dump(obj, f, protocol=DEFAULT_PROTOCOL,
def dump(obj, file_handler, protocol=DEFAULT_PROTOCOL,
persistent_id=PersistentSharedVariableID):
"""Pickles an object to a zip file using external persistence.
:param obj: The object to pickle.
:type obj: object
:param f: The file handle to save the object to.
:type f: file
:param file_handler: The file handle to save the object to.
:type file_handler: file
:param protocol: The pickling protocol to use. Unlike Python's built-in
pickle, the default is set to `2` insstead of 0 for Python 2. The
pickle, the default is set to `2` instead of 0 for Python 2. The
Python 3 default (level 3) is maintained.
:type protocol: int, optional
......@@ -283,7 +297,7 @@ def dump(obj, f, protocol=DEFAULT_PROTOCOL,
array(2)
"""
with closing(zipfile.ZipFile(f, 'w', zipfile.ZIP_DEFLATED,
with closing(zipfile.ZipFile(file_handler, 'w', zipfile.ZIP_DEFLATED,
allowZip64=True)) as zip_file:
def func(f):
p = pickle.Pickler(f, protocol=protocol)
......
from numpy.testing import assert_allclose
from nose.plugins.skip import SkipTest
import theano.sandbox.cuda as cuda_ndarray
from theano.sandbox.cuda.type import CudaNdarrayType
from theano.sandbox.cuda.var import CudaNdarraySharedVariable
from theano.misc.pkl_utils import dump, load
if not cuda_ndarray.cuda_available:
raise SkipTest('Optional package cuda disabled')
def test_dump_load():
x = CudaNdarraySharedVariable('x', CudaNdarrayType((1, 1), name='x'),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论