提交 48a87d25 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Use the fft version of convolution for the valid mode.

上级 9aa65181
...@@ -32,6 +32,7 @@ from theano.sandbox.cuda.blas import gpu_ger_inplace ...@@ -32,6 +32,7 @@ from theano.sandbox.cuda.blas import gpu_ger_inplace
from theano.sandbox.cuda.blas import gpu_ger_no_inplace from theano.sandbox.cuda.blas import gpu_ger_no_inplace
from theano.sandbox.cuda.blas import (GpuDownsampleFactorMax, from theano.sandbox.cuda.blas import (GpuDownsampleFactorMax,
GpuDownsampleFactorMaxGrad) GpuDownsampleFactorMaxGrad)
from theano.sandbox.cuda.fftconv import conv2d_fft
from theano.sandbox.cuda.nnet import ( from theano.sandbox.cuda.nnet import (
GpuCrossentropySoftmaxArgmax1HotWithBias, GpuCrossentropySoftmaxArgmax1HotWithBias,
GpuCrossentropySoftmax1HotWithBiasDx, GpuCrossentropySoftmax1HotWithBiasDx,
...@@ -1118,6 +1119,17 @@ def local_gpu_conv(node): ...@@ -1118,6 +1119,17 @@ def local_gpu_conv(node):
# differently then the gpu ConvOp # differently then the gpu ConvOp
return [out] return [out]
@register_opt()
@local_optimizer([GpuConv])
def local_conv_fft(node):
if (isinstance(node.op, GpuConv) and
node.op.border_mode == 'valid'):
return [conv2d_fft(node.inputs[0], node.inputs[1],
image_shape=node.op.imgshp,
filter_shape=node.op.kshp)]
import theano.tensor.signal.downsample as downsample import theano.tensor.signal.downsample as downsample
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论