提交 cd505659 authored 作者: affanv14's avatar affanv14

add GPUCorrMM conv3d gradinputs alternative

上级 0a7fa210
...@@ -1980,6 +1980,31 @@ def local_abstractconv3d_gradinputs_gemm(node): ...@@ -1980,6 +1980,31 @@ def local_abstractconv3d_gradinputs_gemm(node):
return [rval] return [rval]
@local_optimizer([AbstractConv3d_gradInputs])
def local_abstractconv3d_gradinputs_gemm_alt(node):
if not isinstance(node.op, AbstractConv3d_gradInputs):
return None
kern, topgrad, shape = node.inputs
if not isinstance(kern.type, GpuArrayType) or \
not isinstance(topgrad.type, GpuArrayType):
return None
border_mode = node.op.border_mode
subsample = node.op.subsample
filter_dilation = node.op.filter_dilation
if border_mode == 'valid' and subsample == (1, 1, 1):
if not node.op.filter_flip:
kern = kern[:, :, ::-1, ::-1, ::-1]
rval = GpuCorr3dMM(border_mode='full',
subsample=subsample,
filter_dilation=filter_dilation)(
gpu_contiguous(topgrad),
gpu_contiguous(kern.dimshuffle(1, 0, 2, 3, 4)))
return [rval]
else:
return None
class ConvMetaOptimizer(LocalMetaOptimizer): class ConvMetaOptimizer(LocalMetaOptimizer):
def __init__(self, optimizers=()): def __init__(self, optimizers=()):
...@@ -2680,6 +2705,7 @@ conv_metaopt.register([local_abstractconv_cudnn_alternative]) ...@@ -2680,6 +2705,7 @@ conv_metaopt.register([local_abstractconv_cudnn_alternative])
conv_metaopt.register([local_abstractconv3d2d]) conv_metaopt.register([local_abstractconv3d2d])
conv_metaopt.register([local_abstractconv3d_alt]) conv_metaopt.register([local_abstractconv3d_alt])
conv_metaopt.register([local_abstractconv3d_gemm_gradweights_alt]) conv_metaopt.register([local_abstractconv3d_gemm_gradweights_alt])
conv_metaopt.register([local_abstractconv3d_gradinputs_gemm_alt])
abstractconv_groupopt.register('conv_metaopt', conv_metaopt, 'conv_meta', position=0) abstractconv_groupopt.register('conv_metaopt', conv_metaopt, 'conv_meta', position=0)
# Register cuDNN batch normalization implementation # Register cuDNN batch normalization implementation
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论