提交 e3079e41 authored 作者: Hengjean's avatar Hengjean 提交者: Frederic

Began refactoring. Fixed bug.

上级 da877d34
...@@ -1067,163 +1067,136 @@ class FunctionMaker(object): ...@@ -1067,163 +1067,136 @@ class FunctionMaker(object):
theano.config.compute_test_value = theano.config.compute_test_value_opt theano.config.compute_test_value = theano.config.compute_test_value_opt
gof.Op.add_stack_trace_on_call = False gof.Op.add_stack_trace_on_call = False
def optimize_graph(fgraph): from theano.gof.compilelock import get_lock, release_lock
''' import os.path
params graph_db_file = os.path.join(theano.config.compiledir, 'optimized_graphs.pkl')
------ # the inputs, outputs, and size of the graph to be optimized
fgraph: the new graph to be optimized, optimized in-place. inputs_new = [inp.variable for inp in inputs]
{before_opt: after_opt, ....} outputs_new = [out.variable for out in outputs]
size_new = len(fgraph.apply_nodes)
return need_optimize = False
------ get_lock()
opt_time: timing key = None
''' '''
from theano.gof.compilelock import get_lock, release_lock graph_db and need_optimize
import cPickle '''
import os.path if os.path.isfile(graph_db_file):
graph_db_file = os.path.join(theano.config.compiledir, 'optimized_graphs.pkl') print 'graph_db exists'
# the inputs, outputs, and size of the graph to be optimized else:
inputs_new = fgraph.inputs # create graph_db
outputs_new = fgraph.outputs f = open(graph_db_file, 'w+b')
size_new = len(fgraph.apply_nodes) print 'created new graph_db %s' % graph_db_file
need_optimize = False f.close
get_lock()
key = None # load the graph_db dictionary
''' try:
graph_db and need_optimize f = open(graph_db_file, 'r+b')
''' graph_db = cPickle.load(f)
if os.path.isfile(graph_db_file): f.close()
print 'graph_db exists' print 'graph_db is not empty'
except EOFError, e:
# the file has nothing in it
print e
print 'graph_db is empty'
graph_db = {}
print 'loaded graph_db from %s, size=%d' % (graph_db_file, len(graph_db))
need_optimize = True
# the sole purpose of this loop is to set 'need_optimize'
for i, graph_old in enumerate(graph_db.keys()):
inputs_old = graph_old.inputs
outputs_old = graph_old.outputs
size_old = len(graph_old.apply_nodes)
print 'looping through graph_db %d/%d' % (i + 1, len(graph_db))
# Some heuristics to check is the same graphs have
# already been optimized before.
if len(inputs_new) != len(inputs_old):
# If the inputs are of different size,
# two graphs are for sure different
print 'need to optimize, because input size is different'
continue
elif len(outputs_new) != len(outputs_old):
# If the inputs are of different size,
# two graphs are for sure different
print 'need to optimize, because output size is different'
continue
elif not all(input_new.type == input_old.type for
input_new, input_old in zip(inputs_new, inputs_old)):
print 'need to optimize, because inputs are of different types'
continue
elif not all(output_new.type == output_old.type for
output_new, output_old in zip(outputs_new, outputs_old)):
print 'need to optimize, because outputs are of different types'
continue
elif not size_old == size_new:
print 'need to optimize, because numbers of nodes in graph are different'
continue
else: else:
# create graph_db flags = []
f = open(graph_db_file, 'w+b') for output_new, output_old, i in zip(outputs_new, outputs_old, range(len(outputs_new))):
print 'created new graph_db %s' % graph_db_file print 'loop through outputs node for both graphs'
f.close
f2 = output_old.owner.fgraph.clone()
# load the graph_db dictionary t1 = output_new
try: t2 = f2.outputs[i]
f = open(graph_db_file, 'r+b')
graph_db = cPickle.load(f) def removeAllFgraph(remove):
f.close() if hasattr(remove, 'fgraph'):
print 'graph_db is not empty' del remove.fgraph
except EOFError, e: if hasattr(remove, 'owner'):
# the file has nothing in it if remove.owner == None:
print e pass
print 'graph_db is empty' else:
graph_db = {} if hasattr(remove.owner, 'fgraph'):
del remove.owner.fgraph
print 'loaded graph_db from %s, size=%d'%(graph_db_file,len(graph_db)) if hasattr(remove.owner, 'inputs'):
need_optimize = True remove.owner.inputs = [removeAllFgraph(
# the sole purpose of this loop is to set 'need_optimize' i) for i in remove.owner.inputs]
for i, graph_old in enumerate(graph_db.keys()): for o in remove.owner.outputs:
inputs_old = graph_old.inputs if hasattr(o, 'fgraph'):
outputs_old = graph_old.outputs del o.fgraph
size_old = len(graph_old.apply_nodes) return remove
print 'looping through graph_db %d/%d'%(i+1,len(graph_db))
# Some heuristics to check is the same graphs have t2 = removeAllFgraph(t2)
# already been optimized before. givens = dict(zip(gof.graph.inputs([t1]),
if len(inputs_new) != len(inputs_old): gof.graph.inputs([t2])))
# If the inputs are of different size, temp = dict(zip(gof.graph.inputs([t1]),
# two graphs are for sure different gof.graph.inputs([t2])))
print 'need to optimize, because input size is different' for key, value in temp.iteritems():
continue if key.type != value.type:
elif len(outputs_new) != len(outputs_old): del givens[key]
# If the inputs are of different size, flag = is_same_graph(t1, t2, givens=givens)
# two graphs are for sure different flags.append(flag)
print 'need to optimize, because output size is different'
continue is_same = all(flags)
elif not all(input_new.type == input_old.type for if is_same:
input_new, input_old in zip(inputs_new, inputs_old)): # found the match
print 'need to optimize, because inputs are of different types' print 'found #TODO: he match, no need to optimize'
continue need_optimize = False
elif not all(output_new.type == output_old.type for key = graph_old
output_new, output_old in zip(outputs_new, outputs_old)): break
print 'need to optimize, because outputs are of different types'
continue
elif not len(fgraph.apply_nodes) == len(graph_old.apply_nodes):
print 'need to optimize, because numbers of nodes in graph are different'
continue
else:
# when the both inputs are of the same size
givens = dict(zip(inputs_new, inputs_old))
'''
# strip .fgraph off the givens
i_new = [copy.deepcopy(input_new) for input_new in inputs_new]
i_old = [copy.deepcopy(input_old) for input_old in inputs_old]
for node in i_new:
node.fgraph = None
for node in i_old:
node.fgraph = None
givens = dict(zip(i_new, i_old))
'''
# each element indicates if one of the outputs has the same graph
flags = []
for output_new, output_old, i in zip(outputs_new, outputs_old, range(len(outputs_new))):
print 'loop through outputs node for both graphs'
f1 = output_new.owner.fgraph.clone()
f2 = output_old.owner.fgraph.clone()
# is_same_graph complains if fgraph is not None
t1 = f1.outputs[i]
t2 = f2.outputs[i]
def removeAllFgraph(remove):
if hasattr(remove, 'fgraph'):
del remove.fgraph
if hasattr(remove, 'owner'):
if remove.owner == None:
pass
else:
if hasattr(remove.owner, 'fgraph'):
del remove.owner.fgraph
if hasattr(remove.owner, 'inputs'):
remove.owner.inputs = [removeAllFgraph(
i) for i in remove.owner.inputs]
for o in remove.owner.outputs:
if hasattr(o, 'fgraph'):
del o.fgraph
return remove
t1 = removeAllFgraph(t1)
t2 = removeAllFgraph(t2)
givens = dict(zip(gof.graph.ancestors([t1]),
gof.graph.ancestors([t2])))
flag = is_same_graph(t1, t2, givens=givens)
flags.append(flag)
is_same = all(flags)
if is_same:
# found the match
print 'found #TODO: he match, no need to optimize'
need_optimize = False
key = graph_old
break
# now optimize or not # now optimize or not
if need_optimize: if need_optimize:
# this is a brand new graph, optimize it, save it to graph_db # this is a brand new graph, optimize it, save it to graph_db
print 'optimizing the graph' print 'optimizing the graph'
before_opt = fgraph.clone() before_opt = fgraph.clone()
start_optimizer = time.time() start_optimizer = time.time()
optimizer_profile = optimizer(fgraph) optimizer_profile = optimizer(fgraph)
end_optimizer = time.time() end_optimizer = time.time()
opt_time = end_optimizer - start_optimizer opt_time = end_optimizer - start_optimizer
graph_db.update({before_opt:fgraph}) graph_db.update({before_opt:fgraph})
f = open(graph_db_file, 'w+b') f = open(graph_db_file, 'w+b')
cPickle.dump(graph_db, f, -1) cPickle.dump(graph_db, f, -1)
f.close() f.close()
print 'saved into graph_db' print 'saved into graph_db'
else: else:
print 'no opt, get graph from graph_db' print 'no opt, get graph from graph_db'
# just read the optmized graph from graph_db # just read the optmized graph from graph_db
opt_time = 0 opt_time = 0
fgraph = graph_db[key] fgraph = graph_db[key]
# release stuff # release stuff
release_lock() release_lock()
return opt_time
opt_time = optimize_graph(fgraph)
print 'opt took %s'%opt_time print 'opt took %s'%opt_time
if profile: if profile:
......
...@@ -694,7 +694,7 @@ class VM_Linker(link.LocalLinker): ...@@ -694,7 +694,7 @@ class VM_Linker(link.LocalLinker):
if k.owner and k.clients: if k.owner and k.clients:
ls = [] ls = []
for cl in k.clients: for cl in k.clients:
if cl[0] is not 'output': if cl[0] != 'output':
ls += cl[0].outputs ls += cl[0].outputs
dependencies[k] += ls dependencies[k] += ls
return dependencies return dependencies
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论