提交 60d5f91d authored 作者: abergeron's avatar abergeron

Merge pull request #1896 from abergeron/cuda_fftconv

Don't import fftconv.py unless it's necessary to avoid a PyCUDA warning.
......@@ -40,7 +40,6 @@ from theano.sandbox.cuda.elemwise import SupportCodeError
from theano.scalar.basic_scipy import Erfinv
from theano.sandbox.cuda.elemwise import erfinv_gpu
from theano.sandbox.cuda.var import CudaNdarrayConstant
from theano.sandbox.cuda.fftconv import conv2d_fft
from theano.scan_module import scan_utils, scan_op, scan_opt
from theano.tensor.blas import _is_real_vector, _is_real_matrix
linalg = None
......@@ -1122,6 +1121,8 @@ def local_gpu_conv(node):
@local_optimizer([GpuConv])
def local_conv_fft_valid(node):
# import locally to avoid pycuda warnings
from theano.sandbox.cuda.fftconv import conv2d_fft
if (isinstance(node.op, GpuConv) and
node.op.border_mode == 'valid' and
node.op.subsample == (1, 1)):
......@@ -1130,6 +1131,8 @@ def local_conv_fft_valid(node):
@local_optimizer([GpuConv])
def local_conv_fft_full(node):
# import locally to avoid pycuda warnings
from theano.sandbox.cuda.fftconv import conv2d_fft
if (isinstance(node.op, GpuConv) and
node.op.border_mode == 'full' and
node.op.subsample == (1, 1)):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论