提交 21f30ed4 authored 作者: Nicolas Ballas's avatar Nicolas Ballas

Fix opt

上级 b14e9943
...@@ -1277,8 +1277,8 @@ def local_conv3d_fft(node): ...@@ -1277,8 +1277,8 @@ def local_conv3d_fft(node):
f = gpu_from_host(f.dimshuffle(0, 4, 1, 2, 3)) f = gpu_from_host(f.dimshuffle(0, 4, 1, 2, 3))
rval = conv3d_fft(x, f) rval = conv3d_fft(x, f)
# Shuffle from (oc, c, 0, 1, t) to (oc, 0, 1, t, c) # Shuffle from (oc, c, 0, 1, t) to (oc, 0, 1, t, c)
rval = gpu_from_host(rval.dimshuffle(0, 2, 3, 4, 1) + node.inputs[2]) return [gpu_from_host(rval.dimshuffle(0, 2, 3, 4, 1) + node.inputs[2])]
return rval
gpu_optimizer.register("conv3d_fft", local_conv3d_fft) gpu_optimizer.register("conv3d_fft", local_conv3d_fft)
...@@ -1293,6 +1293,8 @@ def local_convgrad3d_fft(node): ...@@ -1293,6 +1293,8 @@ def local_convgrad3d_fft(node):
return False return False
if (isinstance(node.op, GpuConvGrad3D) and if (isinstance(node.op, GpuConvGrad3D) and
(stride_x, stride_y, stride_z) == (1, 1, 1)): (stride_x, stride_y, stride_z) == (1, 1, 1)):
# we import conv3d_fft locally to avoid pycuda warnings
from theano.sandbox.cuda.fftconv import conv3d_fft
# Shuffle inputs signal from (b, 0, 1, t, ic) to (ic, b, 0, 1, t) # Shuffle inputs signal from (b, 0, 1, t, ic) to (ic, b, 0, 1, t)
x = node.inputs[0] x = node.inputs[0]
x = x.dimshuffle(4, 0, 1, 2, 3) x = x.dimshuffle(4, 0, 1, 2, 3)
...@@ -1301,8 +1303,8 @@ def local_convgrad3d_fft(node): ...@@ -1301,8 +1303,8 @@ def local_convgrad3d_fft(node):
f = f.dimshuffle(4, 0, 1, 2, 3) f = f.dimshuffle(4, 0, 1, 2, 3)
rval = conv3d_fft(x, f) rval = conv3d_fft(x, f)
# Shuffle from (ic, oc, 0, 1, t) to (oc, 0, 1, t, ic) # Shuffle from (ic, oc, 0, 1, t) to (oc, 0, 1, t, ic)
rval = gpu_from_host(rval.dimshuffle(1, 2, 3, 4, 0)) return [gpu_from_host(rval.dimshuffle(1, 2, 3, 4, 0))]
return rval
gpu_optimizer.register("convgrad3d_fft", local_convgrad3d_fft) gpu_optimizer.register("convgrad3d_fft", local_convgrad3d_fft)
...@@ -1327,8 +1329,7 @@ def local_convtransp3d_fft(node): ...@@ -1327,8 +1329,7 @@ def local_convtransp3d_fft(node):
f = f.dimshuffle(0, 4, 1, 2, 3) f = f.dimshuffle(0, 4, 1, 2, 3)
rval = conv3d_fft(f, x, border_mode='full') rval = conv3d_fft(f, x, border_mode='full')
# Shuffle from (ic, b, 0, 1, t) to (b, 0, 1, t, ic) # Shuffle from (ic, b, 0, 1, t) to (b, 0, 1, t, ic)
rval = gpu_from_host(rval.dimshuffle(0, 2, 3, 4, 1)) return [gpu_from_host(rval.dimshuffle(0, 2, 3, 4, 1))]
return rval
gpu_optimizer.register("convtransp3d_fft", local_convtransp3d_fft) gpu_optimizer.register("convtransp3d_fft", local_convtransp3d_fft)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论