提交 2051ae99 authored 作者: Nicolas Ballas's avatar Nicolas Ballas

Add kernel flip to follow the behavior of Conv3D/ConvGrad3D

上级 bab9eacb
...@@ -1275,6 +1275,8 @@ def local_conv3d_fft(node): ...@@ -1275,6 +1275,8 @@ def local_conv3d_fft(node):
# Shuffle filters from (oc, 0, 1, t, ic) to (oc, ic, 0, 1, t) # Shuffle filters from (oc, 0, 1, t, ic) to (oc, ic, 0, 1, t)
f = node.inputs[1] f = node.inputs[1]
f = gpu_from_host(f.dimshuffle(0, 4, 1, 2, 3)) f = gpu_from_host(f.dimshuffle(0, 4, 1, 2, 3))
# filter flip
f = f[:,:,::-1,::-1,::-1]
rval = conv3d_fft(x, f, border_mode='valid', pad_last_dim=True) rval = conv3d_fft(x, f, border_mode='valid', pad_last_dim=True)
# Shuffle from (oc, c, 0, 1, t) to (oc, 0, 1, t, c) # Shuffle from (oc, c, 0, 1, t) to (oc, 0, 1, t, c)
return [rval.dimshuffle(0, 2, 3, 4, 1) + node.inputs[2]] return [rval.dimshuffle(0, 2, 3, 4, 1) + node.inputs[2]]
...@@ -1301,6 +1303,8 @@ def local_convgrad3d_fft(node): ...@@ -1301,6 +1303,8 @@ def local_convgrad3d_fft(node):
# Shuffle dCdH from (b, 0, 1, t, oc) to (oc, b, 0, 1, t) # Shuffle dCdH from (b, 0, 1, t, oc) to (oc, b, 0, 1, t)
f = node.inputs[3] f = node.inputs[3]
f = f.dimshuffle(4, 0, 1, 2, 3) f = f.dimshuffle(4, 0, 1, 2, 3)
# filter flip
f = f[:,:,::-1,::-1,::-1]
rval = conv3d_fft(x, f, border_mode='valid', pad_last_dim=True) rval = conv3d_fft(x, f, border_mode='valid', pad_last_dim=True)
# Shuffle from (ic, oc, 0, 1, t) to (oc, 0, 1, t, ic) # Shuffle from (ic, oc, 0, 1, t) to (oc, 0, 1, t, ic)
return [rval.dimshuffle(1, 2, 3, 4, 0)] return [rval.dimshuffle(1, 2, 3, 4, 0)]
...@@ -1329,7 +1333,7 @@ def local_convtransp3d_fft(node): ...@@ -1329,7 +1333,7 @@ def local_convtransp3d_fft(node):
f = f.dimshuffle(0, 4, 1, 2, 3) f = f.dimshuffle(0, 4, 1, 2, 3)
rval = conv3d_fft(f, x, border_mode='full', pad_last_dim=True) rval = conv3d_fft(f, x, border_mode='full', pad_last_dim=True)
# Shuffle from (ic, b, 0, 1, t) to (b, 0, 1, t, ic) # Shuffle from (ic, b, 0, 1, t) to (b, 0, 1, t, ic)
return [rval.dimshuffle(0, 2, 3, 4, 1)] return [rval.dimshuffle(0, 2, 3, 4, 1) + node.inputs[1]]
gpu_optimizer.register("convtransp3d_fft", local_convtransp3d_fft) gpu_optimizer.register("convtransp3d_fft", local_convtransp3d_fft)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论