提交 d0c93bbe authored 作者: Hengjean's avatar Hengjean 提交者: Frederic

Added flag, defaulted to False.

上级 99157a6d
...@@ -1104,104 +1104,109 @@ class FunctionMaker(object): ...@@ -1104,104 +1104,109 @@ class FunctionMaker(object):
graph_db = {} graph_db = {}
print 'loaded graph_db from %s, size=%d' % (graph_db_file, len(graph_db)) print 'loaded graph_db from %s, size=%d' % (graph_db_file, len(graph_db))
need_optimize = True if theano.config.cache_optimizations:
# the sole purpose of this loop is to set 'need_optimize' need_optimize = True
for i, graph_old in enumerate(graph_db.keys()): # the sole purpose of this loop is to set 'need_optimize'
inputs_old = graph_old.inputs for i, graph_old in enumerate(graph_db.keys()):
outputs_old = graph_old.outputs inputs_old = graph_old.inputs
size_old = len(graph_old.apply_nodes) outputs_old = graph_old.outputs
print 'looping through graph_db %d/%d' % (i + 1, len(graph_db)) size_old = len(graph_old.apply_nodes)
# Some heuristics to check is the same graphs have print 'looping through graph_db %d/%d' % (i + 1, len(graph_db))
# already been optimized before. # Some heuristics to check is the same graphs have
if len(inputs_new) != len(inputs_old): # already been optimized before.
# If the inputs are of different size, if len(inputs_new) != len(inputs_old):
# two graphs are for sure different # If the inputs are of different size,
print 'need to optimize, because input size is different' # two graphs are for sure different
continue print 'need to optimize, because input size is different'
elif len(outputs_new) != len(outputs_old): continue
# If the inputs are of different size, elif len(outputs_new) != len(outputs_old):
# two graphs are for sure different # If the inputs are of different size,
print 'need to optimize, because output size is different' # two graphs are for sure different
continue print 'need to optimize, because output size is different'
elif not all(input_new.type == input_old.type for continue
input_new, input_old in zip(inputs_new, inputs_old)): elif not all(input_new.type == input_old.type for
print 'need to optimize, because inputs are of different types' input_new, input_old in zip(inputs_new, inputs_old)):
continue print 'need to optimize, because inputs are of different types'
elif not all(output_new.type == output_old.type for continue
output_new, output_old in zip(outputs_new, outputs_old)): elif not all(output_new.type == output_old.type for
print 'need to optimize, because outputs are of different types' output_new, output_old in zip(outputs_new, outputs_old)):
continue print 'need to optimize, because outputs are of different types'
elif not size_old == size_new: continue
print 'need to optimize, because numbers of nodes in graph are different' elif not size_old == size_new:
continue print 'need to optimize, because numbers of nodes in graph are different'
continue
else:
flags = []
for output_new, output_old, i in zip(outputs_new, outputs_old, range(len(outputs_new))):
print 'loop through outputs node for both graphs'
f2 = output_old.owner.fgraph.clone()
t1 = output_new
t2 = f2.outputs[i]
def removeAllFgraph(remove):
if hasattr(remove, 'fgraph'):
del remove.fgraph
if hasattr(remove, 'owner'):
if remove.owner == None:
pass
else:
if hasattr(remove.owner, 'fgraph'):
del remove.owner.fgraph
if hasattr(remove.owner, 'inputs'):
remove.owner.inputs = [removeAllFgraph(
i) for i in remove.owner.inputs]
for o in remove.owner.outputs:
if hasattr(o, 'fgraph'):
del o.fgraph
return remove
t2 = removeAllFgraph(t2)
givens = dict(zip(gof.graph.inputs([t1]),
gof.graph.inputs([t2])))
temp = dict(zip(gof.graph.inputs([t1]),
gof.graph.inputs([t2])))
for key, value in temp.iteritems():
if key.type != value.type:
del givens[key]
flag = is_same_graph(t1, t2, givens=givens)
flags.append(flag)
is_same = all(flags)
if is_same:
# found the match
print 'found #TODO: he match, no need to optimize'
need_optimize = False
key = graph_old
break
# now optimize or not
if need_optimize:
# this is a brand new graph, optimize it, save it to graph_db
print 'optimizing the graph'
before_opt = fgraph.clone()
start_optimizer = time.time()
optimizer_profile = optimizer(fgraph)
end_optimizer = time.time()
opt_time = end_optimizer - start_optimizer
graph_db.update({before_opt:fgraph})
f = open(graph_db_file, 'w+b')
cPickle.dump(graph_db, f, -1)
f.close()
print 'saved into graph_db'
else: else:
flags = [] print 'no opt, get graph from graph_db'
for output_new, output_old, i in zip(outputs_new, outputs_old, range(len(outputs_new))): # just read the optmized graph from graph_db
print 'loop through outputs node for both graphs' opt_time = 0
self.fgraph = graph_db[key]
f2 = output_old.owner.fgraph.clone() fgraph = self.fgraph
t1 = output_new # release stuff
t2 = f2.outputs[i] release_lock()
else:
def removeAllFgraph(remove):
if hasattr(remove, 'fgraph'):
del remove.fgraph
if hasattr(remove, 'owner'):
if remove.owner == None:
pass
else:
if hasattr(remove.owner, 'fgraph'):
del remove.owner.fgraph
if hasattr(remove.owner, 'inputs'):
remove.owner.inputs = [removeAllFgraph(
i) for i in remove.owner.inputs]
for o in remove.owner.outputs:
if hasattr(o, 'fgraph'):
del o.fgraph
return remove
t2 = removeAllFgraph(t2)
givens = dict(zip(gof.graph.inputs([t1]),
gof.graph.inputs([t2])))
temp = dict(zip(gof.graph.inputs([t1]),
gof.graph.inputs([t2])))
for key, value in temp.iteritems():
if key.type != value.type:
del givens[key]
flag = is_same_graph(t1, t2, givens=givens)
flags.append(flag)
is_same = all(flags)
if is_same:
# found the match
print 'found #TODO: he match, no need to optimize'
need_optimize = False
key = graph_old
break
# now optimize or not
if need_optimize:
# this is a brand new graph, optimize it, save it to graph_db
print 'optimizing the graph'
before_opt = fgraph.clone()
start_optimizer = time.time() start_optimizer = time.time()
optimizer_profile = optimizer(fgraph) optimizer_profile = optimizer(fgraph)
end_optimizer = time.time() end_optimizer = time.time()
opt_time = end_optimizer - start_optimizer opt_time = end_optimizer - start_optimizer
graph_db.update({before_opt:fgraph})
f = open(graph_db_file, 'w+b')
cPickle.dump(graph_db, f, -1)
f.close()
print 'saved into graph_db'
else:
print 'no opt, get graph from graph_db'
# just read the optmized graph from graph_db
opt_time = 0
self.fgraph = graph_db[key]
fgraph = self.fgraph
# release stuff
release_lock()
print 'opt took %s' % opt_time print 'opt took %s' % opt_time
if profile: if profile:
profile.optimizer_time += opt_time profile.optimizer_time += opt_time
......
...@@ -538,3 +538,10 @@ AddConfigVar('check_input', ...@@ -538,3 +538,10 @@ AddConfigVar('check_input',
"(particularly for scalars) and reduce the number of generated C " "(particularly for scalars) and reduce the number of generated C "
"files.", "files.",
BoolParam(True)) BoolParam(True))
AddConfigVar('cache_optimizations',
"Specify if the optimization cache should be used. This cache will"
"any optimized graph and its optimization. Actually slow downs a lot"
"the first optimization, and could possibly still contains some bugs."
"Use at your own risks.",
BoolParam(False))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论