提交 17f7b9ec authored 作者: f0k's avatar f0k

Add an additional meta-optimizer entry that tries a second code path for conv_dnn

上级 ab6c786e
......@@ -1207,7 +1207,37 @@ if True:
border_mode=border_mode, subsample=subsample,
direction_hint=direction_hint)]
@register_opt('cudnn')
# This optimizer is registered in opt.py as part of the meta-optimizer.
# It tries exactly the opposite code path of what local_conv_dnn() uses,
# because for some input/kernel shape configurations, this is faster.
@local_optimizer([GpuConv])
def local_conv_dnn_alternative(node):
if not dnn_available():
return
if isinstance(node.op, GpuConv):
border_mode = node.op.border_mode
subsample = node.op.subsample
if border_mode not in ['full', 'valid'] or subsample != (1, 1):
return
img, kern = node.inputs
direction_hint = node.op.direction_hint
if border_mode == 'full':
# for a full convolution, try using the forward pass instead
# of the backward pass wrt. inputs
direction_hint = 'forward!'
elif border_mode == 'valid':
# for a valid convolution, try using the backward pass wrt.
# weights instead of the forward pass and vice versa
if direction_hint == 'bprop weights':
direction_hint = 'forward'
else:
direction_hint = 'bprop weights'
return [dnn_conv(img, kern,
border_mode=border_mode, subsample=subsample,
direction_hint=direction_hint)]
# DISABLED as there is problems in the handling of borders
# @register_opt('cudnn')
@local_optimizer([GpuDownsampleFactorMax])
def local_pool_dnn(node):
if not dnn_available():
......
......@@ -1395,7 +1395,9 @@ class ConvMetaOptimizer(LocalCudaMetaOptimizer):
# We just register all optimizers from conv_groupopt with the metaoptimizer
conv_metaopt = ConvMetaOptimizer(
conv_groupopt.query(*['+' + name for name in conv_groupopt._names]).opts)
# And then register the metaoptimizer as the first optimizer in conv_groupopt
# Then we add some optimizers that try less obvious options
conv_metaopt.register(dnn.local_conv_dnn_alternative)
# Finally, we register the metaoptimizer as the first optimizer in conv_groupopt
conv_groupopt.register('conv_meta', conv_metaopt, 0)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论