提交 14f0a2a0 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Really only import if it's needed.

上级 1df8f73e
...@@ -1121,21 +1121,21 @@ def local_gpu_conv(node): ...@@ -1121,21 +1121,21 @@ def local_gpu_conv(node):
@local_optimizer([GpuConv]) @local_optimizer([GpuConv])
def local_conv_fft_valid(node): def local_conv_fft_valid(node):
# import locally to avoid pycuda warnings
from theano.sandbox.cuda.fftconv import conv2d_fft
if (isinstance(node.op, GpuConv) and if (isinstance(node.op, GpuConv) and
node.op.border_mode == 'valid' and node.op.border_mode == 'valid' and
node.op.subsample == (1, 1)): node.op.subsample == (1, 1)):
# import locally to avoid pycuda warnings
from theano.sandbox.cuda.fftconv import conv2d_fft
return [conv2d_fft(node.inputs[0], node.inputs[1])] return [conv2d_fft(node.inputs[0], node.inputs[1])]
@local_optimizer([GpuConv]) @local_optimizer([GpuConv])
def local_conv_fft_full(node): def local_conv_fft_full(node):
# import locally to avoid pycuda warnings
from theano.sandbox.cuda.fftconv import conv2d_fft
if (isinstance(node.op, GpuConv) and if (isinstance(node.op, GpuConv) and
node.op.border_mode == 'full' and node.op.border_mode == 'full' and
node.op.subsample == (1, 1)): node.op.subsample == (1, 1)):
# import locally to avoid pycuda warnings
from theano.sandbox.cuda.fftconv import conv2d_fft
return [conv2d_fft(node.inputs[0], node.inputs[1], border_mode='full')] return [conv2d_fft(node.inputs[0], node.inputs[1], border_mode='full')]
gpu_optimizer.register("conv_fft_valid", local_conv_fft_valid) gpu_optimizer.register("conv_fft_valid", local_conv_fft_valid)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论