提交 2dc51295 authored 作者: Frederic's avatar Frederic

Fix pickling of FunctionGraph

上级 2a5c64c4
...@@ -743,3 +743,19 @@ class FunctionGraph(utils.object2): ...@@ -743,3 +743,19 @@ class FunctionGraph(utils.object2):
for feature in self._features: for feature in self._features:
e.attach_feature(feature) e.attach_feature(feature)
return e, equiv return e, equiv
def __getstate__(self):
"""This is needed as some feature introduce instancemethod and
this is not pickable.
"""
d = self.__dict__.copy()
for feature in self._features:
for attr in getattr(feature, "pickle_rm_attr", []):
del d[attr]
return d
def __setstate__(self, dct):
self.__dict__.update(dct)
for feature in self._features:
if hasattr(feature, "unpickle"):
feature.unpickle(self)
import pickle
import unittest import unittest
import theano import theano
from theano.gof import CachedConstantError, FunctionGraph from theano.gof import CachedConstantError, FunctionGraph
from theano import tensor as tt
class TFunctionGraph(unittest.TestCase): class TFunctionGraph(unittest.TestCase):
...@@ -15,3 +17,10 @@ class TFunctionGraph(unittest.TestCase): ...@@ -15,3 +17,10 @@ class TFunctionGraph(unittest.TestCase):
v = theano.tensor.constant(1) v = theano.tensor.constant(1)
assert v.cached assert v.cached
FunctionGraph([], [v + 1]) FunctionGraph([], [v + 1])
def test_pickle(self):
v = tt.vector()
func = theano.gof.FunctionGraph([v], [v + 1])
s = pickle.dumps(func)
func2 = pickle.loads(s)
...@@ -105,6 +105,7 @@ class Bookkeeper(Feature): ...@@ -105,6 +105,7 @@ class Bookkeeper(Feature):
class History(Feature): class History(Feature):
pickle_rm_attr = ["checkpoint", "revert"]
def __init__(self): def __init__(self):
self.history = {} self.history = {}
...@@ -114,6 +115,13 @@ class History(Feature): ...@@ -114,6 +115,13 @@ class History(Feature):
raise AlreadyThere("History feature is already present or in" raise AlreadyThere("History feature is already present or in"
" conflict with another plugin.") " conflict with another plugin.")
self.history[fgraph] = [] self.history[fgraph] = []
# Don't call unpickle here, as ReplaceValidate.on_attach()
# call to History.on_attach() will call the
# ReplaceValidate.unpickle and not History.unpickle
fgraph.checkpoint = lambda: len(self.history[fgraph])
fgraph.revert = partial(self.revert, fgraph)
def unpickle(self, fgraph):
fgraph.checkpoint = lambda: len(self.history[fgraph]) fgraph.checkpoint = lambda: len(self.history[fgraph])
fgraph.revert = partial(self.revert, fgraph) fgraph.revert = partial(self.revert, fgraph)
...@@ -144,14 +152,28 @@ class History(Feature): ...@@ -144,14 +152,28 @@ class History(Feature):
class Validator(Feature): class Validator(Feature):
pickle_rm_attr = ["validate", "consistent"]
def on_attach(self, fgraph): def on_attach(self, fgraph):
for attr in ('validate', 'validate_time'): for attr in ('validate', 'validate_time'):
if hasattr(fgraph, attr): if hasattr(fgraph, attr):
raise AlreadyThere("Validator feature is already present or in" raise AlreadyThere("Validator feature is already present or in"
" conflict with another plugin.") " conflict with another plugin.")
# Don't call unpickle here, as ReplaceValidate.on_attach()
# call to History.on_attach() will call the
# ReplaceValidate.unpickle and not History.unpickle
fgraph.validate = partial(self.validate, fgraph)
fgraph.consistent = partial(self.consistent, fgraph)
def validate(): def unpickle(self, fgraph):
fgraph.validate = partial(self.validate, fgraph)
fgraph.consistent = partial(self.consistent, fgraph)
def on_detach(self, fgraph):
del fgraph.validate
del fgraph.consistent
def validate(self, fgraph):
t0 = time.time() t0 = time.time()
ret = fgraph.execute_callbacks('validate') ret = fgraph.execute_callbacks('validate')
t1 = time.time() t1 = time.time()
...@@ -159,32 +181,37 @@ class Validator(Feature): ...@@ -159,32 +181,37 @@ class Validator(Feature):
fgraph.profile.validate_time += t1 - t0 fgraph.profile.validate_time += t1 - t0
return ret return ret
fgraph.validate = validate def consistent(self, fgraph):
def consistent():
try: try:
fgraph.validate() fgraph.validate()
return True return True
except Exception: except Exception:
return False return False
fgraph.consistent = consistent
def on_detach(self, fgraph):
del fgraph.validate
del fgraph.consistent
class ReplaceValidate(History, Validator): class ReplaceValidate(History, Validator):
pickle_rm_attr = ["replace_validate", "replace_all_validate",
"replace_all_validate_remove",
#Parent pickle_rm_attr
"consistent", "validate", "checkpoint", "revert"
]
def on_attach(self, fgraph): def on_attach(self, fgraph):
History.on_attach(self, fgraph) for attr in ('replace_validate', 'replace_all_validate',
Validator.on_attach(self, fgraph) 'replace_all_validate_remove'):
for attr in ('replace_validate', 'replace_all_validate'):
if hasattr(fgraph, attr): if hasattr(fgraph, attr):
raise AlreadyThere("ReplaceValidate feature is already present" raise AlreadyThere("ReplaceValidate feature is already present"
" or in conflict with another plugin.") " or in conflict with another plugin.")
History.on_attach(self, fgraph)
Validator.on_attach(self, fgraph)
self.unpickle(fgraph)
def unpickle(self, fgraph):
History.unpickle(self, fgraph)
Validator.unpickle(self, fgraph)
fgraph.replace_validate = partial(self.replace_validate, fgraph) fgraph.replace_validate = partial(self.replace_validate, fgraph)
fgraph.replace_all_validate = partial(self.replace_all_validate, fgraph) fgraph.replace_all_validate = partial(self.replace_all_validate,
fgraph)
fgraph.replace_all_validate_remove = partial( fgraph.replace_all_validate_remove = partial(
self.replace_all_validate_remove, fgraph) self.replace_all_validate_remove, fgraph)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论