提交 8c44e2f3 authored 作者: Hengjean's avatar Hengjean 提交者: Frederic

set cache optimization to ignore cached variables error.

上级 e220e10e
...@@ -1085,12 +1085,12 @@ class FunctionMaker(object): ...@@ -1085,12 +1085,12 @@ class FunctionMaker(object):
print 'graph_db exists' print 'graph_db exists'
else: else:
# create graph_db # create graph_db
f = open(graph_db_file, 'w+b') f = open(graph_db_file, 'wb')
print 'created new graph_db %s' % graph_db_file print 'created new graph_db %s' % graph_db_file
f.close f.close
# load the graph_db dictionary # load the graph_db dictionary
try: try:
f = open(graph_db_file, 'r+b') f = open(graph_db_file, 'rb')
tmp = theano.config.unpickle_function tmp = theano.config.unpickle_function
theano.config.unpickle_function = False theano.config.unpickle_function = False
graph_db = cPickle.load(f) graph_db = cPickle.load(f)
...@@ -1137,8 +1137,8 @@ class FunctionMaker(object): ...@@ -1137,8 +1137,8 @@ class FunctionMaker(object):
flags = [] flags = []
for output_new, output_old, i in zip(outputs_new, outputs_old, range(len(outputs_new))): for output_new, output_old, i in zip(outputs_new, outputs_old, range(len(outputs_new))):
print 'loop through outputs node for both graphs' print 'loop through outputs node for both graphs'
graph_old.variables = set(gof.graph.variables(graph_old.inputs, graph_old.outputs))
f2 = graph_old.clone() f2 = graph_old.clone(check_integrity=False)
t1 = output_new t1 = output_new
t2 = f2.outputs[i] t2 = f2.outputs[i]
...@@ -1182,13 +1182,14 @@ class FunctionMaker(object): ...@@ -1182,13 +1182,14 @@ class FunctionMaker(object):
if need_optimize: if need_optimize:
# this is a brand new graph, optimize it, save it to graph_db # this is a brand new graph, optimize it, save it to graph_db
print 'optimizing the graph' print 'optimizing the graph'
before_opt = fgraph.clone() fgraph.variables = set(gof.graph.variables(fgraph.inputs, fgraph.outputs))
before_opt = fgraph.clone(check_integrity=False)
start_optimizer = time.time() start_optimizer = time.time()
optimizer_profile = optimizer(fgraph) optimizer_profile = optimizer(fgraph)
end_optimizer = time.time() end_optimizer = time.time()
opt_time = end_optimizer - start_optimizer opt_time = end_optimizer - start_optimizer
graph_db.update({before_opt:fgraph}) graph_db.update({before_opt:fgraph})
f = open(graph_db_file, 'w+b') f = open(graph_db_file, 'wb')
cPickle.dump(graph_db, f, -1) cPickle.dump(graph_db, f, -1)
f.close() f.close()
print 'saved into graph_db' print 'saved into graph_db'
......
...@@ -729,16 +729,18 @@ class FunctionGraph(utils.object2): ...@@ -729,16 +729,18 @@ class FunctionGraph(utils.object2):
return self.__str__() return self.__str__()
### clone ### ### clone ###
def clone(self): def clone(self, check_integrity=True):
"""WRITEME""" """WRITEME"""
return self.clone_get_equiv()[0] return self.clone_get_equiv(check_integrity)[0]
def clone_get_equiv(self): def clone_get_equiv(self, check_integrity=True):
"""WRITEME""" """WRITEME"""
equiv = graph.clone_get_equiv(self.inputs, self.outputs) equiv = graph.clone_get_equiv(self.inputs, self.outputs)
if check_integrity:
self.check_integrity() self.check_integrity()
e = FunctionGraph([equiv[i] for i in self.inputs], e = FunctionGraph([equiv[i] for i in self.inputs],
[equiv[o] for o in self.outputs]) [equiv[o] for o in self.outputs])
if check_integrity:
e.check_integrity() e.check_integrity()
for feature in self._features: for feature in self._features:
e.attach_feature(feature) e.attach_feature(feature)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论