提交 68bc36ad authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #4983 from nouiz/opt_order

Fix optimization order and jenkins
...@@ -322,7 +322,10 @@ class SequenceDB(DB): ...@@ -322,7 +322,10 @@ class SequenceDB(DB):
def register(self, name, obj, position, *tags): def register(self, name, obj, position, *tags):
super(SequenceDB, self).register(name, obj, *tags) super(SequenceDB, self).register(name, obj, *tags)
if position == 'last': if position == 'last':
self.__position__[name] = max(self.__position__.values()) if len(self.__position__) == 0:
self.__position__[name] = 0
else:
self.__position__[name] = max(self.__position__.values()) + 1
else: else:
assert isinstance(position, (integer_types, float)) assert isinstance(position, (integer_types, float))
self.__position__[name] = position self.__position__[name] = position
...@@ -407,10 +410,25 @@ class LocalGroupDB(DB): ...@@ -407,10 +410,25 @@ class LocalGroupDB(DB):
self.failure_callback = None self.failure_callback = None
self.apply_all_opts = apply_all_opts self.apply_all_opts = apply_all_opts
self.profile = profile self.profile = profile
self.__position__ = {}
def register(self, name, obj, *tags, **kwargs):
super(LocalGroupDB, self).register(name, obj, *tags)
position = kwargs.pop('position', 'last')
if position == 'last':
if len(self.__position__) == 0:
self.__position__[name] = 0
else:
self.__position__[name] = max(self.__position__.values()) + 1
else:
assert isinstance(position, (integer_types, float))
self.__position__[name] = position
def query(self, *tags, **kwtags): def query(self, *tags, **kwtags):
# For the new `useless` optimizer # For the new `useless` optimizer
opts = super(LocalGroupDB, self).query(*tags, **kwtags) opts = list(super(LocalGroupDB, self).query(*tags, **kwtags))
opts.sort(key=lambda obj: (self.__position__[obj.name], obj.name))
ret = opt.LocalOptGroup(*opts, ret = opt.LocalOptGroup(*opts,
apply_all_opts=self.apply_all_opts, apply_all_opts=self.apply_all_opts,
profile=self.profile) profile=self.profile)
......
...@@ -1656,23 +1656,23 @@ register_opt()(conv_groupopt) ...@@ -1656,23 +1656,23 @@ register_opt()(conv_groupopt)
# FFT gets the highest priority (lowest number), but is disabled by default. # FFT gets the highest priority (lowest number), but is disabled by default.
# It can be enabled by including 'conv_fft'. # It can be enabled by including 'conv_fft'.
conv_groupopt.register('conv_fft_valid', local_conv_fft_valid, 10, conv_groupopt.register('conv_fft_valid', local_conv_fft_valid,
'conv_fft') 'conv_fft', position=10)
conv_groupopt.register('conv_fft_full', local_conv_fft_full, 10, conv_groupopt.register('conv_fft_full', local_conv_fft_full,
'conv_fft') 'conv_fft', position=10)
# cuDNN is the second, but only registered if cuDNN is available. # cuDNN is the second, but only registered if cuDNN is available.
# It can be disabled by excluding 'conv_dnn' or 'cudnn'. # It can be disabled by excluding 'conv_dnn' or 'cudnn'.
# We can't check at import if dnn is available, so we must always # We can't check at import if dnn is available, so we must always
# register it. This do not cause problem as if it is not avail, the # register it. This do not cause problem as if it is not avail, the
# opt will do nothing. # opt will do nothing.
conv_groupopt.register('local_conv_dnn', dnn.local_conv_dnn, 20, conv_groupopt.register('local_conv_dnn', dnn.local_conv_dnn,
'conv_dnn', 'conv_dnn',
'fast_compile', 'fast_run', 'cudnn') 'fast_compile', 'fast_run', 'cudnn', position=20)
# The GEMM-based convolution comes last to catch all remaining cases. # The GEMM-based convolution comes last to catch all remaining cases.
# It can be disabled by excluding 'conv_gemm'. # It can be disabled by excluding 'conv_gemm'.
conv_groupopt.register('local_conv_gemm', local_conv_gemm, 30, conv_groupopt.register('local_conv_gemm', local_conv_gemm,
'conv_gemm', 'conv_gemm',
'fast_compile', 'fast_run') 'fast_compile', 'fast_run', positin=30)
class LocalCudaMetaOptimizer(LocalMetaOptimizer): class LocalCudaMetaOptimizer(LocalMetaOptimizer):
...@@ -1733,7 +1733,7 @@ conv_metaopt = ConvMetaOptimizer( ...@@ -1733,7 +1733,7 @@ conv_metaopt = ConvMetaOptimizer(
conv_metaopt.register(dnn.local_conv_dnn_alternative) conv_metaopt.register(dnn.local_conv_dnn_alternative)
# Finally, we register the metaoptimizer as the first optimizer in # Finally, we register the metaoptimizer as the first optimizer in
# conv_groupopt # conv_groupopt
conv_groupopt.register('conv_meta', conv_metaopt, 0) conv_groupopt.register('conv_meta', conv_metaopt, position=0)
@local_optimizer([Conv3D]) @local_optimizer([Conv3D])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论