提交 d3da0a2d authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Move DeviceParam to configdefaults.py

上级 86e48dc9
......@@ -2,9 +2,8 @@ import os
import logging
import subprocess
from theano.configparser import (
AddConfigVar, BoolParam, ConfigParam, DeviceParam, EnumStr, IntParam,
StrParam, TheanoConfigParser)
from theano.configparser import (AddConfigVar, BoolParam, ConfigParam, EnumStr,
IntParam, StrParam, TheanoConfigParser)
from theano.misc.cpucount import cpuCount
from theano.misc.windows import call_subprocess_Popen
......@@ -44,6 +43,25 @@ AddConfigVar('int_division',
# gpu means let the driver select the gpu. Needed in case of gpu in
# exclusive mode.
# gpuX mean use the gpu number X.
class DeviceParam(ConfigParam):
def __init__(self, default, *options, **kwargs):
self.default = default
def filter(val):
if val.startswith('cpu') or val.startswith('gpu') \
or val.startswith('opencl') or val.startswith('cuda'):
return val
else:
raise ValueError(('Invalid value ("%s") for configuration '
'variable "%s". Valid options start with '
'one of "cpu", "gpu", "opencl", "cuda"'
% (val, self.fullname)))
over = kwargs.get("allow_override", True)
super(DeviceParam, self).__init__(default, filter, over)
def __str__(self):
return '%s (cpu, gpu*, opencl*, cuda*) ' % (self.fullname,)
AddConfigVar('device',
("Default device for computations. If gpu*, change the default to try "
"to move computation to it and to put shared variable of float32 "
......
......@@ -314,26 +314,6 @@ class EnumStr(ConfigParam):
return '%s (%s) ' % (self.fullname, self.all)
class DeviceParam(ConfigParam):
def __init__(self, default, *options, **kwargs):
self.default = default
def filter(val):
if val.startswith('cpu') or val.startswith('gpu') \
or val.startswith('opencl') or val.startswith('cuda'):
return val
else:
raise ValueError(('Invalid value ("%s") for configuration '
'variable "%s". Valid options start with '
'one of "cpu", "gpu", "opencl", "cuda"'
% (val, self.fullname)))
over = kwargs.get("allow_override", True)
super(DeviceParam, self).__init__(default, filter, over)
def __str__(self):
return '%s (cpu, gpu*, opencl*, cuda*) ' % (self.fullname,)
class TypedParam(ConfigParam):
def __init__(self, default, mytype, is_valid=None, allow_override=True):
self.mytype = mytype
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论