提交 a9454113 authored 作者: Frederic's avatar Frederic

small code clean up (and remove potential bug that currently don't happen)

上级 7bab3eac
...@@ -43,6 +43,7 @@ from theano.sandbox.cuda.var import CudaNdarrayConstant ...@@ -43,6 +43,7 @@ from theano.sandbox.cuda.var import CudaNdarrayConstant
from theano.scan_module import scan_utils, scan_op, scan_opt from theano.scan_module import scan_utils, scan_op, scan_opt
from theano.tensor.blas import _is_real_vector, _is_real_matrix from theano.tensor.blas import _is_real_vector, _is_real_matrix
from theano.tensor import nlinalg from theano.tensor import nlinalg
from theano.tensor.nnet.Conv3D import Conv3D
#optdb.print_summary() # shows what is currently registered #optdb.print_summary() # shows what is currently registered
...@@ -1236,17 +1237,18 @@ def local_conv_fft_full(node): ...@@ -1236,17 +1237,18 @@ def local_conv_fft_full(node):
gpu_optimizer.register("conv_fft_valid", local_conv_fft_valid) gpu_optimizer.register("conv_fft_valid", local_conv_fft_valid)
gpu_optimizer.register("conv_fft_full", local_conv_fft_full) gpu_optimizer.register("conv_fft_full", local_conv_fft_full)
from theano.tensor.nnet.Conv3D import Conv3D
@local_optimizer([Conv3D]) @local_optimizer([Conv3D])
def local_conv3d_fft(node): def local_conv3d_fft(node):
if not isinstance(node.op, Conv3D):
return
try: try:
stride_x = tensor.get_scalar_constant_value(node.inputs[3][0]) stride_x = tensor.get_scalar_constant_value(node.inputs[3][0])
stride_y = tensor.get_scalar_constant_value(node.inputs[3][1]) stride_y = tensor.get_scalar_constant_value(node.inputs[3][1])
stride_z = tensor.get_scalar_constant_value(node.inputs[3][2]) stride_z = tensor.get_scalar_constant_value(node.inputs[3][2])
except tensor.NotScalarConstantError: except tensor.NotScalarConstantError:
return False return False
if (isinstance(node.op, Conv3D) and if (stride_x, stride_y, stride_z) == (1, 1, 1):
(stride_x, stride_y, stride_z) == (1, 1, 1)):
# we import conv3d_fft locally to avoid pycuda warnings # we import conv3d_fft locally to avoid pycuda warnings
from theano.sandbox.cuda.fftconv import conv3d_fft from theano.sandbox.cuda.fftconv import conv3d_fft
# Shuffle inputs signal from (b, 0, 1, t, c) to (b, c, 0, 1, t) # Shuffle inputs signal from (b, 0, 1, t, c) to (b, c, 0, 1, t)
...@@ -1256,7 +1258,7 @@ def local_conv3d_fft(node): ...@@ -1256,7 +1258,7 @@ def local_conv3d_fft(node):
f = node.inputs[1] f = node.inputs[1]
f = gpu_from_host(f.dimshuffle(0, 4, 1, 2, 3)) f = gpu_from_host(f.dimshuffle(0, 4, 1, 2, 3))
# filter flip # filter flip
f = f[:,:,::-1,::-1,::-1] f = f[:, :, ::-1, ::-1, ::-1]
rval = conv3d_fft(x, f, border_mode='valid', pad_last_dim=True) rval = conv3d_fft(x, f, border_mode='valid', pad_last_dim=True)
# Shuffle from (oc, c, 0, 1, t) to (oc, 0, 1, t, c) # Shuffle from (oc, c, 0, 1, t) to (oc, 0, 1, t, c)
return [rval.dimshuffle(0, 2, 3, 4, 1) + node.inputs[2]] return [rval.dimshuffle(0, 2, 3, 4, 1) + node.inputs[2]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论