提交 b540b158 authored 作者: affanv14's avatar affanv14

add cudnn conv forward alternative

上级 cd505659
...@@ -2993,6 +2993,50 @@ def local_abstractconv_cudnn_alternative(node): ...@@ -2993,6 +2993,50 @@ def local_abstractconv_cudnn_alternative(node):
return [rval] return [rval]
@local_optimizer([AbstractConv3d, AbstractConv3d_gradWeights, AbstractConv3d_gradInputs])
def local_abstractconv3d_cudnn_alternative(node):
if(not isinstance(node.op, (AbstractConv3d,
AbstractConv3d_gradWeights,
AbstractConv3d_gradInputs))):
return
if version(raises=False) < 6000 and node.op.filter_dilation != (1, 1, 1):
return None
inp1 = node.inputs[0]
inp2 = node.inputs[1]
if not dnn_available(inp1.type.context_name):
return
op = node.op
border_mode = node.op.border_mode
subsample = node.op.subsample
filter_dilation = node.op.filter_dilation
if node.op.filter_flip:
conv_mode = 'conv'
else:
conv_mode = 'cross'
if isinstance(op, AbstractConv3d):
if border_mode == 'half' or subsample != (1, 1, 1):
return None
if border_mode == 'full':
direction_hint = 'bprop inputs'
elif border_mode == 'valid' and filter_dilation == (1, 1, 1):
direction_hint = 'bprop weights'
else:
return None
rval = dnn_conv3d(inp1, inp2,
border_mode=border_mode,
subsample=subsample,
dilation=filter_dilation,
direction_hint=direction_hint,
conv_mode=conv_mode)
return rval
@local_optimizer([AbstractConv2d_gradWeights, AbstractConv3d_gradWeights]) @local_optimizer([AbstractConv2d_gradWeights, AbstractConv3d_gradWeights])
def local_abstractconv_gw_cudnn(node): def local_abstractconv_gw_cudnn(node):
ctx = infer_context_name(*node.inputs) ctx = infer_context_name(*node.inputs)
......
...@@ -2649,7 +2649,8 @@ register_opt('fast_compile')(abstractconv_groupopt) ...@@ -2649,7 +2649,8 @@ register_opt('fast_compile')(abstractconv_groupopt)
from .dnn import (local_abstractconv_cudnn, from .dnn import (local_abstractconv_cudnn,
local_abstractconv_gw_cudnn, local_abstractconv_gw_cudnn,
local_abstractconv_gi_cudnn, # noqa: 402 local_abstractconv_gi_cudnn, # noqa: 402
local_abstractconv_cudnn_alternative) local_abstractconv_cudnn_alternative,
local_abstractconv3d_cudnn_alternative)
abstractconv_groupopt.register('local_abstractconv_dnn', abstractconv_groupopt.register('local_abstractconv_dnn',
local_abstractconv_cudnn, 20, local_abstractconv_cudnn, 20,
...@@ -2706,6 +2707,7 @@ conv_metaopt.register([local_abstractconv3d2d]) ...@@ -2706,6 +2707,7 @@ conv_metaopt.register([local_abstractconv3d2d])
conv_metaopt.register([local_abstractconv3d_alt]) conv_metaopt.register([local_abstractconv3d_alt])
conv_metaopt.register([local_abstractconv3d_gemm_gradweights_alt]) conv_metaopt.register([local_abstractconv3d_gemm_gradweights_alt])
conv_metaopt.register([local_abstractconv3d_gradinputs_gemm_alt]) conv_metaopt.register([local_abstractconv3d_gradinputs_gemm_alt])
conv_metaopt.register([local_abstractconv3d_cudnn_alternative])
abstractconv_groupopt.register('conv_metaopt', conv_metaopt, 'conv_meta', position=0) abstractconv_groupopt.register('conv_metaopt', conv_metaopt, 'conv_meta', position=0)
# Register cuDNN batch normalization implementation # Register cuDNN batch normalization implementation
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论