提交 21f30ed4 authored 作者: Nicolas Ballas's avatar Nicolas Ballas

Fix opt

上级 b14e9943
......@@ -1277,8 +1277,8 @@ def local_conv3d_fft(node):
f = gpu_from_host(f.dimshuffle(0, 4, 1, 2, 3))
rval = conv3d_fft(x, f)
# Shuffle from (oc, c, 0, 1, t) to (oc, 0, 1, t, c)
rval = gpu_from_host(rval.dimshuffle(0, 2, 3, 4, 1) + node.inputs[2])
return rval
return [gpu_from_host(rval.dimshuffle(0, 2, 3, 4, 1) + node.inputs[2])]
gpu_optimizer.register("conv3d_fft", local_conv3d_fft)
......@@ -1293,6 +1293,8 @@ def local_convgrad3d_fft(node):
return False
if (isinstance(node.op, GpuConvGrad3D) and
(stride_x, stride_y, stride_z) == (1, 1, 1)):
# we import conv3d_fft locally to avoid pycuda warnings
from theano.sandbox.cuda.fftconv import conv3d_fft
# Shuffle inputs signal from (b, 0, 1, t, ic) to (ic, b, 0, 1, t)
x = node.inputs[0]
x = x.dimshuffle(4, 0, 1, 2, 3)
......@@ -1301,8 +1303,8 @@ def local_convgrad3d_fft(node):
f = f.dimshuffle(4, 0, 1, 2, 3)
rval = conv3d_fft(x, f)
# Shuffle from (ic, oc, 0, 1, t) to (oc, 0, 1, t, ic)
rval = gpu_from_host(rval.dimshuffle(1, 2, 3, 4, 0))
return rval
return [gpu_from_host(rval.dimshuffle(1, 2, 3, 4, 0))]
gpu_optimizer.register("convgrad3d_fft", local_convgrad3d_fft)
......@@ -1327,8 +1329,7 @@ def local_convtransp3d_fft(node):
f = f.dimshuffle(0, 4, 1, 2, 3)
rval = conv3d_fft(f, x, border_mode='full')
# Shuffle from (ic, b, 0, 1, t) to (b, 0, 1, t, ic)
rval = gpu_from_host(rval.dimshuffle(0, 2, 3, 4, 1))
return rval
return [gpu_from_host(rval.dimshuffle(0, 2, 3, 4, 1))]
gpu_optimizer.register("convtransp3d_fft", local_convtransp3d_fft)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论