提交 cbd84195 authored 作者: Frederic's avatar Frederic

Allow to disable the GPU when device=cpu and force_device=True

上级 7f860ba0
...@@ -114,19 +114,6 @@ def try_import(): ...@@ -114,19 +114,6 @@ def try_import():
return True return True
# Add the theano cache directory's cuda_ndarray subdirectory to the
# list of places that are hard-coded into compiled modules' runtime
# library search list. This works in conjunction with
# nvcc_compiler.NVCC_compiler.compile_str which adds this folder during
# compilation with -L and also adds -lcuda_ndarray when compiling
# modules.
nvcc_compiler.add_standard_rpath(cuda_ndarray_loc)
compile_cuda_ndarray = True
if not compile_cuda_ndarray:
compile_cuda_ndarray = not try_import()
if not nvcc_compiler.is_nvcc_available() or not theano.config.cxx: if not nvcc_compiler.is_nvcc_available() or not theano.config.cxx:
# It can happen that the file cuda_ndarray.so is already compiled # It can happen that the file cuda_ndarray.so is already compiled
# but nvcc is not available. In that case we need to disable the CUDA # but nvcc is not available. In that case we need to disable the CUDA
...@@ -134,6 +121,20 @@ if not nvcc_compiler.is_nvcc_available() or not theano.config.cxx: ...@@ -134,6 +121,20 @@ if not nvcc_compiler.is_nvcc_available() or not theano.config.cxx:
# use already compiled GPU op and not the others. # use already compiled GPU op and not the others.
# Also, if cxx is not available, we need to disable all GPU code. # Also, if cxx is not available, we need to disable all GPU code.
set_cuda_disabled() set_cuda_disabled()
elif not config.device.startswith('gpu') and config.force_device:
# We where asked to NEVER use the GPU
set_cuda_disabled()
compile_cuda_ndarray = False
else:
# Add the theano cache directory's cuda_ndarray subdirectory to the
# list of places that are hard-coded into compiled modules' runtime
# library search list. This works in conjunction with
# nvcc_compiler.NVCC_compiler.compile_str which adds this folder during
# compilation with -L and also adds -lcuda_ndarray when compiling
# modules.
nvcc_compiler.add_standard_rpath(cuda_ndarray_loc)
compile_cuda_ndarray = not try_import()
if compile_cuda_ndarray and cuda_available: if compile_cuda_ndarray and cuda_available:
get_lock() get_lock()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论