提交 7b4507cd authored 作者: Frederic's avatar Frederic

Make a list of extra tag to remove and add this parameter to function_dump. fix gh-3381.

上级 db344856
......@@ -23,7 +23,8 @@ def function_dump(filename, inputs, outputs=None, mode=None, updates=None,
givens=None,
no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=None, profile=None,
on_unused_input=None):
on_unused_input=None,
extra_tag_to_remove=None):
"""
This is helpful to make a reproducable case for problem during Theano
compilation.
......@@ -48,6 +49,11 @@ def function_dump(filename, inputs, outputs=None, mode=None, updates=None,
>>> d = cPickle.load(open("func_dump.bin", "rb")) # doctest: +SKIP
>>> f = theano.function(**d) # doctest: +SKIP
Note:
The parameter extra_tag_to_remove, is passed to the StripPickler used.
To pickle graph made by Blocks, it must be:
['annotations', 'replacement_of', 'aggregation_scheme', 'rolesc']
"""
assert isinstance(filename, string_types)
d = dict(inputs=inputs, outputs=outputs, mode=mode, updates=updates,
......@@ -58,7 +64,9 @@ def function_dump(filename, inputs, outputs=None, mode=None, updates=None,
on_unused_input=on_unused_input)
with open(filename, 'wb') as f:
import theano.misc.pkl_utils
pickler = theano.misc.pkl_utils.StripPickler(f, protocol=-1)
pickler = theano.misc.pkl_utils.StripPickler(
f, protocol=-1,
extra_tag_to_remove=extra_tag_to_remove)
pickler.dump(d)
......
......@@ -55,22 +55,19 @@ class StripPickler(Pickler):
strip_pickler.dump(fn_args)
f.close()
"""
def __init__(self, file, protocol=0, extra_tag_to_remove=None):
# Can't use super as Pickler isn't a new style class
Pickler.__init__(self, file, protocol)
self.tag_to_remove = ['trace', 'test_value']
if extra_tag_to_remove:
self.tag_to_remove.extend(extra_tag_to_remove)
def save(self, obj):
# Remove the tag.trace attribute from Variable and Apply nodes
if isinstance(obj, theano.gof.utils.scratchpad):
if hasattr(obj, 'trace'):
del obj.trace
if hasattr(obj, 'test_value'):
del obj.test_value
# The next 4 items are from Blocks
if hasattr(obj, 'annotations'):
del obj.annontations
if hasattr(obj, 'replacement_of'):
del obj.replacement_of
if hasattr(obj, 'aggregation_scheme'):
del obj.aggregation_scheme
if hasattr(obj, 'rolesc'):
del obj.rolesc
for tag in self.tag_to_remove:
if hasattr(obj, tag):
del obj.__dict__[tag]
# Remove manually-added docstring of Elemwise ops
elif (isinstance(obj, theano.tensor.Elemwise)):
if '__doc__' in obj.__dict__:
......
......@@ -13,7 +13,7 @@ import theano.sandbox.cuda as cuda_ndarray
from theano.sandbox.cuda.type import CudaNdarrayType
from theano.sandbox.cuda.var import CudaNdarraySharedVariable
from theano.sandbox.rng_mrg import MRG_RandomStreams
from theano.misc.pkl_utils import dump, load
from theano.misc.pkl_utils import dump, load, StripPickler
class T_dump_load(unittest.TestCase):
......@@ -69,3 +69,25 @@ class T_dump_load(unittest.TestCase):
with open('model.zip', 'rb') as f:
foo_1, foo_2, foo_3, array = load(f)
assert array == numpy.array(3)
class TestStripPickler(unittest.TestCase):
def setUp(self):
# Work in a temporary directory to avoid cluttering the repository
self.origdir = os.getcwd()
self.tmpdir = mkdtemp()
os.chdir(self.tmpdir)
def tearDown(self):
# Get back to the original dir, and delete the temporary one
os.chdir(self.origdir)
if self.tmpdir is not None:
shutil.rmtree(self.tmpdir)
def test0(self):
with open('test.pkl', 'wb') as f:
m = theano.tensor.matrix()
dest_pkl = 'my_test.pkl'
f = open(dest_pkl, 'wb')
strip_pickler = StripPickler(f, protocol=-1)
strip_pickler.dump(m)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论