提交 498f9298 authored 作者: Nicolas Ballas's avatar Nicolas Ballas

[3DConvFFT] add fft optimization for GpuConv3D

上级 c40b1bd9
......@@ -80,6 +80,9 @@ def register_opt(*tags, **kwargs):
return local_opt
return f
#register local_track_shape_i at this level too
#to make multi-level lift of shape work.
register_opt()(theano.tensor.opt.local_track_shape_i)
......@@ -1257,6 +1260,51 @@ gpu_optimizer.register("conv_fft_valid", local_conv_fft_valid)
gpu_optimizer.register("conv_fft_full", local_conv_fft_full)
from theano.sandbox.cuda.GpuConv3D import GpuConv3D
def _gpu_conv3d_to_fftconv(node):
# shared helper function for local_conv_fft_valid and local_conv_fft_full.
# we import conv2d_fft locally to avoid pycuda warnings
from theano.sandbox.cuda.fftconv import conv3d_fft
# Shuffle inputs signal from (b, 0, 1, t, c) to (b, c, 0, 1, t)
x = node.inputs[0]
x = gpu_from_host(x.dimshuffle(0, 4, 1, 2, 3))
# Shuflle filters from (oc, 0, 1, t, ic) to (oc, ic, 0, 1, t)
f = node.inputs[1]
f = gpu_from_host(f.dimshuffle(0, 4, 1, 2, 3))
# TODO: If the user supplied the full nonsymbolic image_shape and
# filter_shape in conv2d(), we could pass it on to conv2d_fft(). However,
# information on batch size and channel counts is currently discarded
# when a ConvOp is replaced by a GpuConv, so this would need more changes.
#if (node.op.imshp is not None) and (None not in node.op.imshp):
# kwargs['image_shape'] = (bsize, inchannels) + node.op.imshp
#if (node.op.kshp is not None) and (None not in node.op.kshp):
# kwargs['filter_shape'] = (outchannels, inchannels) + node.op.kshp
rval = conv3d_fft(x, f)
# Shuffle back (oc, c, 0, 1, t) to (oc, 0, 1, t, c)
rval = gpu_from_host(rval.dimshuffle(0, 2, 3, 4, 1))
# Add biais
#rval = rval + node.inputs[2]
return rval
@local_optimizer([GpuConv3D])
def local_conv3d_fft(node):
if (isinstance(node.op, GpuConv3D)# and
# node.inputs[3] == (1, 1, 1)]):
):
return [_gpu_conv3d_to_fftconv(node)]
gpu_optimizer.register("conv3d_fft", local_conv3d_fft)
import theano.tensor.signal.downsample as downsample
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论