提交 ed31d840 authored 作者: Ziye Fan's avatar Ziye Fan

add argument final_optimizer for EquilibriumOptimizer

上级 2425cd11
......@@ -1690,7 +1690,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
optimizers,
failure_callback=None,
ignore_newtrees=True,
max_use_ratio=None):
max_use_ratio=None,
final_optimizers=None):
""" Apply optimizations until equilibrium point.
:param optimizers: list or set of local or global optimizations to
......@@ -1710,6 +1711,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
self.local_optimizers_map = dict()
self.local_optimizers_all = []
self.global_optimizers = []
self.final_optimizers = []
for opt in optimizers:
if isinstance(opt, LocalOptimizer):
......@@ -1720,6 +1722,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
self.local_optimizers_map.setdefault(c, []).append(opt)
else:
self.global_optimizers.append(opt)
if final_optimizers:
self.final_optimizers = final_optimizers
self.max_use_ratio = max_use_ratio
assert self.max_use_ratio is not None, (
'max_use_ratio has to be a number')
......@@ -1766,7 +1770,9 @@ class EquilibriumOptimizer(NavigatorOptimizer):
io_toposort_timing = []
nb_nodes = []
node_created = {}
for opt in self.global_optimizers + list(self.get_local_optimizers()):
for opt in (self.global_optimizers +
list(self.get_local_optimizers()) +
self.final_optimizers):
global_process_count.setdefault(opt, 0)
time_opts.setdefault(opt, 0)
node_created.setdefault(opt, 0)
......@@ -1845,6 +1851,26 @@ class EquilibriumOptimizer(NavigatorOptimizer):
finally:
self.detach_updater(fgraph, u)
# Apply final optimizers
for gopt in self.final_optimizers:
change_tracker.reset()
nb = change_tracker.nb_imported
t_opt = time.time()
gopt.apply(fgraph)
time_opts[gopt] += time.time() - t_opt
if change_tracker.changed:
process_count.setdefault(gopt, 0)
process_count[gopt] += 1
global_process_count[gopt] += 1
changed = True
node_created[gopt] += change_tracker.nb_imported - nb
if global_process_count[gopt] > max_use:
max_use_abort = True
opt_name = (getattr(gopt, "name", None)
or getattr(gopt, "__name__", ""))
global_opt_timing.append(float(time.time() - t0))
loop_process_count.append(process_count)
loop_timing.append(float(time.time() - t0))
......
......@@ -225,14 +225,29 @@ class EquilibriumDB(DB):
def __init__(self, ignore_newtrees=True):
super(EquilibriumDB, self).__init__()
self.ignore_newtrees = ignore_newtrees
self.__final__ = {}
def register(self, name, obj, *tags, **kwtags):
# if name == 'cut_gpua_constant_transfers':
# import ipdb;ipdb.set_trace()
if 'final_opt' in tags:
final_opt = True
else:
final_opt = False
super(EquilibriumDB, self).register(name, obj, *tags, **kwtags)
self.__final__[name] = final_opt
def query(self, *tags, **kwtags):
opts = super(EquilibriumDB, self).query(*tags, **kwtags)
final_opts = [o for o in opts if self.__final__.get(o.name, False)]
if len(final_opts) == 0:
final_opts = None
return opt.EquilibriumOptimizer(
opts,
max_use_ratio=config.optdb.max_use_ratio,
ignore_newtrees=self.ignore_newtrees,
failure_callback=opt.NavigatorOptimizer.warn_inplace)
failure_callback=opt.NavigatorOptimizer.warn_inplace,
final_optimizers=final_opts)
class SequenceDB(DB):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论