提交 76714ad0 authored 作者: affanv14's avatar affanv14

add alternative optimizer for cudnn forward pass

上级 9041a214
......@@ -2888,6 +2888,43 @@ def local_abstractconv_cudnn(node):
return local_abstractconv3d_cudnn_graph(node.op, ctx, node.inputs, node.outputs)
@local_optimizer([AbstractConv2d])
def local_abstractconv_cudnn_alternative(node):
if not isinstance(node.op, AbstractConv2d):
return
if version(raises=False) < 6000 and node.op.filter_dilation != (1, 1):
return None
inp1 = node.inputs[0]
inp2 = node.inputs[1]
if not dnn_available(inp1.type.context_name):
return
if node.op.filter_flip:
conv_mode = 'conv'
else:
conv_mode = 'cross'
if node.op.border_mode == 'full':
direction_hint = 'bprop inputs'
elif node.op.border_mode == 'valid':
direction_hint = 'bprop weights'
else:
return None
rval = dnn_conv(inp1, inp2,
border_mode=node.op.border_mode,
subsample=node.op.subsample,
dilation=node.op.filter_dilation,
direction_hint=direction_hint,
conv_mode=conv_mode,
num_groups=node.op.num_groups)
return [rval]
@local_optimizer([AbstractConv2d_gradWeights, AbstractConv3d_gradWeights])
def local_abstractconv_gw_cudnn(node):
ctx = infer_context_name(*node.inputs)
......
......@@ -2522,8 +2522,10 @@ register_opt('fast_compile')(abstractconv_groupopt)
# We import these opts here instead of at the top of this file
# to avoid a circular dependency problem with dnn
from .dnn import (local_abstractconv_cudnn, local_abstractconv_gw_cudnn,
local_abstractconv_gi_cudnn) # noqa: 402
from .dnn import (local_abstractconv_cudnn,
local_abstractconv_gw_cudnn,
local_abstractconv_gi_cudnn, # noqa: 402
local_abstractconv_cudnn_alternative)
abstractconv_groupopt.register('local_abstractconv_dnn',
local_abstractconv_cudnn, 20,
......@@ -2575,6 +2577,7 @@ conv_metaopt.register(abstractconv_groupopt.query(*running_list).opts)
conv_metaopt.register([local_abstractconv_gemm_alternative])
conv_metaopt.register([local_abstractconv_gemm_gradweights_alt])
conv_metaopt.register([local_abstractconv_gradinputs_gemm_alt])
conv_metaopt.register([local_abstractconv_cudnn_alternative])
abstractconv_groupopt.register('conv_metaopt', conv_metaopt, 'conv_meta', position=0)
# Register cuDNN batch normalization implementation
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论