提交 7b4507cd authored 作者: Frederic's avatar Frederic

Make a list of extra tag to remove and add this parameter to function_dump. fix gh-3381.

上级 db344856
...@@ -23,7 +23,8 @@ def function_dump(filename, inputs, outputs=None, mode=None, updates=None, ...@@ -23,7 +23,8 @@ def function_dump(filename, inputs, outputs=None, mode=None, updates=None,
givens=None, givens=None,
no_default_updates=False, accept_inplace=False, name=None, no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=None, profile=None, rebuild_strict=True, allow_input_downcast=None, profile=None,
on_unused_input=None): on_unused_input=None,
extra_tag_to_remove=None):
""" """
This is helpful to make a reproducable case for problem during Theano This is helpful to make a reproducable case for problem during Theano
compilation. compilation.
...@@ -48,6 +49,11 @@ def function_dump(filename, inputs, outputs=None, mode=None, updates=None, ...@@ -48,6 +49,11 @@ def function_dump(filename, inputs, outputs=None, mode=None, updates=None,
>>> d = cPickle.load(open("func_dump.bin", "rb")) # doctest: +SKIP >>> d = cPickle.load(open("func_dump.bin", "rb")) # doctest: +SKIP
>>> f = theano.function(**d) # doctest: +SKIP >>> f = theano.function(**d) # doctest: +SKIP
Note:
The parameter extra_tag_to_remove, is passed to the StripPickler used.
To pickle graph made by Blocks, it must be:
['annotations', 'replacement_of', 'aggregation_scheme', 'rolesc']
""" """
assert isinstance(filename, string_types) assert isinstance(filename, string_types)
d = dict(inputs=inputs, outputs=outputs, mode=mode, updates=updates, d = dict(inputs=inputs, outputs=outputs, mode=mode, updates=updates,
...@@ -58,7 +64,9 @@ def function_dump(filename, inputs, outputs=None, mode=None, updates=None, ...@@ -58,7 +64,9 @@ def function_dump(filename, inputs, outputs=None, mode=None, updates=None,
on_unused_input=on_unused_input) on_unused_input=on_unused_input)
with open(filename, 'wb') as f: with open(filename, 'wb') as f:
import theano.misc.pkl_utils import theano.misc.pkl_utils
pickler = theano.misc.pkl_utils.StripPickler(f, protocol=-1) pickler = theano.misc.pkl_utils.StripPickler(
f, protocol=-1,
extra_tag_to_remove=extra_tag_to_remove)
pickler.dump(d) pickler.dump(d)
......
...@@ -55,22 +55,19 @@ class StripPickler(Pickler): ...@@ -55,22 +55,19 @@ class StripPickler(Pickler):
strip_pickler.dump(fn_args) strip_pickler.dump(fn_args)
f.close() f.close()
""" """
def __init__(self, file, protocol=0, extra_tag_to_remove=None):
# Can't use super as Pickler isn't a new style class
Pickler.__init__(self, file, protocol)
self.tag_to_remove = ['trace', 'test_value']
if extra_tag_to_remove:
self.tag_to_remove.extend(extra_tag_to_remove)
def save(self, obj): def save(self, obj):
# Remove the tag.trace attribute from Variable and Apply nodes # Remove the tag.trace attribute from Variable and Apply nodes
if isinstance(obj, theano.gof.utils.scratchpad): if isinstance(obj, theano.gof.utils.scratchpad):
if hasattr(obj, 'trace'): for tag in self.tag_to_remove:
del obj.trace if hasattr(obj, tag):
if hasattr(obj, 'test_value'): del obj.__dict__[tag]
del obj.test_value
# The next 4 items are from Blocks
if hasattr(obj, 'annotations'):
del obj.annontations
if hasattr(obj, 'replacement_of'):
del obj.replacement_of
if hasattr(obj, 'aggregation_scheme'):
del obj.aggregation_scheme
if hasattr(obj, 'rolesc'):
del obj.rolesc
# Remove manually-added docstring of Elemwise ops # Remove manually-added docstring of Elemwise ops
elif (isinstance(obj, theano.tensor.Elemwise)): elif (isinstance(obj, theano.tensor.Elemwise)):
if '__doc__' in obj.__dict__: if '__doc__' in obj.__dict__:
......
...@@ -13,7 +13,7 @@ import theano.sandbox.cuda as cuda_ndarray ...@@ -13,7 +13,7 @@ import theano.sandbox.cuda as cuda_ndarray
from theano.sandbox.cuda.type import CudaNdarrayType from theano.sandbox.cuda.type import CudaNdarrayType
from theano.sandbox.cuda.var import CudaNdarraySharedVariable from theano.sandbox.cuda.var import CudaNdarraySharedVariable
from theano.sandbox.rng_mrg import MRG_RandomStreams from theano.sandbox.rng_mrg import MRG_RandomStreams
from theano.misc.pkl_utils import dump, load from theano.misc.pkl_utils import dump, load, StripPickler
class T_dump_load(unittest.TestCase): class T_dump_load(unittest.TestCase):
...@@ -69,3 +69,25 @@ class T_dump_load(unittest.TestCase): ...@@ -69,3 +69,25 @@ class T_dump_load(unittest.TestCase):
with open('model.zip', 'rb') as f: with open('model.zip', 'rb') as f:
foo_1, foo_2, foo_3, array = load(f) foo_1, foo_2, foo_3, array = load(f)
assert array == numpy.array(3) assert array == numpy.array(3)
class TestStripPickler(unittest.TestCase):
def setUp(self):
# Work in a temporary directory to avoid cluttering the repository
self.origdir = os.getcwd()
self.tmpdir = mkdtemp()
os.chdir(self.tmpdir)
def tearDown(self):
# Get back to the original dir, and delete the temporary one
os.chdir(self.origdir)
if self.tmpdir is not None:
shutil.rmtree(self.tmpdir)
def test0(self):
with open('test.pkl', 'wb') as f:
m = theano.tensor.matrix()
dest_pkl = 'my_test.pkl'
f = open(dest_pkl, 'wb')
strip_pickler = StripPickler(f, protocol=-1)
strip_pickler.dump(m)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论