提交 bad34f33 authored 作者: Benjamin Scellier's avatar Benjamin Scellier

file theano/misc/pkl_utils.py

上级 fc8e7324
...@@ -5,7 +5,7 @@ These pickled graphs can be used, for instance, as cases for ...@@ -5,7 +5,7 @@ These pickled graphs can be used, for instance, as cases for
unit tests or regression tests. unit tests or regression tests.
""" """
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import numpy import numpy as np
import os import os
import pickle import pickle
import sys import sys
...@@ -188,10 +188,10 @@ class PersistentNdarrayID(object): ...@@ -188,10 +188,10 @@ class PersistentNdarrayID(object):
return name return name
def __call__(self, obj): def __call__(self, obj):
if type(obj) is numpy.ndarray: if type(obj) is np.ndarray:
if id(obj) not in self.seen: if id(obj) not in self.seen:
def write_array(f): def write_array(f):
numpy.lib.format.write_array(f, obj) np.lib.format.write_array(f, obj)
name = self._resolve_name(obj) name = self._resolve_name(obj)
zipadd(write_array, self.zip_file, name) zipadd(write_array, self.zip_file, name)
self.seen[id(obj)] = 'ndarray.{0}'.format(name) self.seen[id(obj)] = 'ndarray.{0}'.format(name)
...@@ -204,7 +204,7 @@ class PersistentCudaNdarrayID(PersistentNdarrayID): ...@@ -204,7 +204,7 @@ class PersistentCudaNdarrayID(PersistentNdarrayID):
type(obj) is cuda_ndarray.cuda_ndarray.CudaNdarray): type(obj) is cuda_ndarray.cuda_ndarray.CudaNdarray):
if id(obj) not in self.seen: if id(obj) not in self.seen:
def write_array(f): def write_array(f):
numpy.lib.format.write_array(f, numpy.asarray(obj)) np.lib.format.write_array(f, np.asarray(obj))
name = self._resolve_name(obj) name = self._resolve_name(obj)
zipadd(write_array, self.zip_file, name) zipadd(write_array, self.zip_file, name)
self.seen[id(obj)] = 'cuda_ndarray.{0}'.format(name) self.seen[id(obj)] = 'cuda_ndarray.{0}'.format(name)
...@@ -283,7 +283,7 @@ class PersistentNdarrayLoad(object): ...@@ -283,7 +283,7 @@ class PersistentNdarrayLoad(object):
if name in self.cache: if name in self.cache:
return self.cache[name] return self.cache[name]
ret = None ret = None
array = numpy.lib.format.read_array(self.zip_file.open(name)) array = np.lib.format.read_array(self.zip_file.open(name))
if array_type == 'cuda_ndarray': if array_type == 'cuda_ndarray':
if config.experimental.unpickle_gpu_on_cpu: if config.experimental.unpickle_gpu_on_cpu:
# directly return numpy array # directly return numpy array
...@@ -335,10 +335,10 @@ def dump(obj, file_handler, protocol=DEFAULT_PROTOCOL, ...@@ -335,10 +335,10 @@ def dump(obj, file_handler, protocol=DEFAULT_PROTOCOL,
>>> foo_1 = theano.shared(0, name='foo') >>> foo_1 = theano.shared(0, name='foo')
>>> foo_2 = theano.shared(1, name='foo') >>> foo_2 = theano.shared(1, name='foo')
>>> with open('model.zip', 'wb') as f: >>> with open('model.zip', 'wb') as f:
... dump((foo_1, foo_2, numpy.array(2)), f) ... dump((foo_1, foo_2, np.array(2)), f)
>>> numpy.load('model.zip').keys() >>> np.load('model.zip').keys()
['foo', 'foo_2', 'array_0', 'pkl'] ['foo', 'foo_2', 'array_0', 'pkl']
>>> numpy.load('model.zip')['foo'] >>> np.load('model.zip')['foo']
array(0) array(0)
>>> with open('model.zip', 'rb') as f: >>> with open('model.zip', 'rb') as f:
... foo_1, foo_2, array = load(f) ... foo_1, foo_2, array = load(f)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论