提交 bab9eacb authored 作者: Nicolas Ballas's avatar Nicolas Ballas

Update opt to replace Conv3D, ConvGrad3D and ConvTransp3D instead of GpuConv3D,…

Update opt to replace Conv3D, ConvGrad3D and ConvTransp3D instead of GpuConv3D, GpuConvGrad3D and GpuConvTransp3D.
上级 1db4ac0f
......@@ -1256,8 +1256,8 @@ def local_conv_fft_full(node):
gpu_optimizer.register("conv_fft_valid", local_conv_fft_valid)
gpu_optimizer.register("conv_fft_full", local_conv_fft_full)
from theano.sandbox.cuda.GpuConv3D import GpuConv3D
@local_optimizer([GpuConv3D])
from theano.tensor.nnet.Conv3D import Conv3D
@local_optimizer([Conv3D])
def local_conv3d_fft(node):
try:
stride_x = tensor.get_scalar_constant_value(node.inputs[3][0])
......@@ -1265,7 +1265,7 @@ def local_conv3d_fft(node):
stride_z = tensor.get_scalar_constant_value(node.inputs[3][2])
except tensor.NotScalarConstantError:
return False
if (isinstance(node.op, GpuConv3D) and
if (isinstance(node.op, Conv3D) and
(stride_x, stride_y, stride_z) == (1, 1, 1)):
# we import conv3d_fft locally to avoid pycuda warnings
from theano.sandbox.cuda.fftconv import conv3d_fft
......@@ -1277,13 +1277,13 @@ def local_conv3d_fft(node):
f = gpu_from_host(f.dimshuffle(0, 4, 1, 2, 3))
rval = conv3d_fft(x, f, border_mode='valid', pad_last_dim=True)
# Shuffle from (oc, c, 0, 1, t) to (oc, 0, 1, t, c)
return [gpu_from_host(rval.dimshuffle(0, 2, 3, 4, 1) + node.inputs[2])]
return [rval.dimshuffle(0, 2, 3, 4, 1) + node.inputs[2]]
gpu_optimizer.register("conv3d_fft", local_conv3d_fft)
from theano.sandbox.cuda.GpuConvGrad3D import GpuConvGrad3D
@local_optimizer([GpuConvGrad3D])
from theano.tensor.nnet.ConvGrad3D import ConvGrad3D
@local_optimizer([ConvGrad3D])
def local_convgrad3d_fft(node):
try:
stride_x = tensor.get_scalar_constant_value(node.inputs[1][0])
......@@ -1291,7 +1291,7 @@ def local_convgrad3d_fft(node):
stride_z = tensor.get_scalar_constant_value(node.inputs[1][2])
except tensor.NotScalarConstantError:
return False
if (isinstance(node.op, GpuConvGrad3D) and
if (isinstance(node.op, ConvGrad3D) and
(stride_x, stride_y, stride_z) == (1, 1, 1)):
# we import conv3d_fft locally to avoid pycuda warnings
from theano.sandbox.cuda.fftconv import conv3d_fft
......@@ -1303,13 +1303,13 @@ def local_convgrad3d_fft(node):
f = f.dimshuffle(4, 0, 1, 2, 3)
rval = conv3d_fft(x, f, border_mode='valid', pad_last_dim=True)
# Shuffle from (ic, oc, 0, 1, t) to (oc, 0, 1, t, ic)
return [gpu_from_host(rval.dimshuffle(1, 2, 3, 4, 0))]
return [rval.dimshuffle(1, 2, 3, 4, 0)]
gpu_optimizer.register("convgrad3d_fft", local_convgrad3d_fft)
from theano.sandbox.cuda.GpuConvTransp3D import GpuConvTransp3D
@local_optimizer([GpuConvTransp3D])
from theano.tensor.nnet.ConvTransp3D import ConvTransp3D
@local_optimizer([ConvTransp3D])
def local_convtransp3d_fft(node):
try:
stride_x = tensor.get_scalar_constant_value(node.inputs[2][0])
......@@ -1317,7 +1317,7 @@ def local_convtransp3d_fft(node):
stride_z = tensor.get_scalar_constant_value(node.inputs[2][2])
except tensor.NotScalarConstantError:
return False
if (isinstance(node.op, GpuConvTransp3D) and
if (isinstance(node.op, ConvTransp3D) and
(stride_x, stride_y, stride_z) == (1, 1, 1)):
# we import conv3d_fft locally to avoid pycuda warnings
from theano.sandbox.cuda.fftconv import conv3d_fft
......@@ -1329,12 +1329,11 @@ def local_convtransp3d_fft(node):
f = f.dimshuffle(0, 4, 1, 2, 3)
rval = conv3d_fft(f, x, border_mode='full', pad_last_dim=True)
# Shuffle from (ic, b, 0, 1, t) to (b, 0, 1, t, ic)
return [gpu_from_host(rval.dimshuffle(0, 2, 3, 4, 1))]
return [rval.dimshuffle(0, 2, 3, 4, 1)]
gpu_optimizer.register("convtransp3d_fft", local_convtransp3d_fft)
import theano.tensor.signal.downsample as downsample
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论