提交 ea33f222 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Allow other values than 'cpu' and 'gpu?' for the device.

上级 dca4bc5b
...@@ -3,7 +3,7 @@ import logging ...@@ -3,7 +3,7 @@ import logging
import subprocess import subprocess
from theano.configparser import ( from theano.configparser import (
AddConfigVar, BoolParam, ConfigParam, EnumStr, IntParam, AddConfigVar, BoolParam, ConfigParam, DeviceParam, EnumStr, IntParam,
TheanoConfigParser) TheanoConfigParser)
from theano.misc.cpucount import cpuCount from theano.misc.cpucount import cpuCount
from theano.misc.windows import call_subprocess_Popen from theano.misc.windows import call_subprocess_Popen
...@@ -49,12 +49,7 @@ AddConfigVar('device', ...@@ -49,12 +49,7 @@ AddConfigVar('device',
"to move computation to it and to put shared variable of float32 " "to move computation to it and to put shared variable of float32 "
"on it. Do not use upper case letters, only lower case even if " "on it. Do not use upper case letters, only lower case even if "
"NVIDIA use capital letters."), "NVIDIA use capital letters."),
EnumStr('cpu', 'gpu', DeviceParam('cpu', allow_override=False),
'gpu0', 'gpu1', 'gpu2', 'gpu3',
'gpu4', 'gpu5', 'gpu6', 'gpu7',
'gpu8', 'gpu9', 'gpu10', 'gpu11',
'gpu12', 'gpu13', 'gpu14', 'gpu15',
allow_override=False),
in_c_key=False, in_c_key=False,
) )
......
...@@ -314,6 +314,26 @@ class EnumStr(ConfigParam): ...@@ -314,6 +314,26 @@ class EnumStr(ConfigParam):
return '%s (%s) ' % (self.fullname, self.all) return '%s (%s) ' % (self.fullname, self.all)
class DeviceParam(ConfigParam):
def __init__(self, default, *options, **kwargs):
self.default = default
def filter(val):
if val.startswith('cpu') or val.startswith('gpu') \
or val.startswith('opencl') or val.startswith('cuda'):
return val
else:
raise ValueError(('Invalid value ("%s") for configuration '
'variable "%s". Valid options start with '
'one of "cpu", "gpu", "opencl", "cuda"'
% (val, self.fullname)))
over = kwargs.get("allow_override", True)
super(DeviceParam, self).__init__(default, filter, over)
def __str__(self):
return '%s (cpu, gpu*, opencl*, cuda*) ' % (self.fullname,)
class TypedParam(ConfigParam): class TypedParam(ConfigParam):
def __init__(self, default, mytype, is_valid=None, allow_override=True): def __init__(self, default, mytype, is_valid=None, allow_override=True):
self.mytype = mytype self.mytype = mytype
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论