提交 44d464c0 authored 作者: affanv14's avatar affanv14

add support for excluding and including opts

上级 647903ff
......@@ -1136,13 +1136,15 @@ class LocalMetaOptimizer(LocalOptimizer):
self.optimizers = list(optimizers)
self.verbose = config.metaopt.verbose
self.track_dict = defaultdict(lambda: [])
self.tag_dict = defaultdict(lambda: [])
def register(self, optimizers):
self.optimizers.extend(optimizers)
for o in optimizers:
for c in o.tracks():
self.track_dict[c].append(o)
self._tracks.append(c)
def register(self, optimizer, tag_list):
self.optimizers.append(optimizer)
for c in optimizer.tracks():
self.track_dict[c].append(optimizer)
self._tracks.append(c)
for tag in tag_list:
self.tag_dict[tag].append(optimizer)
def tracks(self):
return self._tracks
......@@ -1181,9 +1183,9 @@ class LocalMetaOptimizer(LocalOptimizer):
# compile the resulting subgraphs and time their execution
if self.verbose > 1:
print(("%s meta-optimizing %s (%d choices):" %
(self.__class__.__name__, node, len(self.track_dict[type(node.op)]))))
(self.__class__.__name__, node, len(self.get_opts(node)))))
timings = []
for opt in self.track_dict[type(node.op)]:
for opt in self.get_opts(node):
outputs = opt.transform(node)
if outputs:
try:
......@@ -1218,6 +1220,12 @@ class LocalMetaOptimizer(LocalOptimizer):
"""
raise NotImplementedError()
def get_opts(self, node):
"""
Can be overrided to change the way opts are selected
"""
return self.track_dict[type(node.op)]
def time_call(self, fn):
start = time.time()
fn()
......
......@@ -2067,6 +2067,18 @@ class ConvMetaOptimizer(LocalMetaOptimizer):
return result
def get_opts(self, node):
opts = [opt for opt in self.track_dict[type(node.op)]
if opt in self.tag_dict['default']]
opt_include = config.optimizer_including.split(':')
opt_exclude = config.optimizer_excluding.split(':')
for in_opt in opt_include:
opts = opts + [opt for opt in self.track_dict[type(node.op)]
if opt in self.tag_dict[in_opt]]
for ex_opt in opt_exclude:
opts = [opt for opt in opts if opt not in self.tag_dict[ex_opt]]
return opts
# This deals with any abstract convs that have a transfer somewhere
@register_opt('fast_compile', 'conv_dnn', 'cudnn')
......@@ -2690,24 +2702,44 @@ abstractconv_groupopt.register('local_abstractconv3d_gradinputs',
'gpuarray', 'fast_compile', 'fast_run')
conv_metaopt = ConvMetaOptimizer()
running_list = ['+fast_run' if config.mode == 'Mode' else '+' + config.mode]
if config.optimizer_including:
running_list += ['+' + name for name in config.optimizer_including.split(':')]
if config.optimizer_excluding:
running_list += ['-' + name for name in config.optimizer_excluding.split(':')]
conv_metaopt.register(abstractconv_groupopt.query(*running_list).opts)
conv_metaopt.register([local_abstractconv_gemm_alternative])
conv_metaopt.register([local_abstractconv_gemm_gradweights_alt])
conv_metaopt.register([local_abstractconv_gradinputs_gemm_alt])
conv_metaopt.register([local_abstractconv_cudnn_alternative])
conv_metaopt.register([local_abstractconv3d2d])
conv_metaopt.register([local_abstractconv3d_alt])
conv_metaopt.register([local_abstractconv3d_gemm_gradweights_alt])
conv_metaopt.register([local_abstractconv3d_gradinputs_gemm_alt])
conv_metaopt.register([local_abstractconv3d_cudnn_alternative])
conv_metaopt.register(local_abstractconv_cudnn,
['default', 'cudnn', 'conv_dnn'])
conv_metaopt.register(local_abstractconv_gw_cudnn,
['default', 'cudnn', 'conv_dnn'])
conv_metaopt.register(local_abstractconv_gi_cudnn,
['default', 'cudnn', 'conv_dnn'])
conv_metaopt.register(local_abstractconv_gemm,
['default', 'conv_gemm'])
conv_metaopt.register(local_abstractconv3d_gemm,
['default', 'conv_gemm'])
conv_metaopt.register(local_abstractconv_gradweights_gemm,
['default', 'conv_gemm'])
conv_metaopt.register(local_abstractconv3d_gradweights_gemm,
['default', 'conv_gemm'])
conv_metaopt.register(local_abstractconv_gradinputs_gemm,
['default', 'conv_gemm'])
conv_metaopt.register(local_abstractconv3d_gradinputs_gemm,
['default', 'conv_gemm'])
conv_metaopt.register(local_abstractconv_gemm_alternative,
['default', 'alternative', 'conv_gemm'])
conv_metaopt.register(local_abstractconv_gemm_gradweights_alt,
['default', 'alternative', 'conv_gemm'])
conv_metaopt.register(local_abstractconv_gradinputs_gemm_alt,
['default', 'alternative', 'conv_gemm'])
conv_metaopt.register(local_abstractconv_cudnn_alternative,
['default', 'alternative', 'cudnn', 'conv_dnn'])
conv_metaopt.register(local_abstractconv3d_cudnn_alternative,
['default', 'alternative', 'cudnn', 'conv_dnn'])
conv_metaopt.register(local_abstractconv3d_alt,
['default', 'alternative', 'conv_gemm'])
conv_metaopt.register(local_abstractconv3d_gemm_gradweights_alt,
['default', 'alternative', 'conv_gemm'])
conv_metaopt.register(local_abstractconv3d_gradinputs_gemm_alt,
['default', 'alternative', 'conv_gemm'])
conv_metaopt.register(local_abstractconv3d2d,
['alternative', 'conv3d2d'])
abstractconv_groupopt.register('conv_metaopt', conv_metaopt, 'conv_meta', position=0)
# Register cuDNN batch normalization implementation
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论