提交 4f3aef6e authored 作者: Frederic's avatar Frederic

Make function_dump strip blocks stuff

上级 edfb49e4
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
Define the `function` function. Define the `function` function.
""" """
import six.moves.cPickle as pickle
import logging import logging
import traceback as tb import traceback as tb
...@@ -58,7 +57,9 @@ def function_dump(filename, inputs, outputs=None, mode=None, updates=None, ...@@ -58,7 +57,9 @@ def function_dump(filename, inputs, outputs=None, mode=None, updates=None,
allow_input_downcast=allow_input_downcast, profile=profile, allow_input_downcast=allow_input_downcast, profile=profile,
on_unused_input=on_unused_input) on_unused_input=on_unused_input)
with open(filename, 'wb') as f: with open(filename, 'wb') as f:
pickle.dump(d, f, -1) import theano.misc.pkl_utils
pickler = theano.misc.pkl_utils.StripPickler(f, protocol=-1)
pickler.dump(d)
def function(inputs, outputs=None, mode=None, updates=None, givens=None, def function(inputs, outputs=None, mode=None, updates=None, givens=None,
......
...@@ -60,7 +60,17 @@ class StripPickler(Pickler): ...@@ -60,7 +60,17 @@ class StripPickler(Pickler):
if isinstance(obj, theano.gof.utils.scratchpad): if isinstance(obj, theano.gof.utils.scratchpad):
if hasattr(obj, 'trace'): if hasattr(obj, 'trace'):
del obj.trace del obj.trace
if hasattr(obj, 'test_value'):
del obj.test_value
# The next 4 items are from Blocks
if hasattr(obj, 'annontations'):
del obj.annontations
if hasattr(obj, 'replacement_of'):
del obj.replacement_of
if hasattr(obj, 'aggregation_scheme'):
del obj.aggregation_scheme
if hasattr(obj, 'rolesc'):
del obj.rolesc
# Remove manually-added docstring of Elemwise ops # Remove manually-added docstring of Elemwise ops
elif (isinstance(obj, theano.tensor.Elemwise)): elif (isinstance(obj, theano.tensor.Elemwise)):
if '__doc__' in obj.__dict__: if '__doc__' in obj.__dict__:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论