提交 6bf3539c authored 作者: abergeron's avatar abergeron

Merge pull request #1718 from nouiz/pickle_eval

Fix crash during unpickling of Theano variable
...@@ -709,6 +709,8 @@ def _pickle_Function(f): ...@@ -709,6 +709,8 @@ def _pickle_Function(f):
return rval return rval
def _constructor_Function(maker, input_storage, inputs_data): def _constructor_Function(maker, input_storage, inputs_data):
if not theano.config.unpickle_function:
return None
f = maker.create(input_storage, trustme = True) f = maker.create(input_storage, trustme = True)
assert len(f.input_storage) == len(inputs_data) assert len(f.input_storage) == len(inputs_data)
for container, x in zip(f.input_storage, inputs_data): for container, x in zip(f.input_storage, inputs_data):
...@@ -1204,7 +1206,10 @@ def _pickle_FunctionMaker(self): ...@@ -1204,7 +1206,10 @@ def _pickle_FunctionMaker(self):
def _constructor_FunctionMaker(kwargs): def _constructor_FunctionMaker(kwargs):
if theano.config.unpickle_function:
return FunctionMaker(**kwargs) return FunctionMaker(**kwargs)
else:
return None
copy_reg.pickle(FunctionMaker, _pickle_FunctionMaker) copy_reg.pickle(FunctionMaker, _pickle_FunctionMaker)
......
...@@ -410,6 +410,13 @@ AddConfigVar('compute_test_value_opt', ...@@ -410,6 +410,13 @@ AddConfigVar('compute_test_value_opt',
EnumStr('off', 'ignore', 'warn', 'raise', 'pdb'), EnumStr('off', 'ignore', 'warn', 'raise', 'pdb'),
in_c_key=False) in_c_key=False)
AddConfigVar('unpickle_function',
("Replace unpickled Theano function with None",
"This is useful to unpickle old graph that pickled"
" them when it shouldn't"),
BoolParam(True),
in_c_key=False)
"""Note to developers: """Note to developers:
Generally your exceptions should use an apply node's __str__ Generally your exceptions should use an apply node's __str__
method when exception_verbosity == 'low'. When exception_verbosity method when exception_verbosity == 'low'. When exception_verbosity
......
...@@ -424,6 +424,10 @@ class Variable(Node): ...@@ -424,6 +424,10 @@ class Variable(Node):
return rval return rval
def __getstate__(self):
d = self.__dict__.copy()
d.pop("_fn", None)
return d
env = property(env_getter, env_setter, env_deleter) env = property(env_getter, env_setter, env_deleter)
......
import pickle
import unittest import unittest
from theano import tensor from theano import tensor
...@@ -292,7 +293,6 @@ class TestIsSameGraph(unittest.TestCase): ...@@ -292,7 +293,6 @@ class TestIsSameGraph(unittest.TestCase):
debug=False) debug=False)
################ ################
# eval # # eval #
################ ################
...@@ -305,3 +305,6 @@ def test_eval(): ...@@ -305,3 +305,6 @@ def test_eval():
result = z.eval({x : 1., y : 2.}) result = z.eval({x : 1., y : 2.})
assert result == 3. assert result == 3.
# We don't want to pickle the tmp function.
assert not hasattr(pickle.loads(pickle.dumps(z)), '_fn')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论