提交 aa892a7e authored 作者: abergeron's avatar abergeron

Merge pull request #2279 from nouiz/import_fix

Fix crash at import
...@@ -17,7 +17,7 @@ from theano.gof import (local_optimizer, EquilibriumDB, SequenceDB, ProxyDB, ...@@ -17,7 +17,7 @@ from theano.gof import (local_optimizer, EquilibriumDB, SequenceDB, ProxyDB,
Optimizer, toolbox) Optimizer, toolbox)
from theano.gof.python25 import all, any from theano.gof.python25 import all, any
from theano.sandbox.cuda.basic_ops import ( from theano.sandbox.cuda.basic_ops import (
device_properties, gpu_eye, gpu_contiguous, gpu_eye, gpu_contiguous,
gpu_from_host, host_from_gpu, GpuFromHost, HostFromGpu, gpu_from_host, host_from_gpu, GpuFromHost, HostFromGpu,
GpuElemwise, GpuDimShuffle, GpuReshape, GpuCAReduce, GpuFlatten, GpuElemwise, GpuDimShuffle, GpuReshape, GpuCAReduce, GpuFlatten,
GpuSubtensor, GpuAdvancedSubtensor1, GpuSubtensor, GpuAdvancedSubtensor1,
...@@ -47,6 +47,13 @@ from theano.tensor.blas import _is_real_vector, _is_real_matrix ...@@ -47,6 +47,13 @@ from theano.tensor.blas import _is_real_vector, _is_real_matrix
from theano.tensor import nlinalg from theano.tensor import nlinalg
from theano.tensor.nnet.Conv3D import Conv3D from theano.tensor.nnet.Conv3D import Conv3D
try:
# We need to be able to import this file even if cuda isn't avail.
from theano.sandbox.cuda import device_properties
except ImportError:
pass
#optdb.print_summary() # shows what is currently registered #optdb.print_summary() # shows what is currently registered
#ignore_newtrees is to speed the optimization as this is the pattern #ignore_newtrees is to speed the optimization as this is the pattern
...@@ -1337,10 +1344,12 @@ conv_groupopt.register('conv_fft_full', local_conv_fft_full, 10, ...@@ -1337,10 +1344,12 @@ conv_groupopt.register('conv_fft_full', local_conv_fft_full, 10,
# cuDNN is the second, but only registered if cuDNN is available. # cuDNN is the second, but only registered if cuDNN is available.
# It can be disabled by excluding 'conv_dnn' or 'cudnn'. # It can be disabled by excluding 'conv_dnn' or 'cudnn'.
from . import dnn from . import dnn
if dnn.dnn_available(): # We can't check at import if dnn is available, so we must always
conv_groupopt.register('local_conv_dnn', dnn.local_conv_dnn, 20, # register it. This do not cause problem as if it is not avail, the
'conv_dnn', # opt will do nothing.
'fast_compile', 'fast_run', 'cudnn') conv_groupopt.register('local_conv_dnn', dnn.local_conv_dnn, 20,
'conv_dnn',
'fast_compile', 'fast_run', 'cudnn')
# The GEMM-based convolution comes last to catch all remaining cases. # The GEMM-based convolution comes last to catch all remaining cases.
# It can be disabled by excluding 'conv_gemm'. # It can be disabled by excluding 'conv_gemm'.
conv_groupopt.register('local_conv_gemm', local_conv_gemm, 30, conv_groupopt.register('local_conv_gemm', local_conv_gemm, 30,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论