提交 4cbd9ff5 authored 作者: Frederic's avatar Frederic

Better interface for cleanup opt

上级 54482de9
...@@ -200,6 +200,12 @@ optdb.register('merge1', gof.MergeOptimizer(), ...@@ -200,6 +200,12 @@ optdb.register('merge1', gof.MergeOptimizer(),
# rearranges elemwise expressions # rearranges elemwise expressions
optdb.register('canonicalize', gof.EquilibriumDB(ignore_newtrees=False), optdb.register('canonicalize', gof.EquilibriumDB(ignore_newtrees=False),
1, 'fast_run', 'fast_compile') 1, 'fast_run', 'fast_compile')
# Register in the canonizer Equilibrium as a local opt the merge opt.
# Without this, as the equilibrium have ignore_newtrees=False, we
# won't merge all nodes if it is set as a global optimizer with
# final_opt=True.
optdb['canonicalize'].register("merge", gof.opt.merge_optimizer, 'fast_run',
"fast_compile", cleanup=True)
optdb.register('merge1.2', gof.MergeOptimizer(), optdb.register('merge1.2', gof.MergeOptimizer(),
1.2, 'fast_run', 'fast_compile', 'merge') 1.2, 'fast_run', 'fast_compile', 'merge')
......
...@@ -1774,8 +1774,6 @@ class NavigatorOptimizer(Optimizer): ...@@ -1774,8 +1774,6 @@ class NavigatorOptimizer(Optimizer):
raise raise
if replacements is False or replacements is None: if replacements is False or replacements is None:
return False return False
if replacements is True:
return True
old_vars = node.outputs old_vars = node.outputs
if isinstance(replacements, dict): if isinstance(replacements, dict):
old_vars = list(replacements.keys()) old_vars = list(replacements.keys())
...@@ -1998,7 +1996,8 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1998,7 +1996,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
failure_callback=None, failure_callback=None,
ignore_newtrees=True, ignore_newtrees=True,
max_use_ratio=None, max_use_ratio=None,
final_optimizers=None): final_optimizers=None,
cleanup_optimizers=None):
super(EquilibriumOptimizer, self).__init__( super(EquilibriumOptimizer, self).__init__(
None, None,
ignore_newtrees=ignore_newtrees, ignore_newtrees=ignore_newtrees,
...@@ -2007,6 +2006,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2007,6 +2006,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
self.local_optimizers_all = [] self.local_optimizers_all = []
self.global_optimizers = [] self.global_optimizers = []
self.final_optimizers = [] self.final_optimizers = []
self.cleanup_optimizers = []
for opt in optimizers: for opt in optimizers:
if isinstance(opt, LocalOptimizer): if isinstance(opt, LocalOptimizer):
...@@ -2019,6 +2019,8 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2019,6 +2019,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
self.global_optimizers.append(opt) self.global_optimizers.append(opt)
if final_optimizers: if final_optimizers:
self.final_optimizers = final_optimizers self.final_optimizers = final_optimizers
if cleanup_optimizers:
self.cleanup_optimizers = cleanup_optimizers
self.max_use_ratio = max_use_ratio self.max_use_ratio = max_use_ratio
assert self.max_use_ratio is not None, ( assert self.max_use_ratio is not None, (
'max_use_ratio has to be a number') 'max_use_ratio has to be a number')
...@@ -2042,6 +2044,8 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2042,6 +2044,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
opt.add_requirements(fgraph) opt.add_requirements(fgraph)
for opt in self.final_optimizers: for opt in self.final_optimizers:
opt.add_requirements(fgraph) opt.add_requirements(fgraph)
for opt in self.cleanup_optimizers:
opt.add_requirements(fgraph)
def apply(self, fgraph, start_from=None): def apply(self, fgraph, start_from=None):
change_tracker = ChangeTracker() change_tracker = ChangeTracker()
...@@ -2069,9 +2073,11 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2069,9 +2073,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
node_created = {} node_created = {}
global_sub_profs = [] global_sub_profs = []
final_sub_profs = [] final_sub_profs = []
cleanup_sub_profs = []
for opt in (self.global_optimizers + for opt in (self.global_optimizers +
list(self.get_local_optimizers()) + list(self.get_local_optimizers()) +
self.final_optimizers): self.final_optimizers +
self.cleanup_optimizers):
global_process_count.setdefault(opt, 0) global_process_count.setdefault(opt, 0)
time_opts.setdefault(opt, 0) time_opts.setdefault(opt, 0)
node_created.setdefault(opt, 0) node_created.setdefault(opt, 0)
...@@ -2080,7 +2086,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2080,7 +2086,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
process_count = {} process_count = {}
t0 = time.time() t0 = time.time()
changed = False changed = False
iter_cleanup_sub_profs = {}
# apply global optimizers # apply global optimizers
sub_profs = [] sub_profs = []
for gopt in self.global_optimizers: for gopt in self.global_optimizers:
...@@ -2104,6 +2110,17 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2104,6 +2110,17 @@ class EquilibriumOptimizer(NavigatorOptimizer):
global_opt_timing.append(float(time.time() - t0)) global_opt_timing.append(float(time.time() - t0))
# apply clean up as global opt can have done changes that
# request that
for copt in self.cleanup_optimizers:
change_tracker.reset()
t_opt = time.time()
sub_prof = copt.apply(fgraph)
time_opts[copt] += time.time() - t_opt
iter_cleanup_sub_profs[copt] = [sub_prof]
if change_tracker.changed:
changed = True
# apply local optimizer # apply local optimizer
topo_t0 = time.time() topo_t0 = time.time()
q = deque(graph.io_toposort(fgraph.inputs, start_from)) q = deque(graph.io_toposort(fgraph.inputs, start_from))
...@@ -2137,12 +2154,18 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2137,12 +2154,18 @@ class EquilibriumOptimizer(NavigatorOptimizer):
t_opt = time.time() t_opt = time.time()
lopt_change = self.process_node(fgraph, node, lopt) lopt_change = self.process_node(fgraph, node, lopt)
time_opts[lopt] += time.time() - t_opt time_opts[lopt] += time.time() - t_opt
# TODO: if not ...: continue
if lopt_change: if lopt_change:
process_count.setdefault(lopt, 0) process_count.setdefault(lopt, 0)
process_count[lopt] += 1 process_count[lopt] += 1
global_process_count[lopt] += 1 global_process_count[lopt] += 1
changed = True changed = True
node_created[lopt] += change_tracker.nb_imported - nb node_created[lopt] += change_tracker.nb_imported - nb
for copt in self.cleanup_optimizers:
t_opt = time.time()
sub_prof = copt.apply(fgraph)
time_opts[copt] += time.time() - t_opt
iter_cleanup_sub_profs[copt].append(sub_prof)
if global_process_count[lopt] > max_use: if global_process_count[lopt] > max_use:
max_use_abort = True max_use_abort = True
opt_name = (getattr(lopt, "name", None) or opt_name = (getattr(lopt, "name", None) or
...@@ -2176,7 +2199,23 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2176,7 +2199,23 @@ class EquilibriumOptimizer(NavigatorOptimizer):
final_sub_profs.append(sub_profs) final_sub_profs.append(sub_profs)
global_opt_timing[-1] += time.time() - t_before_final_opt global_opt_timing[-1] += time.time() - t_before_final_opt
# apply clean up as final opt can have done changes that
# request that
for copt in self.cleanup_optimizers:
t_opt = time.time()
sub_prof = copt.apply(fgraph)
time_opts[copt] += time.time() - t_opt
iter_cleanup_sub_profs[copt] = [sub_prof]
# merge clean up profiles during that iteration.
c_sub_profs = []
for copt, sub_profs in iteritems(iter_cleanup_sub_profs):
sub_prof = sub_profs[0]
for s_p in sub_profs[1:]:
sub_prof = copt.merge_profile(sub_prof, s_p)
c_sub_profs.append(sub_prof)
cleanup_sub_profs.append(c_sub_profs)
loop_process_count.append(process_count) loop_process_count.append(process_count)
loop_timing.append(float(time.time() - t0)) loop_timing.append(float(time.time() - t0))
...@@ -2191,7 +2230,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2191,7 +2230,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
return (self, loop_timing, loop_process_count, return (self, loop_timing, loop_process_count,
(start_nb_nodes, end_nb_nodes, max_nb_nodes), (start_nb_nodes, end_nb_nodes, max_nb_nodes),
global_opt_timing, nb_nodes, time_opts, io_toposort_timing, global_opt_timing, nb_nodes, time_opts, io_toposort_timing,
node_created, global_sub_profs, final_sub_profs) node_created, global_sub_profs, final_sub_profs, cleanup_sub_profs)
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, 'name', None) name = getattr(self, 'name', None)
...@@ -2207,7 +2246,8 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2207,7 +2246,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
(opt, loop_timing, loop_process_count, (opt, loop_timing, loop_process_count,
(start_nb_nodes, end_nb_nodes, max_nb_nodes), (start_nb_nodes, end_nb_nodes, max_nb_nodes),
global_opt_timing, nb_nodes, time_opts, io_toposort_timing, global_opt_timing, nb_nodes, time_opts, io_toposort_timing,
node_created, global_sub_profs, final_sub_profs) = prof node_created, global_sub_profs, final_sub_profs,
cleanup_sub_profs) = prof
blanc = (' ' * level) blanc = (' ' * level)
print(blanc, "EquilibriumOptimizer", end=' ', file=stream) print(blanc, "EquilibriumOptimizer", end=' ', file=stream)
...@@ -2225,6 +2265,8 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2225,6 +2265,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
print(blanc, " time in global optimizers %.3fs" % s, file=stream) print(blanc, " time in global optimizers %.3fs" % s, file=stream)
s = sum([time_opts[o] for o in opt.final_optimizers]) s = sum([time_opts[o] for o in opt.final_optimizers])
print(blanc, " time in final optimizers %.3fs" % s, file=stream) print(blanc, " time in final optimizers %.3fs" % s, file=stream)
s = sum([time_opts[o] for o in opt.cleanup_optimizers])
print(blanc, " time in cleanup optimizers %.3fs" % s, file=stream)
for i in range(len(loop_timing)): for i in range(len(loop_timing)):
lopt = "" lopt = ""
if loop_process_count[i]: if loop_process_count[i]:
...@@ -2248,7 +2290,8 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2248,7 +2290,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
process_count = {} process_count = {}
for o in (opt.global_optimizers + for o in (opt.global_optimizers +
list(opt.get_local_optimizers()) + list(opt.get_local_optimizers()) +
list(opt.final_optimizers)): list(opt.final_optimizers) +
list(opt.cleanup_optimizers)):
process_count.setdefault(o, 0) process_count.setdefault(o, 0)
for count in loop_process_count: for count in loop_process_count:
for o, v in iteritems(count): for o, v in iteritems(count):
...@@ -2278,12 +2321,13 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2278,12 +2321,13 @@ class EquilibriumOptimizer(NavigatorOptimizer):
print(blanc + " ", ' %.3fs - %s' % (t, o), file=stream) print(blanc + " ", ' %.3fs - %s' % (t, o), file=stream)
print(file=stream) print(file=stream)
gf_opts = [o for o in (opt.global_optimizers + gf_opts = [o for o in (opt.global_optimizers +
list(opt.final_optimizers)) list(opt.final_optimizers) +
list(opt.cleanup_optimizers))
if o.print_profile.func_code is not if o.print_profile.func_code is not
Optimizer.print_profile.func_code] Optimizer.print_profile.func_code]
if not gf_opts: if not gf_opts:
return return
print(blanc, "Global and final optimizer", file=stream) print(blanc, "Global, final and clean up optimizers", file=stream)
for i in range(len(loop_timing)): for i in range(len(loop_timing)):
print(blanc, "Iter %d" % i, file=stream) print(blanc, "Iter %d" % i, file=stream)
for o, prof in zip(opt.global_optimizers, global_sub_profs[i]): for o, prof in zip(opt.global_optimizers, global_sub_profs[i]):
...@@ -2296,6 +2340,11 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2296,6 +2340,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
o.print_profile(stream, prof, level + 2) o.print_profile(stream, prof, level + 2)
except NotImplementedError: except NotImplementedError:
print(blanc, "merge not implemented for ", o) print(blanc, "merge not implemented for ", o)
for o, prof in zip(opt.cleanup_optimizers, cleanup_sub_profs[i]):
try:
o.print_profile(stream, prof, level + 2)
except NotImplementedError:
print(blanc, "merge not implemented for ", o)
@staticmethod @staticmethod
def merge_profile(prof1, prof2): def merge_profile(prof1, prof2):
...@@ -2310,10 +2359,16 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2310,10 +2359,16 @@ class EquilibriumOptimizer(NavigatorOptimizer):
prof2[0].final_optimizers) prof2[0].final_optimizers)
else: else:
final_optimizers = None final_optimizers = None
if len(prof1[0].cleanup_optimizers) > 0 or len(prof2[0].cleanup_optimizers) > 0:
cleanup_optimizers = OrderedSet(prof1[0].cleanup_optimizers).union(
prof2[0].cleanup_optimizers)
else:
cleanup_optimizers = None
new_opt = EquilibriumOptimizer( new_opt = EquilibriumOptimizer(
local_optimizers.union(global_optimizers), local_optimizers.union(global_optimizers),
max_use_ratio=1, max_use_ratio=1,
final_optimizers=final_optimizers) final_optimizers=final_optimizers,
cleanup_optimizers=cleanup_optimizers)
def merge_list(l1, l2): def merge_list(l1, l2):
l = copy.copy(l1) l = copy.copy(l1)
...@@ -2361,6 +2416,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2361,6 +2416,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
node_created = merge_dict(prof1[8], prof2[8]) node_created = merge_dict(prof1[8], prof2[8])
global_sub_profs = merge_list(prof1[9], prof2[9]) global_sub_profs = merge_list(prof1[9], prof2[9])
final_sub_profs = merge_list(prof1[10], prof2[10]) final_sub_profs = merge_list(prof1[10], prof2[10])
cleanup_sub_profs = merge_list(prof1[10], prof2[10])
return (new_opt, return (new_opt,
loop_timing, loop_timing,
loop_process_count, loop_process_count,
...@@ -2371,7 +2427,8 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2371,7 +2427,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
io_toposort_timing, io_toposort_timing,
node_created, node_created,
global_sub_profs, global_sub_profs,
final_sub_profs) final_sub_profs,
cleanup_sub_profs)
################# #################
# Utilities # # Utilities #
......
...@@ -265,28 +265,35 @@ class EquilibriumDB(DB): ...@@ -265,28 +265,35 @@ class EquilibriumDB(DB):
super(EquilibriumDB, self).__init__() super(EquilibriumDB, self).__init__()
self.ignore_newtrees = ignore_newtrees self.ignore_newtrees = ignore_newtrees
self.__final__ = {} self.__final__ = {}
self.__cleanup__ = {}
def register(self, name, obj, *tags, **kwtags): def register(self, name, obj, *tags, **kwtags):
if 'final_opt' in kwtags: final_opt = kwtags.pop('final_opt', False)
final_opt = kwtags['final_opt'] cleanup = kwtags.pop('cleanup', False)
kwtags.pop('final_opt', None) # An opt should not be final and clean up
else: assert not (final_opt and cleanup)
final_opt = False
super(EquilibriumDB, self).register(name, obj, *tags, **kwtags) super(EquilibriumDB, self).register(name, obj, *tags, **kwtags)
self.__final__[name] = final_opt self.__final__[name] = final_opt
self.__cleanup__[name] = cleanup
def query(self, *tags, **kwtags): def query(self, *tags, **kwtags):
_opts = super(EquilibriumDB, self).query(*tags, **kwtags) _opts = super(EquilibriumDB, self).query(*tags, **kwtags)
final_opts = [o for o in _opts if self.__final__.get(o.name, False)] final_opts = [o for o in _opts if self.__final__.get(o.name, False)]
opts = [o for o in _opts if o not in final_opts] cleanup_opts = [o for o in _opts if self.__cleanup__.get(o.name,
False)]
opts = [o for o in _opts
if o not in final_opts and o not in cleanup_opts]
if len(final_opts) == 0: if len(final_opts) == 0:
final_opts = None final_opts = None
if len(cleanup_opts) == 0:
cleanup_opts = None
return opt.EquilibriumOptimizer( return opt.EquilibriumOptimizer(
opts, opts,
max_use_ratio=config.optdb.max_use_ratio, max_use_ratio=config.optdb.max_use_ratio,
ignore_newtrees=self.ignore_newtrees, ignore_newtrees=self.ignore_newtrees,
failure_callback=opt.NavigatorOptimizer.warn_inplace, failure_callback=opt.NavigatorOptimizer.warn_inplace,
final_optimizers=final_opts) final_optimizers=final_opts,
cleanup_optimizers=cleanup_opts)
class SequenceDB(DB): class SequenceDB(DB):
......
...@@ -47,7 +47,6 @@ from theano.tensor.type import (values_eq_approx_remove_inf, ...@@ -47,7 +47,6 @@ from theano.tensor.type import (values_eq_approx_remove_inf,
from theano.gof.opt import (Optimizer, pre_constant_merge, from theano.gof.opt import (Optimizer, pre_constant_merge,
pre_greedy_local_optimizer) pre_greedy_local_optimizer)
from theano.gof.opt import merge_optimizer
from theano.gof import toolbox from theano.gof import toolbox
from theano.tensor.basic import get_scalar_constant_value, ShapeError, NotScalarConstantError from theano.tensor.basic import get_scalar_constant_value, ShapeError, NotScalarConstantError
from six import StringIO from six import StringIO
...@@ -504,29 +503,6 @@ def register_specialize_device(lopt, *tags, **kwargs): ...@@ -504,29 +503,6 @@ def register_specialize_device(lopt, *tags, **kwargs):
return lopt return lopt
# Register in the canonizer Equilibrium as a local opt the merge opt.
# Without this, as the equilibrium have ignore_newtrees=False, we
# won't merge all nodes if it is set as a global optimizer with
# final_opt=True.
#
# This work due to those properties:
# 1) the EQ will execute first the optimizer that trac all nodes.
# 2) after an local optimization being applied, if the
# current node is still in the graph, it will continue to the next
# local optimizer. So this won't trigger more iteration.
def add_merge_feature(fgraph):
if not hasattr(fgraph, 'merge_feature'):
fgraph.attach_feature(theano.gof.opt.MergeFeature())
@register_canonicalize('fast_compile', 'merge')
@gof.local_optimizer(None, requirements=[add_merge_feature])
def local_merge_optimizer(node):
if node.fgraph.merge_feature.scheduled:
ret = merge_optimizer(node.fgraph)
return ret[5] > 0
##################### #####################
# Dot optimizations # # Dot optimizations #
##################### #####################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论