提交 0d12e256 authored 作者: serdyuk's avatar serdyuk

Fixed tests

Added test pkl_utils into flake8 whitelist Refactored pkl tools and tests Removed test_pkl_utils from flake whitelist
上级 73f2057b
...@@ -13,7 +13,7 @@ import warnings ...@@ -13,7 +13,7 @@ import warnings
from collections import defaultdict from collections import defaultdict
from contextlib import closing from contextlib import closing
from pickle import HIGHEST_PROTOCOL from pickle import HIGHEST_PROTOCOL
from six import StringIO from theano.compat.six import StringIO
try: try:
from pickle import DEFAULT_PROTOCOL from pickle import DEFAULT_PROTOCOL
except ImportError: except ImportError:
...@@ -25,12 +25,11 @@ from theano.compat import PY3 ...@@ -25,12 +25,11 @@ from theano.compat import PY3
from theano.compat.six import string_types from theano.compat.six import string_types
from theano.compile.sharedvalue import SharedVariable from theano.compile.sharedvalue import SharedVariable
try: try:
from cuda_ndarray import cuda_ndarray from theano.sandbox.cuda import cuda_ndarray
except ImportError: except ImportError:
cuda_ndarray = None cuda_ndarray = None
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
__authors__ = "Pascal Lamblin" __authors__ = "Pascal Lamblin"
__copyright__ = "Copyright 2013, Universite de Montreal" __copyright__ = "Copyright 2013, Universite de Montreal"
...@@ -142,22 +141,37 @@ class PersistentNdarrayID(object): ...@@ -142,22 +141,37 @@ class PersistentNdarrayID(object):
return name return name
def __call__(self, obj): def __call__(self, obj):
if ((type(obj) is numpy.ndarray) or if type(obj) is numpy.ndarray:
(type(obj) is cuda_ndarray.CudaNdarray)): if id(obj) not in self.seen:
def write_array(f):
numpy.lib.format.write_array(f, obj)
name = self._resolve_name(obj)
zipadd(write_array, self.zip_file, name)
self.seen[id(obj)] = 'ndarray.{}'.format(name)
return self.seen[id(obj)]
class PersistentCudaNdarrayID(PersistentNdarrayID):
def __init__(self, zip_file):
super(PersistentCudaNdarrayID, self).__init__(zip_file)
def __call__(self, obj):
print 'try cuda'
if (cuda_ndarray is not None and
type(obj) is cuda_ndarray.cuda_ndarray.CudaNdarray):
print 'cuda'
if id(obj) not in self.seen: if id(obj) not in self.seen:
def write_array(f): def write_array(f):
numpy.lib.format.write_array(f, numpy.asarray(obj)) numpy.lib.format.write_array(f, numpy.asarray(obj))
name = self._resolve_name(obj) name = self._resolve_name(obj)
zipadd(write_array, self.zip_file, name) zipadd(write_array, self.zip_file, name)
if type(obj) is cuda_ndarray.CudaNdarray: self.seen[id(obj)] = 'cuda_ndarray.{}'.format(name)
self.seen[id(obj)] = 'cuda_ndarray.{}'.format(name)
else:
self.seen[id(obj)] = 'ndarray.{}'.format(name)
return self.seen[id(obj)] return self.seen[id(obj)]
super(PersistentCudaNdarrayID, self).__call__(obj)
class PersistentSharedVariableID(PersistentNdarrayID): class PersistentSharedVariableID(PersistentCudaNdarrayID):
"""Persist the names of shared variable arrays in the zip file. """Uses shared variable names when persisting to zip file.
If a shared variable has a name, this name is used as the name of the If a shared variable has a name, this name is used as the name of the
NPY file inside of the zip file. NumPy arrays that aren't matched to a NPY file inside of the zip file. NumPy arrays that aren't matched to a
...@@ -232,7 +246,7 @@ class PersistentNdarrayLoad(object): ...@@ -232,7 +246,7 @@ class PersistentNdarrayLoad(object):
"numpy.ndarray") "numpy.ndarray")
return array return array
elif cuda_ndarray: elif cuda_ndarray:
return cuda_ndarray.CudaNdarray(array) return cuda_ndarray.cuda_ndarray.CudaNdarray(array)
else: else:
raise ImportError("Cuda not found. Cannot unpickle " raise ImportError("Cuda not found. Cannot unpickle "
"CudaNdarray") "CudaNdarray")
...@@ -240,18 +254,18 @@ class PersistentNdarrayLoad(object): ...@@ -240,18 +254,18 @@ class PersistentNdarrayLoad(object):
return array return array
def dump(obj, f, protocol=DEFAULT_PROTOCOL, def dump(obj, file_handler, protocol=DEFAULT_PROTOCOL,
persistent_id=PersistentSharedVariableID): persistent_id=PersistentSharedVariableID):
"""Pickles an object to a zip file using external persistence. """Pickles an object to a zip file using external persistence.
:param obj: The object to pickle. :param obj: The object to pickle.
:type obj: object :type obj: object
:param f: The file handle to save the object to. :param file_handler: The file handle to save the object to.
:type f: file :type file_handler: file
:param protocol: The pickling protocol to use. Unlike Python's built-in :param protocol: The pickling protocol to use. Unlike Python's built-in
pickle, the default is set to `2` insstead of 0 for Python 2. The pickle, the default is set to `2` instead of 0 for Python 2. The
Python 3 default (level 3) is maintained. Python 3 default (level 3) is maintained.
:type protocol: int, optional :type protocol: int, optional
...@@ -283,7 +297,7 @@ def dump(obj, f, protocol=DEFAULT_PROTOCOL, ...@@ -283,7 +297,7 @@ def dump(obj, f, protocol=DEFAULT_PROTOCOL,
array(2) array(2)
""" """
with closing(zipfile.ZipFile(f, 'w', zipfile.ZIP_DEFLATED, with closing(zipfile.ZipFile(file_handler, 'w', zipfile.ZIP_DEFLATED,
allowZip64=True)) as zip_file: allowZip64=True)) as zip_file:
def func(f): def func(f):
p = pickle.Pickler(f, protocol=protocol) p = pickle.Pickler(f, protocol=protocol)
......
from numpy.testing import assert_allclose from numpy.testing import assert_allclose
from nose.plugins.skip import SkipTest
import theano.sandbox.cuda as cuda_ndarray
from theano.sandbox.cuda.type import CudaNdarrayType from theano.sandbox.cuda.type import CudaNdarrayType
from theano.sandbox.cuda.var import CudaNdarraySharedVariable from theano.sandbox.cuda.var import CudaNdarraySharedVariable
from theano.misc.pkl_utils import dump, load from theano.misc.pkl_utils import dump, load
if not cuda_ndarray.cuda_available:
raise SkipTest('Optional package cuda disabled')
def test_dump_load(): def test_dump_load():
x = CudaNdarraySharedVariable('x', CudaNdarrayType((1, 1), name='x'), x = CudaNdarraySharedVariable('x', CudaNdarrayType((1, 1), name='x'),
...@@ -16,4 +22,4 @@ def test_dump_load(): ...@@ -16,4 +22,4 @@ def test_dump_load():
x = load(f) x = load(f)
assert x.name == 'x' assert x.name == 'x'
assert_allclose(x.get_value(), [[1]]) assert_allclose(x.get_value(), [[1]])
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论