提交 04750bbf authored 作者: affanv14's avatar affanv14

revamp selection of optimisers to be more flexible

上级 09a5b97e
......@@ -1455,6 +1455,18 @@ AddConfigVar(
theano.configparser.IntParam(0),
in_c_key=False)
AddConfigVar('metaopt.optimizer_excluding',
("exclude optimizers with these tags. "
"Separate tags with ':'."),
StrParam(""),
in_c_key=False)
AddConfigVar('metaopt.optimizer_including',
("include optimizers with these tags. "
"Separate tags with ':'."),
StrParam(""),
in_c_key=False)
AddConfigVar('profile',
"If VM should collect profile information",
BoolParam(False),
......
......@@ -40,6 +40,7 @@ from theano.tensor.nnet.ctc import ConnectionistTemporalClassification
import theano.tensor.nlinalg as nlinalg
import theano.tensor.signal.pool as pool
import theano.tensor.slinalg as slinalg
from collections import Counter
from theano.tests.breakpoint import PdbBreakpoint
......@@ -2068,15 +2069,20 @@ class ConvMetaOptimizer(LocalMetaOptimizer):
return result
def get_opts(self, node):
opts = [opt for opt in self.track_dict[type(node.op)]
if opt in self.tag_dict['default']]
opt_include = config.optimizer_including.split(':')
opt_exclude = config.optimizer_excluding.split(':')
for in_opt in opt_include:
opts = opts + [opt for opt in self.track_dict[type(node.op)]
if opt in self.tag_dict[in_opt]]
for ex_opt in opt_exclude:
opts = [opt for opt in opts if opt not in self.tag_dict[ex_opt]]
opts = Counter([opt for opt in self.track_dict[type(node.op)]
if opt in self.tag_dict['default']])
include_tags = config.metaopt.optimizer_including.split(':')
exclude_tags = config.metaopt.optimizer_excluding.split(':')
for in_opt in include_tags:
opts.update([opt for opt in self.track_dict[type(node.op)]
if opt in self.tag_dict[in_opt]])
for ex_opt in exclude_tags:
opts.subtract([opt for opt in self.track_dict[type(node.op)]
if opt in self.tag_dict[ex_opt]])
opts = list(opts + Counter())
return opts
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论