提交 a9454113 authored 作者: Frederic's avatar Frederic

small code clean up (and remove potential bug that currently don't happen)

上级 7bab3eac
......@@ -43,6 +43,7 @@ from theano.sandbox.cuda.var import CudaNdarrayConstant
from theano.scan_module import scan_utils, scan_op, scan_opt
from theano.tensor.blas import _is_real_vector, _is_real_matrix
from theano.tensor import nlinalg
from theano.tensor.nnet.Conv3D import Conv3D
#optdb.print_summary() # shows what is currently registered
......@@ -1236,17 +1237,18 @@ def local_conv_fft_full(node):
gpu_optimizer.register("conv_fft_valid", local_conv_fft_valid)
gpu_optimizer.register("conv_fft_full", local_conv_fft_full)
from theano.tensor.nnet.Conv3D import Conv3D
@local_optimizer([Conv3D])
def local_conv3d_fft(node):
if not isinstance(node.op, Conv3D):
return
try:
stride_x = tensor.get_scalar_constant_value(node.inputs[3][0])
stride_y = tensor.get_scalar_constant_value(node.inputs[3][1])
stride_z = tensor.get_scalar_constant_value(node.inputs[3][2])
except tensor.NotScalarConstantError:
return False
if (isinstance(node.op, Conv3D) and
(stride_x, stride_y, stride_z) == (1, 1, 1)):
if (stride_x, stride_y, stride_z) == (1, 1, 1):
# we import conv3d_fft locally to avoid pycuda warnings
from theano.sandbox.cuda.fftconv import conv3d_fft
# Shuffle inputs signal from (b, 0, 1, t, c) to (b, c, 0, 1, t)
......@@ -1256,7 +1258,7 @@ def local_conv3d_fft(node):
f = node.inputs[1]
f = gpu_from_host(f.dimshuffle(0, 4, 1, 2, 3))
# filter flip
f = f[:,:,::-1,::-1,::-1]
f = f[:, :, ::-1, ::-1, ::-1]
rval = conv3d_fft(x, f, border_mode='valid', pad_last_dim=True)
# Shuffle from (oc, c, 0, 1, t) to (oc, 0, 1, t, c)
return [rval.dimshuffle(0, 2, 3, 4, 1) + node.inputs[2]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论