提交 4cbd9ff5 authored 作者: Frederic's avatar Frederic

Better interface for cleanup opt

上级 54482de9
......@@ -200,6 +200,12 @@ optdb.register('merge1', gof.MergeOptimizer(),
# rearranges elemwise expressions
optdb.register('canonicalize', gof.EquilibriumDB(ignore_newtrees=False),
1, 'fast_run', 'fast_compile')
# Register in the canonizer Equilibrium as a local opt the merge opt.
# Without this, as the equilibrium have ignore_newtrees=False, we
# won't merge all nodes if it is set as a global optimizer with
# final_opt=True.
optdb['canonicalize'].register("merge", gof.opt.merge_optimizer, 'fast_run',
"fast_compile", cleanup=True)
optdb.register('merge1.2', gof.MergeOptimizer(),
1.2, 'fast_run', 'fast_compile', 'merge')
......
差异被折叠。
......@@ -265,28 +265,35 @@ class EquilibriumDB(DB):
super(EquilibriumDB, self).__init__()
self.ignore_newtrees = ignore_newtrees
self.__final__ = {}
self.__cleanup__ = {}
def register(self, name, obj, *tags, **kwtags):
if 'final_opt' in kwtags:
final_opt = kwtags['final_opt']
kwtags.pop('final_opt', None)
else:
final_opt = False
final_opt = kwtags.pop('final_opt', False)
cleanup = kwtags.pop('cleanup', False)
# An opt should not be final and clean up
assert not (final_opt and cleanup)
super(EquilibriumDB, self).register(name, obj, *tags, **kwtags)
self.__final__[name] = final_opt
self.__cleanup__[name] = cleanup
def query(self, *tags, **kwtags):
_opts = super(EquilibriumDB, self).query(*tags, **kwtags)
final_opts = [o for o in _opts if self.__final__.get(o.name, False)]
opts = [o for o in _opts if o not in final_opts]
cleanup_opts = [o for o in _opts if self.__cleanup__.get(o.name,
False)]
opts = [o for o in _opts
if o not in final_opts and o not in cleanup_opts]
if len(final_opts) == 0:
final_opts = None
if len(cleanup_opts) == 0:
cleanup_opts = None
return opt.EquilibriumOptimizer(
opts,
max_use_ratio=config.optdb.max_use_ratio,
ignore_newtrees=self.ignore_newtrees,
failure_callback=opt.NavigatorOptimizer.warn_inplace,
final_optimizers=final_opts)
final_optimizers=final_opts,
cleanup_optimizers=cleanup_opts)
class SequenceDB(DB):
......
......@@ -47,7 +47,6 @@ from theano.tensor.type import (values_eq_approx_remove_inf,
from theano.gof.opt import (Optimizer, pre_constant_merge,
pre_greedy_local_optimizer)
from theano.gof.opt import merge_optimizer
from theano.gof import toolbox
from theano.tensor.basic import get_scalar_constant_value, ShapeError, NotScalarConstantError
from six import StringIO
......@@ -504,29 +503,6 @@ def register_specialize_device(lopt, *tags, **kwargs):
return lopt
# Register in the canonizer Equilibrium as a local opt the merge opt.
# Without this, as the equilibrium have ignore_newtrees=False, we
# won't merge all nodes if it is set as a global optimizer with
# final_opt=True.
#
# This work due to those properties:
# 1) the EQ will execute first the optimizer that trac all nodes.
# 2) after an local optimization being applied, if the
# current node is still in the graph, it will continue to the next
# local optimizer. So this won't trigger more iteration.
def add_merge_feature(fgraph):
if not hasattr(fgraph, 'merge_feature'):
fgraph.attach_feature(theano.gof.opt.MergeFeature())
@register_canonicalize('fast_compile', 'merge')
@gof.local_optimizer(None, requirements=[add_merge_feature])
def local_merge_optimizer(node):
if node.fgraph.merge_feature.scheduled:
ret = merge_optimizer(node.fgraph)
return ret[5] > 0
#####################
# Dot optimizations #
#####################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论