提交 ddb24d79 authored 作者: Hengjean's avatar Hengjean 提交者: Frederic

Using temporary pickle to debug. Refactored lambda expressions used in history.…

Using temporary pickle to debug. Refactored lambda expressions used in history. Refactored problematic Opt.
上级 0e773298
......@@ -1079,9 +1079,9 @@ class FunctionMaker(object):
opt_time: timing
'''
from theano.gof.compilelock import get_lock, release_lock
import cPickle
import pickle
import os.path
graph_db_file = theano.config.compiledir + '/optimized_graphs.pkl'
graph_db_file = os.path.join(theano.config.compiledir, 'optimized_graphs.pkl')
# the inputs, outputs, and size of the graph to be optimized
inputs_new = fgraph.inputs
outputs_new = fgraph.outputs
......@@ -1101,10 +1101,11 @@ class FunctionMaker(object):
# load the graph_db dictionary
try:
graph_db = cPickle.load(f)
graph_db = pickle.load(f)
print 'graph_db is not empty'
except EOFError:
except EOFError, e:
# the file has nothing in it
print e
print 'graph_db is empty'
graph_db = {}
......@@ -1180,7 +1181,7 @@ class FunctionMaker(object):
cPickle.load(test_file)
'''
graph_db.update({before_opt:fgraph})
cPickle.dump(graph_db, f, -1)
pickle.dump(graph_db, f, -1)
print 'saved into graph_db'
else:
print 'no opt, get graph from graph_db'
......
......@@ -104,6 +104,30 @@ class Bookkeeper(Feature):
self.on_prune(fgraph, node, 'Bookkeeper.detach')
class getCheckpoint:
def __init__(self, history, fgraph):
self.h = history
self.fgraph = fgraph
def __call__(self):
return len(self.h.history[self.fgraph])
class lambdextract:
def __init__(self, fgraph, node, i, r, reason=None):
self.fgraph = fgraph
self.node = node
self.i = i
self.r = r
self.reason = reason
def __call__(self):
return self.fgraph.change_input(self.node, self.i, self.r,
reason=("Revert", self.reason))
class History(Feature):
pickle_rm_attr = ["checkpoint", "revert"]
......@@ -118,11 +142,11 @@ class History(Feature):
# Don't call unpickle here, as ReplaceValidate.on_attach()
# call to History.on_attach() will call the
# ReplaceValidate.unpickle and not History.unpickle
fgraph.checkpoint = lambda: len(self.history[fgraph])
fgraph.checkpoint = getCheckpoint(self, fgraph)
fgraph.revert = partial(self.revert, fgraph)
def unpickle(self, fgraph):
fgraph.checkpoint = lambda: len(self.history[fgraph])
fgraph.checkpoint = getCheckpoint(self, fgraph)
fgraph.revert = partial(self.revert, fgraph)
def on_detach(self, fgraph):
......@@ -134,8 +158,7 @@ class History(Feature):
if self.history[fgraph] is None:
return
h = self.history[fgraph]
h.append(lambda: fgraph.change_input(node, i, r,
reason=("Revert", reason)))
h.append(lambdextract(fgraph, node, i, r, reason))
def revert(self, fgraph, checkpoint):
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论