提交 7294d64c authored 作者: f0k's avatar f0k

Make conv_gemm optimizer support strided convolution

上级 837b1ca5
...@@ -1286,13 +1286,12 @@ def local_gpu_downsample_factor_max_grad(node): ...@@ -1286,13 +1286,12 @@ def local_gpu_downsample_factor_max_grad(node):
@local_optimizer([GpuConv]) @local_optimizer([GpuConv])
def local_conv_gemm(node): def local_conv_gemm(node):
if (isinstance(node.op, GpuConv) and if (isinstance(node.op, GpuConv) and
node.op.border_mode in ['full', 'valid'] and node.op.border_mode in ['full', 'valid']):
node.op.subsample == (1, 1)):
img, kern = node.inputs img, kern = node.inputs
img = gpu_contiguous(img) img = gpu_contiguous(img)
kern = kern[:, :, ::-1, ::-1] kern = kern[:, :, ::-1, ::-1]
kern = gpu_contiguous(kern) kern = gpu_contiguous(kern)
return [GpuCorrMM(node.op.border_mode)(img, kern)] return [GpuCorrMM(node.op.border_mode, node.op.subsample)(img, kern)]
gpu_optimizer.register("conv_gemm", local_conv_gemm) gpu_optimizer.register("conv_gemm", local_conv_gemm)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论