提交 bc9fbfaf authored 作者: Frederic Bastien's avatar Frederic Bastien

added theano flags force_device that have precedence over theano flags device…

added theano flags force_device that have precedence over theano flags device when its value is not cpu. If can't use the gived device, we raise an error.
上级 007ad0f6
......@@ -75,7 +75,7 @@ import scalar
import gradient
import gof
if config.device.startswith('gpu'):
if config.device.startswith('gpu') or config.force_device.startswith('gpu'):
import theano.sandbox.cuda
## import scalar_opt
......
......@@ -15,6 +15,11 @@ AddConfigVar('device',
EnumStr('cpu', 'gpu',*['gpu%i'%i for i in range(4)])
)
AddConfigVar('force_device',
"Have precedence over device if not equal cpu.",
EnumStr('cpu', 'gpu',*['gpu%i'%i for i in range(4)])
)
AddConfigVar('mode',
"Default compilation mode",
EnumStr('Mode', 'ProfileMode', 'DebugMode', 'FAST_RUN', 'FAST_COMPILE', 'PROFILE_MODE', 'DEBUG_MODE'))
......
......@@ -120,7 +120,7 @@ if cuda_available:
import cuda_ndarray
def use(device):
def use(device, force=False):
global cuda_enabled
if device == 'gpu':
pass
......@@ -149,6 +149,10 @@ def use(device):
except RuntimeError, e:
_logger.error("ERROR: Not using GPU. Initialisation of device %i failed. %s" %(device, e))
cuda_enabled = False
if force:
e.args+=("You asked to force this device and it failed. No fallback to the cpu or other gpu device.",)
raise
elif use.device_number != device:
_logger.warning("WARNING: ignoring call to use(%s), GPU number %i is already in use." %(str(device), use.device_number))
optdb.add_tags('gpu',
......@@ -169,6 +173,16 @@ def handle_shared_float32(tf):
else:
raise NotImplementedError('removing our handler')
if cuda_available and config.device.startswith('gpu'):
if cuda_available and config.force_device.startswith('gpu'):
use(config.force_device, True)
elif cuda_available and config.device.startswith('gpu'):
use(config.device)
if config.force_device.startswith('gpu'):
try:
#in case the device if just gpu, we check that the driver init it correctly.
cuda_ndarray.cuda_ndarray.CudaNdarray.zeros((5,5))
except (Exception, NameError), e:#NameError when no gpu present as cuda_ndarray is not loaded.
e.args+=("ERROR: GPU did not work and we told to don't use the cpu. ",)
raise
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论