提交 b14e9943 authored 作者: Nicolas Ballas's avatar Nicolas Ballas

Move intermediate 3dfft function inside the opt

上级 e67f96cd
...@@ -1256,12 +1256,19 @@ def local_conv_fft_full(node): ...@@ -1256,12 +1256,19 @@ def local_conv_fft_full(node):
gpu_optimizer.register("conv_fft_valid", local_conv_fft_valid) gpu_optimizer.register("conv_fft_valid", local_conv_fft_valid)
gpu_optimizer.register("conv_fft_full", local_conv_fft_full) gpu_optimizer.register("conv_fft_full", local_conv_fft_full)
from theano.sandbox.cuda.GpuConv3D import GpuConv3D from theano.sandbox.cuda.GpuConv3D import GpuConv3D
def _gpu_conv3d_to_fftconv(node): @local_optimizer([GpuConv3D])
def local_conv3d_fft(node):
try:
stride_x = tensor.get_scalar_constant_value(node.inputs[3][0])
stride_y = tensor.get_scalar_constant_value(node.inputs[3][1])
stride_z = tensor.get_scalar_constant_value(node.inputs[3][2])
except tensor.NotScalarConstantError:
return False
if (isinstance(node.op, GpuConv3D) and
(stride_x, stride_y, stride_z) == (1, 1, 1)):
# we import conv3d_fft locally to avoid pycuda warnings # we import conv3d_fft locally to avoid pycuda warnings
from theano.sandbox.cuda.fftconv import conv3d_fft from theano.sandbox.cuda.fftconv import conv3d_fft
# Shuffle inputs signal from (b, 0, 1, t, c) to (b, c, 0, 1, t) # Shuffle inputs signal from (b, 0, 1, t, c) to (b, c, 0, 1, t)
x = node.inputs[0] x = node.inputs[0]
x = gpu_from_host(x.dimshuffle(0, 4, 1, 2, 3)) x = gpu_from_host(x.dimshuffle(0, 4, 1, 2, 3))
...@@ -1271,30 +1278,21 @@ def _gpu_conv3d_to_fftconv(node): ...@@ -1271,30 +1278,21 @@ def _gpu_conv3d_to_fftconv(node):
rval = conv3d_fft(x, f) rval = conv3d_fft(x, f)
# Shuffle from (oc, c, 0, 1, t) to (oc, 0, 1, t, c) # Shuffle from (oc, c, 0, 1, t) to (oc, 0, 1, t, c)
rval = gpu_from_host(rval.dimshuffle(0, 2, 3, 4, 1) + node.inputs[2]) rval = gpu_from_host(rval.dimshuffle(0, 2, 3, 4, 1) + node.inputs[2])
return rval return rval
@local_optimizer([GpuConv3D]) gpu_optimizer.register("conv3d_fft", local_conv3d_fft)
def local_conv3d_fft(node):
from theano.sandbox.cuda.GpuConvGrad3D import GpuConvGrad3D
@local_optimizer([GpuConvGrad3D])
def local_convgrad3d_fft(node):
try: try:
stride_x = tensor.get_scalar_constant_value(node.inputs[3][0]) stride_x = tensor.get_scalar_constant_value(node.inputs[1][0])
stride_y = tensor.get_scalar_constant_value(node.inputs[3][1]) stride_y = tensor.get_scalar_constant_value(node.inputs[1][1])
stride_z = tensor.get_scalar_constant_value(node.inputs[3][2]) stride_z = tensor.get_scalar_constant_value(node.inputs[1][2])
except tensor.NotScalarConstantError: except tensor.NotScalarConstantError:
pass return False
if (isinstance(node.op, GpuConv3D) and if (isinstance(node.op, GpuConvGrad3D) and
(stride_x, stride_y, stride_z) == (1, 1, 1)): (stride_x, stride_y, stride_z) == (1, 1, 1)):
return [_gpu_conv3d_to_fftconv(node)]
gpu_optimizer.register("conv3d_fft", local_conv3d_fft)
from theano.sandbox.cuda.GpuConvGrad3D import GpuConvGrad3D
def _gpu_convgrad3d_to_fftconv(node):
# we import conv3d_fft locally to avoid pycuda warnings
from theano.sandbox.cuda.fftconv import conv3d_fft
# Shuffle inputs signal from (b, 0, 1, t, ic) to (ic, b, 0, 1, t) # Shuffle inputs signal from (b, 0, 1, t, ic) to (ic, b, 0, 1, t)
x = node.inputs[0] x = node.inputs[0]
x = x.dimshuffle(4, 0, 1, 2, 3) x = x.dimshuffle(4, 0, 1, 2, 3)
...@@ -1304,30 +1302,23 @@ def _gpu_convgrad3d_to_fftconv(node): ...@@ -1304,30 +1302,23 @@ def _gpu_convgrad3d_to_fftconv(node):
rval = conv3d_fft(x, f) rval = conv3d_fft(x, f)
# Shuffle from (ic, oc, 0, 1, t) to (oc, 0, 1, t, ic) # Shuffle from (ic, oc, 0, 1, t) to (oc, 0, 1, t, ic)
rval = gpu_from_host(rval.dimshuffle(1, 2, 3, 4, 0)) rval = gpu_from_host(rval.dimshuffle(1, 2, 3, 4, 0))
return rval return rval
@local_optimizer([GpuConvGrad3D]) gpu_optimizer.register("convgrad3d_fft", local_convgrad3d_fft)
def local_convgrad3d_fft(node):
from theano.sandbox.cuda.GpuConvTransp3D import GpuConvTransp3D
@local_optimizer([GpuConvTransp3D])
def local_convtransp3d_fft(node):
try: try:
stride_x = tensor.get_scalar_constant_value(node.inputs[3][0]) stride_x = tensor.get_scalar_constant_value(node.inputs[2][0])
stride_y = tensor.get_scalar_constant_value(node.inputs[3][1]) stride_y = tensor.get_scalar_constant_value(node.inputs[2][1])
stride_z = tensor.get_scalar_constant_value(node.inputs[3][2]) stride_z = tensor.get_scalar_constant_value(node.inputs[2][2])
except tensor.NotScalarConstantError: except tensor.NotScalarConstantError:
pass return False
if (isinstance(node.op, GpuConvGrad3D) and if (isinstance(node.op, GpuConvTransp3D) and
(stride_x, stride_y, stride_z) == (1, 1, 1)): (stride_x, stride_y, stride_z) == (1, 1, 1)):
return [_gpu_convgrad3d_to_fftconv(node)]
gpu_optimizer.register("convgrad3d_fft", local_convgrad3d_fft)
from theano.sandbox.cuda.GpuConvTransp3D import GpuConvTransp3D
def _gpu_convtransp3d_to_fftconv(node):
# we import conv3d_fft locally to avoid pycuda warnings # we import conv3d_fft locally to avoid pycuda warnings
from theano.sandbox.cuda.fftconv import conv3d_fft from theano.sandbox.cuda.fftconv import conv3d_fft
# Shuffle filters from (oc, 0, 1, t, ic) to (ic, oc, 0, 1, t) # Shuffle filters from (oc, 0, 1, t, ic) to (ic, oc, 0, 1, t)
x = node.inputs[0] x = node.inputs[0]
x = x.dimshuffle(4, 0, 1, 2, 3) x = x.dimshuffle(4, 0, 1, 2, 3)
...@@ -1337,22 +1328,8 @@ def _gpu_convtransp3d_to_fftconv(node): ...@@ -1337,22 +1328,8 @@ def _gpu_convtransp3d_to_fftconv(node):
rval = conv3d_fft(f, x, border_mode='full') rval = conv3d_fft(f, x, border_mode='full')
# Shuffle from (ic, b, 0, 1, t) to (b, 0, 1, t, ic) # Shuffle from (ic, b, 0, 1, t) to (b, 0, 1, t, ic)
rval = gpu_from_host(rval.dimshuffle(0, 2, 3, 4, 1)) rval = gpu_from_host(rval.dimshuffle(0, 2, 3, 4, 1))
return rval return rval
@local_optimizer([GpuConvTransp3D])
def local_convtransp3d_fft(node):
try:
stride_x = tensor.get_scalar_constant_value(node.inputs[3][0])
stride_y = tensor.get_scalar_constant_value(node.inputs[3][1])
stride_z = tensor.get_scalar_constant_value(node.inputs[3][2])
except tensor.NotScalarConstantError:
pass
if (isinstance(node.op, GpuConvTransp3D) and
(stride_x, stride_y, stride_z) == (1, 1, 1)):
return [_gpu_convtransp3d_to_fftconv(node)]
gpu_optimizer.register("convtransp3d_fft", local_convtransp3d_fft) gpu_optimizer.register("convtransp3d_fft", local_convtransp3d_fft)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论