提交 04750bbf authored 作者: affanv14's avatar affanv14

revamp selection of optimisers to be more flexible

上级 09a5b97e
...@@ -1455,6 +1455,18 @@ AddConfigVar( ...@@ -1455,6 +1455,18 @@ AddConfigVar(
theano.configparser.IntParam(0), theano.configparser.IntParam(0),
in_c_key=False) in_c_key=False)
AddConfigVar('metaopt.optimizer_excluding',
("exclude optimizers with these tags. "
"Separate tags with ':'."),
StrParam(""),
in_c_key=False)
AddConfigVar('metaopt.optimizer_including',
("include optimizers with these tags. "
"Separate tags with ':'."),
StrParam(""),
in_c_key=False)
AddConfigVar('profile', AddConfigVar('profile',
"If VM should collect profile information", "If VM should collect profile information",
BoolParam(False), BoolParam(False),
......
...@@ -40,6 +40,7 @@ from theano.tensor.nnet.ctc import ConnectionistTemporalClassification ...@@ -40,6 +40,7 @@ from theano.tensor.nnet.ctc import ConnectionistTemporalClassification
import theano.tensor.nlinalg as nlinalg import theano.tensor.nlinalg as nlinalg
import theano.tensor.signal.pool as pool import theano.tensor.signal.pool as pool
import theano.tensor.slinalg as slinalg import theano.tensor.slinalg as slinalg
from collections import Counter
from theano.tests.breakpoint import PdbBreakpoint from theano.tests.breakpoint import PdbBreakpoint
...@@ -2068,15 +2069,20 @@ class ConvMetaOptimizer(LocalMetaOptimizer): ...@@ -2068,15 +2069,20 @@ class ConvMetaOptimizer(LocalMetaOptimizer):
return result return result
def get_opts(self, node): def get_opts(self, node):
opts = [opt for opt in self.track_dict[type(node.op)] opts = Counter([opt for opt in self.track_dict[type(node.op)]
if opt in self.tag_dict['default']] if opt in self.tag_dict['default']])
opt_include = config.optimizer_including.split(':') include_tags = config.metaopt.optimizer_including.split(':')
opt_exclude = config.optimizer_excluding.split(':') exclude_tags = config.metaopt.optimizer_excluding.split(':')
for in_opt in opt_include:
opts = opts + [opt for opt in self.track_dict[type(node.op)] for in_opt in include_tags:
if opt in self.tag_dict[in_opt]] opts.update([opt for opt in self.track_dict[type(node.op)]
for ex_opt in opt_exclude: if opt in self.tag_dict[in_opt]])
opts = [opt for opt in opts if opt not in self.tag_dict[ex_opt]]
for ex_opt in exclude_tags:
opts.subtract([opt for opt in self.track_dict[type(node.op)]
if opt in self.tag_dict[ex_opt]])
opts = list(opts + Counter())
return opts return opts
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论