提交 2dc51295 authored 作者: Frederic's avatar Frederic

Fix pickling of FunctionGraph

上级 2a5c64c4
......@@ -743,3 +743,19 @@ class FunctionGraph(utils.object2):
for feature in self._features:
e.attach_feature(feature)
return e, equiv
def __getstate__(self):
"""This is needed as some feature introduce instancemethod and
this is not pickable.
"""
d = self.__dict__.copy()
for feature in self._features:
for attr in getattr(feature, "pickle_rm_attr", []):
del d[attr]
return d
def __setstate__(self, dct):
self.__dict__.update(dct)
for feature in self._features:
if hasattr(feature, "unpickle"):
feature.unpickle(self)
import pickle
import unittest
import theano
from theano.gof import CachedConstantError, FunctionGraph
from theano import tensor as tt
class TFunctionGraph(unittest.TestCase):
......@@ -15,3 +17,10 @@ class TFunctionGraph(unittest.TestCase):
v = theano.tensor.constant(1)
assert v.cached
FunctionGraph([], [v + 1])
def test_pickle(self):
v = tt.vector()
func = theano.gof.FunctionGraph([v], [v + 1])
s = pickle.dumps(func)
func2 = pickle.loads(s)
......@@ -105,6 +105,7 @@ class Bookkeeper(Feature):
class History(Feature):
pickle_rm_attr = ["checkpoint", "revert"]
def __init__(self):
self.history = {}
......@@ -114,6 +115,13 @@ class History(Feature):
raise AlreadyThere("History feature is already present or in"
" conflict with another plugin.")
self.history[fgraph] = []
# Don't call unpickle here, as ReplaceValidate.on_attach()
# call to History.on_attach() will call the
# ReplaceValidate.unpickle and not History.unpickle
fgraph.checkpoint = lambda: len(self.history[fgraph])
fgraph.revert = partial(self.revert, fgraph)
def unpickle(self, fgraph):
fgraph.checkpoint = lambda: len(self.history[fgraph])
fgraph.revert = partial(self.revert, fgraph)
......@@ -144,47 +152,66 @@ class History(Feature):
class Validator(Feature):
pickle_rm_attr = ["validate", "consistent"]
def on_attach(self, fgraph):
for attr in ('validate', 'validate_time'):
if hasattr(fgraph, attr):
raise AlreadyThere("Validator feature is already present or in"
" conflict with another plugin.")
# Don't call unpickle here, as ReplaceValidate.on_attach()
# call to History.on_attach() will call the
# ReplaceValidate.unpickle and not History.unpickle
fgraph.validate = partial(self.validate, fgraph)
fgraph.consistent = partial(self.consistent, fgraph)
def validate():
t0 = time.time()
ret = fgraph.execute_callbacks('validate')
t1 = time.time()
if fgraph.profile:
fgraph.profile.validate_time += t1 - t0
return ret
fgraph.validate = validate
def consistent():
try:
fgraph.validate()
return True
except Exception:
return False
fgraph.consistent = consistent
def unpickle(self, fgraph):
fgraph.validate = partial(self.validate, fgraph)
fgraph.consistent = partial(self.consistent, fgraph)
def on_detach(self, fgraph):
del fgraph.validate
del fgraph.consistent
def validate(self, fgraph):
t0 = time.time()
ret = fgraph.execute_callbacks('validate')
t1 = time.time()
if fgraph.profile:
fgraph.profile.validate_time += t1 - t0
return ret
def consistent(self, fgraph):
try:
fgraph.validate()
return True
except Exception:
return False
class ReplaceValidate(History, Validator):
pickle_rm_attr = ["replace_validate", "replace_all_validate",
"replace_all_validate_remove",
#Parent pickle_rm_attr
"consistent", "validate", "checkpoint", "revert"
]
def on_attach(self, fgraph):
History.on_attach(self, fgraph)
Validator.on_attach(self, fgraph)
for attr in ('replace_validate', 'replace_all_validate'):
for attr in ('replace_validate', 'replace_all_validate',
'replace_all_validate_remove'):
if hasattr(fgraph, attr):
raise AlreadyThere("ReplaceValidate feature is already present"
" or in conflict with another plugin.")
History.on_attach(self, fgraph)
Validator.on_attach(self, fgraph)
self.unpickle(fgraph)
def unpickle(self, fgraph):
History.unpickle(self, fgraph)
Validator.unpickle(self, fgraph)
fgraph.replace_validate = partial(self.replace_validate, fgraph)
fgraph.replace_all_validate = partial(self.replace_all_validate, fgraph)
fgraph.replace_all_validate = partial(self.replace_all_validate,
fgraph)
fgraph.replace_all_validate_remove = partial(
self.replace_all_validate_remove, fgraph)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论