提交 5999b913 authored 作者: Frederic's avatar Frederic

Use Ordered dict and set to be sure it is deterministic

上级 c802b3ed
...@@ -17,11 +17,12 @@ import numpy ...@@ -17,11 +17,12 @@ import numpy
import theano import theano
from theano import config from theano import config
from theano.compat import izip from theano.compat import izip, OrderedDict
from six import string_types, iteritems, itervalues from six import string_types, iteritems, itervalues
from six.moves import reduce from six.moves import reduce
from theano.gof import graph, op, utils, unify, toolbox from theano.gof import graph, op, utils, unify, toolbox
from theano.gof.fg import InconsistencyError from theano.gof.fg import InconsistencyError
from theano.misc.ordered_set import OrderedSet
from . import destroyhandler as dh from . import destroyhandler as dh
...@@ -1715,7 +1716,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1715,7 +1716,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
None, None,
ignore_newtrees=ignore_newtrees, ignore_newtrees=ignore_newtrees,
failure_callback=failure_callback) failure_callback=failure_callback)
self.local_optimizers_map = dict() self.local_optimizers_map = OrderedDict()
self.local_optimizers_all = [] self.local_optimizers_all = []
self.global_optimizers = [] self.global_optimizers = []
self.final_optimizers = [] self.final_optimizers = []
...@@ -1955,41 +1956,58 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1955,41 +1956,58 @@ class EquilibriumOptimizer(NavigatorOptimizer):
for count in loop_process_count: for count in loop_process_count:
for o, v in iteritems(count): for o, v in iteritems(count):
process_count[o] += v process_count[o] += v
for opt, count in iteritems(process_count): for o, count in iteritems(process_count):
if count > 0: if count > 0:
count_opt.append((time_opts[opt], count, count_opt.append((time_opts[o], count,
node_created[opt], opt)) node_created[o], o))
else: else:
not_used.append((time_opts[opt], opt)) not_used.append((time_opts[o], o))
not_used_time += time_opts[opt] not_used_time += time_opts[o]
if count_opt: if count_opt:
print(blanc, \ print(blanc, \
' times - times applied - nb node created - name:', file=stream) ' times - times applied - nb node created - name:', file=stream)
count_opt.sort() count_opt.sort()
for (t, count, n_created, opt) in count_opt[::-1]: for (t, count, n_created, o) in count_opt[::-1]:
print(blanc, ' %.3fs - %d - %d - %s' % ( print(blanc, ' %.3fs - %d - %d - %s' % (
t, count, n_created, opt), file=stream) t, count, n_created, o), file=stream)
print(blanc, ' %.3fs - in %d optimization that where not used (display only those with a runtime > 0)' % ( print(blanc, ' %.3fs - in %d optimization that where not used (display only those with a runtime > 0)' % (
not_used_time, len(not_used)), file=stream) not_used_time, len(not_used)), file=stream)
not_used.sort() not_used.sort()
for (t, opt) in not_used[::-1]: for (t, o) in not_used[::-1]:
if t > 0: if t > 0:
# Skip opt that have 0 times, they probably wasn't even tried. # Skip opt that have 0 times, they probably wasn't even tried.
print(blanc + " ", ' %.3fs - %s' % (t, opt), file=stream) print(blanc + " ", ' %.3fs - %s' % (t, o), file=stream)
print(file=stream) print(file=stream)
if (len(opt.global_optimizers) + len(opt.final_optimizers) == 0 or
# sum([time_opts[o] for o in opt.global_optimizers + opt.final_optimizers]) < 1 or
False):
return
for i in range(len(loop_timing)):
print(blanc, "Iter %d" % i, file=stream)
for o, prof in zip(opt.global_optimizers, global_sub_profs[i]):
try:
o.print_profile(stream, prof, level + 2)
except NotImplementedError:
import pdb;pdb.set_trace()
print(blanc, "merge not implemented for ", o)
for o, prof in zip(opt.final_optimizers, final_sub_profs[i]):
try:
o.print_profile(stream, prof, level + 2)
except NotImplementedError:
import pdb;pdb.set_trace()
print(blanc, "merge not implemented for ", o)
@staticmethod @staticmethod
def merge_profile(prof1, prof2): def merge_profile(prof1, prof2):
#(opt, loop_timing, loop_process_count, max_nb_nodes, #(opt, loop_timing, loop_process_count, max_nb_nodes,
# global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof1 # global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof1
local_optimizers = OrderedSet(prof1[0].get_local_optimizers()).union(
local_optimizers = set(prof1[0].get_local_optimizers()).union(
prof2[0].get_local_optimizers()) prof2[0].get_local_optimizers())
global_optimizers = set(prof1[0].global_optimizers).union( global_optimizers = OrderedSet(prof1[0].global_optimizers).union(
prof2[0].global_optimizers) prof2[0].global_optimizers)
if len(prof1[0].final_optimizers) > 0 or len(prof2[0].final_optimizers) > 0: if len(prof1[0].final_optimizers) > 0 or len(prof2[0].final_optimizers) > 0:
final_optimizers = set(prof1[0].final_optimizers).union( final_optimizers = OrderedSet(prof1[0].final_optimizers).union(
prof2[0].final_optimizers) prof2[0].final_optimizers)
else: else:
final_optimizers = None final_optimizers = None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论