提交 b8e3cf3f authored 作者: Frederic's avatar Frederic

Change how we set the pycuda device to be sure we always have the good one initialized by Theano.

上级 32d12f3c
import os
import warnings
import theano
import theano.sandbox.cuda as cuda
cuda_ndarray = cuda.cuda_ndarray.cuda_ndarray
def select_gpu_from_theano():
# Transfer the theano gpu binding to pycuda, for consistency
theano_to_pycuda_device_map = {"cpu": "0",
"gpu0": "0",
"gpu1": "1",
"gpu2": "2",
"gpu3": "3"}
dev = theano_to_pycuda_device_map.get(theano.config.device, "0")
if theano.config.device == 'gpu':
dev = str(cuda.cuda_ndarray.cuda_ndarray.active_device_number())
os.environ["CUDA_DEVICE"] = dev
def set_gpu_from_theano():
"""
This set the GPU used by PyCUDA to the same as the one used by Theano.
"""
#import pdb;pdb.set_trace()
if cuda.use.device_number is None:
cuda.use("gpu",
force=False,
default_to_move_computation_to_gpu=False,
move_shared_float32_to_gpu=False,
enable_cuda=True,
test_driver=True)
select_gpu_from_theano()
assert cuda.use.device_number == cuda_ndarray.active_device_number()
# os.environ["CUDA_DEVICE"] = str(cuda.use.device_number)
set_gpu_from_theano()
pycuda_available = False
try:
import pycuda
import pycuda.autoinit
pycuda_available = True
except ImportError:
# presumably, the user wanted to use pycuda, else they wouldn't have
# imported this module, so issue a warning that the import failed.
import warnings
warnings.warn("PyCUDA import failed in theano.misc.pycuda_init")
if True: # theano.sandbox.cuda.use.device_number is None:
try:
import pycuda
import pycuda.autoinit
pycuda_available = True
except ImportError:
# presumably, the user wanted to use pycuda, else they wouldn't have
# imported this module, so issue a warning that the import failed.
warnings.warn("PyCUDA import failed in theano.misc.pycuda_init")
else:
warnings.warn("theano.misc.pycuda_init must be imported before theano"
" init its GPU")
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论