提交 09040cd6 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

optimize abstract => corrmm with padding

It was working for grad, but not forward
上级 05a16b2a
...@@ -437,12 +437,18 @@ def local_conv2d_corrmm(node): ...@@ -437,12 +437,18 @@ def local_conv2d_corrmm(node):
not isinstance(kern.type, CudaNdarrayType)): not isinstance(kern.type, CudaNdarrayType)):
return None return None
if node.op.border_mode in ['full', 'valid']:
border_mode = node.op.border_mode border_mode = node.op.border_mode
subsample = node.op.subsample subsample = node.op.subsample
if (border_mode == 'valid') or (subsample != (1,1)): if (border_mode == 'full') and (subsample == (1, 1)):
# need to flip the kernel for valid convolution if not node.op.filters_flip:
kern = kern[:, :, ::-1, ::-1]
# need to dimshuffle the kernel for full convolution
kern = kern.dimshuffle(1, 0, 2, 3)
# call GpuCorrMM_gradInputs
rval = GpuCorrMM_gradInputs('valid', subsample)(
gpu_contiguous(kern), gpu_contiguous(img))
else:
# need to flip the kernel if necessary
if node.op.filters_flip: if node.op.filters_flip:
kern = kern[:, :, ::-1, ::-1] kern = kern[:, :, ::-1, ::-1]
# By default use GpuCorrMM # By default use GpuCorrMM
...@@ -479,13 +485,8 @@ def local_conv2d_corrmm(node): ...@@ -479,13 +485,8 @@ def local_conv2d_corrmm(node):
gpu_contiguous(img.dimshuffle(1, 0, 2, 3)), gpu_contiguous(img.dimshuffle(1, 0, 2, 3)),
gpu_contiguous(kern.dimshuffle(1, 0, 2, 3)) gpu_contiguous(kern.dimshuffle(1, 0, 2, 3))
).dimshuffle(1, 0, 2, 3)) ).dimshuffle(1, 0, 2, 3))
elif (border_mode == 'full'):
# need to dimshuffle the kernel for full convolution
kern = kern.dimshuffle(1, 0, 2, 3)
# call GpuCorrMM_gradInputs
rval = GpuCorrMM_gradInputs('valid', subsample)(
gpu_contiguous(kern), gpu_contiguous(img))
return [rval] return [rval]
register_specialize_device(local_conv2d_corrmm, 'conv_gemm') register_specialize_device(local_conv2d_corrmm, 'conv_gemm')
@local_optimizer([AbstractConv2d_gradWeights]) @local_optimizer([AbstractConv2d_gradWeights])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论