提交 8c44e2f3 authored 作者: Hengjean's avatar Hengjean 提交者: Frederic

set cache optimization to ignore cached variables error.

上级 e220e10e
......@@ -1085,12 +1085,12 @@ class FunctionMaker(object):
print 'graph_db exists'
else:
# create graph_db
f = open(graph_db_file, 'w+b')
f = open(graph_db_file, 'wb')
print 'created new graph_db %s' % graph_db_file
f.close
# load the graph_db dictionary
try:
f = open(graph_db_file, 'r+b')
f = open(graph_db_file, 'rb')
tmp = theano.config.unpickle_function
theano.config.unpickle_function = False
graph_db = cPickle.load(f)
......@@ -1137,8 +1137,8 @@ class FunctionMaker(object):
flags = []
for output_new, output_old, i in zip(outputs_new, outputs_old, range(len(outputs_new))):
print 'loop through outputs node for both graphs'
f2 = graph_old.clone()
graph_old.variables = set(gof.graph.variables(graph_old.inputs, graph_old.outputs))
f2 = graph_old.clone(check_integrity=False)
t1 = output_new
t2 = f2.outputs[i]
......@@ -1182,13 +1182,14 @@ class FunctionMaker(object):
if need_optimize:
# this is a brand new graph, optimize it, save it to graph_db
print 'optimizing the graph'
before_opt = fgraph.clone()
fgraph.variables = set(gof.graph.variables(fgraph.inputs, fgraph.outputs))
before_opt = fgraph.clone(check_integrity=False)
start_optimizer = time.time()
optimizer_profile = optimizer(fgraph)
end_optimizer = time.time()
opt_time = end_optimizer - start_optimizer
graph_db.update({before_opt:fgraph})
f = open(graph_db_file, 'w+b')
f = open(graph_db_file, 'wb')
cPickle.dump(graph_db, f, -1)
f.close()
print 'saved into graph_db'
......
......@@ -729,16 +729,18 @@ class FunctionGraph(utils.object2):
return self.__str__()
### clone ###
def clone(self):
def clone(self, check_integrity=True):
"""WRITEME"""
return self.clone_get_equiv()[0]
return self.clone_get_equiv(check_integrity)[0]
def clone_get_equiv(self):
def clone_get_equiv(self, check_integrity=True):
"""WRITEME"""
equiv = graph.clone_get_equiv(self.inputs, self.outputs)
if check_integrity:
self.check_integrity()
e = FunctionGraph([equiv[i] for i in self.inputs],
[equiv[o] for o in self.outputs])
if check_integrity:
e.check_integrity()
for feature in self._features:
e.attach_feature(feature)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论