提交 658ec11f authored 作者: Ziye Fan's avatar Ziye Fan

change final_opt to be a keyword argument

上级 50dc6640
...@@ -230,16 +230,18 @@ class EquilibriumDB(DB): ...@@ -230,16 +230,18 @@ class EquilibriumDB(DB):
def register(self, name, obj, *tags, **kwtags): def register(self, name, obj, *tags, **kwtags):
# if name == 'cut_gpua_constant_transfers': # if name == 'cut_gpua_constant_transfers':
# import ipdb;ipdb.set_trace() # import ipdb;ipdb.set_trace()
if 'final_opt' in tags: if 'final_opt' in kwtags:
final_opt = True final_opt = kwtags['final_opt']
kwtags.pop('final_opt', None)
else: else:
final_opt = False final_opt = False
super(EquilibriumDB, self).register(name, obj, *tags, **kwtags) super(EquilibriumDB, self).register(name, obj, *tags, **kwtags)
self.__final__[name] = final_opt self.__final__[name] = final_opt
def query(self, *tags, **kwtags): def query(self, *tags, **kwtags):
opts = super(EquilibriumDB, self).query(*tags, **kwtags) _opts = super(EquilibriumDB, self).query(*tags, **kwtags)
final_opts = [o for o in opts if self.__final__.get(o.name, False)] final_opts = [o for o in _opts if self.__final__.get(o.name, False)]
opts = [o for o in _opts if o not in final_opts]
if len(final_opts) == 0: if len(final_opts) == 0:
final_opts = None final_opts = None
return opt.EquilibriumOptimizer( return opt.EquilibriumOptimizer(
......
...@@ -27,10 +27,11 @@ def register_opt(*tags, **kwargs): ...@@ -27,10 +27,11 @@ def register_opt(*tags, **kwargs):
if any([not isinstance(t, str) for t in tags]): if any([not isinstance(t, str) for t in tags]):
raise RuntimeError("Bad call to register_opt." raise RuntimeError("Bad call to register_opt."
" All tags must be strings.", tags) " All tags must be strings.", tags)
def f(local_opt): def f(local_opt):
name = (kwargs and kwargs.pop('name')) or local_opt.__name__ name = (kwargs and kwargs.pop('name')) or local_opt.__name__
gpu_optimizer.register(name, local_opt, 'fast_run', 'fast_compile', gpu_optimizer.register(name, local_opt, 'fast_run', 'fast_compile',
'gpu', *tags) 'gpu', *tags, **kwargs)
return local_opt return local_opt
return f return f
......
...@@ -94,13 +94,13 @@ optdb.register('gpu_after_fusion', ...@@ -94,13 +94,13 @@ optdb.register('gpu_after_fusion',
# Register merge_optimizer as a global opt # Register merge_optimizer as a global opt
gpu_optimizer.register('gpu_merge', theano.gof.opt.merge_optimizer, gpu_optimizer.register('gpu_merge', theano.gof.opt.merge_optimizer,
'fast_run', 'fast_compile', 'final_opt') 'fast_run', 'fast_compile', final_opt=True)
# register local_track_shape_i at this level too # register local_track_shape_i at this level too
# to make multi-level lift of shape work. # to make multi-level lift of shape work.
register_opt()(theano.tensor.opt.local_track_shape_i) register_opt()(theano.tensor.opt.local_track_shape_i)
register_opt('final_opt', name='gpu_constant_folding')( register_opt(final_opt=True, name='gpu_constant_folding')(
tensor.opt.constant_folding) tensor.opt.constant_folding)
register_opt()(theano.tensor.opt.local_subtensor_make_vector) register_opt()(theano.tensor.opt.local_subtensor_make_vector)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论