提交 55bd12ce authored 作者: Frederic's avatar Frederic

Fix some pickling issues

上级 7c70a9ba
...@@ -752,6 +752,8 @@ class FunctionGraph(utils.object2): ...@@ -752,6 +752,8 @@ class FunctionGraph(utils.object2):
for feature in self._features: for feature in self._features:
for attr in getattr(feature, "pickle_rm_attr", []): for attr in getattr(feature, "pickle_rm_attr", []):
del d[attr] del d[attr]
# The class Updater take fct as parameter and they are lambda function, so unpicklable.
# del d["execute_callbacks_times"]
return d return d
def __setstate__(self, dct): def __setstate__(self, dct):
......
...@@ -1241,6 +1241,30 @@ class PatternSub(LocalOptimizer): ...@@ -1241,6 +1241,30 @@ class PatternSub(LocalOptimizer):
# Use the following classes to apply LocalOptimizers # Use the following classes to apply LocalOptimizers
class Updater:
def __init__(self, importer, pruner, chin):
self.importer = importer
self.pruner = pruner
self.chin = chin
def on_import(self, fgraph, node, reason):
if self.importer:
self.importer(node)
def on_prune(self, fgraph, node, reason):
if self.pruner:
self.pruner(node)
def on_change_input(self, fgraph, node, i, r, new_r, reason):
if self.chin:
self.chin(node, i, r, new_r, reason)
def on_detach(self, fgraph):
# To allow pickling this object
self.importer = None
self.pruner = None
self.chin = None
class NavigatorOptimizer(Optimizer): class NavigatorOptimizer(Optimizer):
"""Abstract class """Abstract class
...@@ -1329,18 +1353,7 @@ class NavigatorOptimizer(Optimizer): ...@@ -1329,18 +1353,7 @@ class NavigatorOptimizer(Optimizer):
if importer is None and pruner is None: if importer is None and pruner is None:
return None return None
class Updater: u = Updater(importer, pruner, chin)
if importer is not None:
def on_import(self, fgraph, node, reason):
importer(node)
if pruner is not None:
def on_prune(self, fgraph, node, reason):
pruner(node)
if chin is not None:
def on_change_input(self, fgraph, node, i, r, new_r, reason):
chin(node, i, r, new_r, reason)
u = Updater()
fgraph.attach_feature(u) fgraph.attach_feature(u)
return u return u
......
...@@ -1494,11 +1494,11 @@ class GemmOptimizer(Optimizer): ...@@ -1494,11 +1494,11 @@ class GemmOptimizer(Optimizer):
callbacks_before = fgraph.execute_callbacks_times.copy() callbacks_before = fgraph.execute_callbacks_times.copy()
callback_before = fgraph.execute_callbacks_time callback_before = fgraph.execute_callbacks_time
class Updater: def on_import(self, new_node):
def on_import(self, fgraph, new_node, reason):
if new_node is not node: if new_node is not node:
nodelist.append(new_node) nodelist.append(new_node)
u = Updater()
u = theano.gof.opt.Updater(on_import, None, None)
fgraph.attach_feature(u) fgraph.attach_feature(u)
while did_something: while did_something:
nb_iter += 1 nb_iter += 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论