提交 fa8bd4ff authored 作者: affanv14's avatar affanv14

add conv3d gemm forward alternative

上级 feba3ecb
......@@ -1739,6 +1739,45 @@ def local_abstractconv3d_gemm(node):
return [rval]
@local_optimizer([AbstractConv3d])
def local_abstractconv3d_alt(node):
if not isinstance(node.op, AbstractConv3d):
return None
img, kern = node.inputs
if (not isinstance(img.type, GpuArrayType) or
not isinstance(kern.type, GpuArrayType)):
return None
ctx = infer_context_name(img, kern)
border_mode = node.op.border_mode
subsample = node.op.subsample
filter_dilation = node.op.filter_dilation
if ((border_mode == 'full') and (subsample == (1, 1, 1))):
if not node.op.filter_flip:
kern = kern[:, :, ::-1, ::-1, ::-1]
kern = kern.dimshuffle(1, 0, 2, 3, 4)
rval = GpuCorr3dMM_gradInputs('valid',
subsample,
filter_dilation)(
gpu_contiguous(kern), gpu_contiguous(img))
elif(subsample == (1, 1, 1) and filter_dilation == (1, 1, 1) and
border_mode == 'valid'):
if node.op.filter_flip:
kern = kern[:, :, ::-1, ::-1, ::-1]
rval = GpuCorr3dMM_gradWeights(border_mode,
subsample,
filter_dilation)(
gpu_contiguous(img.dimshuffle(1, 0, 2, 3, 4)),
gpu_contiguous(kern.dimshuffle(1, 0, 2, 3, 4)))
rval = as_gpuarray_variable(rval.dimshuffle(1, 0, 2, 3, 4),
context_name=ctx)
else:
return None
return [rval]
@local_optimizer([AbstractConv3d])
def local_abstractconv3d2d(node):
if not isinstance(node.op, AbstractConv3d):
......@@ -2608,6 +2647,7 @@ conv_metaopt.register([local_abstractconv_gemm_gradweights_alt])
conv_metaopt.register([local_abstractconv_gradinputs_gemm_alt])
conv_metaopt.register([local_abstractconv_cudnn_alternative])
conv_metaopt.register([local_abstractconv3d2d])
conv_metaopt.register([local_abstractconv3d_alt])
abstractconv_groupopt.register('conv_metaopt', conv_metaopt, 'conv_meta', position=0)
# Register cuDNN batch normalization implementation
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论