提交 4cbd9ff5 authored 作者: Frederic's avatar Frederic

Better interface for cleanup opt

上级 54482de9
...@@ -200,6 +200,12 @@ optdb.register('merge1', gof.MergeOptimizer(), ...@@ -200,6 +200,12 @@ optdb.register('merge1', gof.MergeOptimizer(),
# rearranges elemwise expressions # rearranges elemwise expressions
optdb.register('canonicalize', gof.EquilibriumDB(ignore_newtrees=False), optdb.register('canonicalize', gof.EquilibriumDB(ignore_newtrees=False),
1, 'fast_run', 'fast_compile') 1, 'fast_run', 'fast_compile')
# Register in the canonizer Equilibrium as a local opt the merge opt.
# Without this, as the equilibrium have ignore_newtrees=False, we
# won't merge all nodes if it is set as a global optimizer with
# final_opt=True.
optdb['canonicalize'].register("merge", gof.opt.merge_optimizer, 'fast_run',
"fast_compile", cleanup=True)
optdb.register('merge1.2', gof.MergeOptimizer(), optdb.register('merge1.2', gof.MergeOptimizer(),
1.2, 'fast_run', 'fast_compile', 'merge') 1.2, 'fast_run', 'fast_compile', 'merge')
......
差异被折叠。
...@@ -265,28 +265,35 @@ class EquilibriumDB(DB): ...@@ -265,28 +265,35 @@ class EquilibriumDB(DB):
super(EquilibriumDB, self).__init__() super(EquilibriumDB, self).__init__()
self.ignore_newtrees = ignore_newtrees self.ignore_newtrees = ignore_newtrees
self.__final__ = {} self.__final__ = {}
self.__cleanup__ = {}
def register(self, name, obj, *tags, **kwtags): def register(self, name, obj, *tags, **kwtags):
if 'final_opt' in kwtags: final_opt = kwtags.pop('final_opt', False)
final_opt = kwtags['final_opt'] cleanup = kwtags.pop('cleanup', False)
kwtags.pop('final_opt', None) # An opt should not be final and clean up
else: assert not (final_opt and cleanup)
final_opt = False
super(EquilibriumDB, self).register(name, obj, *tags, **kwtags) super(EquilibriumDB, self).register(name, obj, *tags, **kwtags)
self.__final__[name] = final_opt self.__final__[name] = final_opt
self.__cleanup__[name] = cleanup
def query(self, *tags, **kwtags): def query(self, *tags, **kwtags):
_opts = super(EquilibriumDB, self).query(*tags, **kwtags) _opts = super(EquilibriumDB, self).query(*tags, **kwtags)
final_opts = [o for o in _opts if self.__final__.get(o.name, False)] final_opts = [o for o in _opts if self.__final__.get(o.name, False)]
opts = [o for o in _opts if o not in final_opts] cleanup_opts = [o for o in _opts if self.__cleanup__.get(o.name,
False)]
opts = [o for o in _opts
if o not in final_opts and o not in cleanup_opts]
if len(final_opts) == 0: if len(final_opts) == 0:
final_opts = None final_opts = None
if len(cleanup_opts) == 0:
cleanup_opts = None
return opt.EquilibriumOptimizer( return opt.EquilibriumOptimizer(
opts, opts,
max_use_ratio=config.optdb.max_use_ratio, max_use_ratio=config.optdb.max_use_ratio,
ignore_newtrees=self.ignore_newtrees, ignore_newtrees=self.ignore_newtrees,
failure_callback=opt.NavigatorOptimizer.warn_inplace, failure_callback=opt.NavigatorOptimizer.warn_inplace,
final_optimizers=final_opts) final_optimizers=final_opts,
cleanup_optimizers=cleanup_opts)
class SequenceDB(DB): class SequenceDB(DB):
......
...@@ -47,7 +47,6 @@ from theano.tensor.type import (values_eq_approx_remove_inf, ...@@ -47,7 +47,6 @@ from theano.tensor.type import (values_eq_approx_remove_inf,
from theano.gof.opt import (Optimizer, pre_constant_merge, from theano.gof.opt import (Optimizer, pre_constant_merge,
pre_greedy_local_optimizer) pre_greedy_local_optimizer)
from theano.gof.opt import merge_optimizer
from theano.gof import toolbox from theano.gof import toolbox
from theano.tensor.basic import get_scalar_constant_value, ShapeError, NotScalarConstantError from theano.tensor.basic import get_scalar_constant_value, ShapeError, NotScalarConstantError
from six import StringIO from six import StringIO
...@@ -504,29 +503,6 @@ def register_specialize_device(lopt, *tags, **kwargs): ...@@ -504,29 +503,6 @@ def register_specialize_device(lopt, *tags, **kwargs):
return lopt return lopt
# Register in the canonizer Equilibrium as a local opt the merge opt.
# Without this, as the equilibrium have ignore_newtrees=False, we
# won't merge all nodes if it is set as a global optimizer with
# final_opt=True.
#
# This work due to those properties:
# 1) the EQ will execute first the optimizer that trac all nodes.
# 2) after an local optimization being applied, if the
# current node is still in the graph, it will continue to the next
# local optimizer. So this won't trigger more iteration.
def add_merge_feature(fgraph):
if not hasattr(fgraph, 'merge_feature'):
fgraph.attach_feature(theano.gof.opt.MergeFeature())
@register_canonicalize('fast_compile', 'merge')
@gof.local_optimizer(None, requirements=[add_merge_feature])
def local_merge_optimizer(node):
if node.fgraph.merge_feature.scheduled:
ret = merge_optimizer(node.fgraph)
return ret[5] > 0
##################### #####################
# Dot optimizations # # Dot optimizations #
##################### #####################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论