提交 a65ca93c authored 作者: Hengjean's avatar Hengjean 提交者: Frederic

Fixed a little error and added comments.

上级 8c44e2f3
......@@ -1077,7 +1077,10 @@ class FunctionMaker(object):
need_optimize = False
get_lock()
key = None
if theano.config.cache_optimizations:
#Beginning of cache optimizations.
#Could be refactored in different functions.
if theano.config.cache_optimizations: #set to false by default
'''
graph_db and need_optimize
'''
......@@ -1087,14 +1090,22 @@ class FunctionMaker(object):
# create graph_db
f = open(graph_db_file, 'wb')
print 'created new graph_db %s' % graph_db_file
f.close
#file needs to be open and closed for every pickle
f.close()
# load the graph_db dictionary
try:
f = open(graph_db_file, 'rb')
#Temporary hack to allow theano.scan_module.tests.test_scan.T_Scan
#to finish. Should be changed in definitive version.
tmp = theano.config.unpickle_function
theano.config.unpickle_function = False
graph_db = cPickle.load(f)
theano.config.unpickle_function = tmp
#hack end
f.close()
print 'graph_db is not empty'
except EOFError, e:
......@@ -1102,7 +1113,9 @@ class FunctionMaker(object):
print e
print 'graph_db is empty'
graph_db = {}
need_optimize = True
print 'loaded graph_db from %s, size=%d' % (graph_db_file, len(graph_db))
# the sole purpose of this loop is to set 'need_optimize'
for i, graph_old in enumerate(graph_db.keys()):
......@@ -1133,15 +1146,20 @@ class FunctionMaker(object):
elif not size_old == size_new:
print 'need to optimize, because numbers of nodes in graph are different'
continue
else:
flags = []
for output_new, output_old, i in zip(outputs_new, outputs_old, range(len(outputs_new))):
print 'loop through outputs node for both graphs'
graph_old.variables = set(gof.graph.variables(graph_old.inputs, graph_old.outputs))
#using clone allowed to avoid a lot of errors
#deep copy seemed to had.
f2 = graph_old.clone(check_integrity=False)
t1 = output_new
t2 = f2.outputs[i]
#Used to remove "already used by another graph error
def removeAllFgraph(remove):
if hasattr(remove, 'fgraph'):
del remove.fgraph
......@@ -1160,14 +1178,22 @@ class FunctionMaker(object):
return remove
t2 = removeAllFgraph(t2)
givens = dict(zip(gof.graph.inputs([t1]),
gof.graph.inputs([t2])))
temp = dict(zip(gof.graph.inputs([t1]),
gof.graph.inputs([t2])))
gof.graph.inputs([t2])))
#hack to remove inconstent entry in givens
#seems to work that but source of inconsistency
#could be worth investigating.
for key, value in temp.iteritems():
if key.type != value.type:
del givens[key]
flag = is_same_graph(t1, t2, givens=givens)
flags.append(flag)
is_same = all(flags)
......@@ -1183,6 +1209,10 @@ class FunctionMaker(object):
# this is a brand new graph, optimize it, save it to graph_db
print 'optimizing the graph'
fgraph.variables = set(gof.graph.variables(fgraph.inputs, fgraph.outputs))
#check_integrity parameters was added to ignore
#"excess cached variables" errors. Works that way
#but once again the error couldbe worth
#investigating.
before_opt = fgraph.clone(check_integrity=False)
start_optimizer = time.time()
optimizer_profile = optimizer(fgraph)
......@@ -1197,10 +1227,17 @@ class FunctionMaker(object):
print 'no opt, get graph from graph_db'
# just read the optmized graph from graph_db
opt_time = 0
#"Naive" insertion. It's seems to work, but there may
#be some problems inserting it like that.
self.fgraph = graph_db[key]
fgraph = self.fgraph
# release stuff
release_lock()
#end of cache optimization
#else containing the old code
else:
start_optimizer = time.time()
optimizer_profile = optimizer(fgraph)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论