提交 3e57181e authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Brandon T. Willard

Move ContextParam class to theano.configparser

上级 dde33dc5
...@@ -17,6 +17,7 @@ from theano.configparser import ( ...@@ -17,6 +17,7 @@ from theano.configparser import (
AddConfigVar, AddConfigVar,
BoolParam, BoolParam,
ConfigParam, ConfigParam,
ContextsParam,
DeviceParam, DeviceParam,
EnumStr, EnumStr,
FloatParam, FloatParam,
...@@ -100,17 +101,15 @@ AddConfigVar( ...@@ -100,17 +101,15 @@ AddConfigVar(
in_c_key=False, in_c_key=False,
) )
# gpu means let the driver select the gpu. Needed in case of gpu in
# exclusive mode.
# gpuX mean use the gpu number X.
AddConfigVar( AddConfigVar(
"device", "device",
( (
"Default device for computations. If cuda* or opencl*, change the" "Default device for computations. If cuda* or opencl*, change the"
"default to try to move computation to the GPU. Do not use upper case" "default to try to move computation to the GPU. Do not use upper case"
"letters, only lower case even if NVIDIA uses capital letters." "letters, only lower case even if NVIDIA uses capital letters. "
"'gpu' means let the driver select the gpu (needed for gpu in exclusive mode). "
"'gpuX' mean use the gpu number X."
), ),
DeviceParam("cpu", mutable=False), DeviceParam("cpu", mutable=False),
in_c_key=False, in_c_key=False,
...@@ -152,26 +151,6 @@ AddConfigVar( ...@@ -152,26 +151,6 @@ AddConfigVar(
) )
class ContextsParam(ConfigParam):
def __init__(self):
def filter(val):
if val == "":
return val
for v in val.split(";"):
s = v.split("->")
if len(s) != 2:
raise ValueError(f"Malformed context map: {v}")
if (
s[0] == "cpu"
or s[0].startswith("cuda")
or s[0].startswith("opencl")
):
raise ValueError(f"Cannot use {s[0]} as context name")
return val
ConfigParam.__init__(self, "", apply=filter, mutable=False)
AddConfigVar( AddConfigVar(
"contexts", "contexts",
""" """
......
...@@ -528,3 +528,19 @@ class DeviceParam(ConfigParam): ...@@ -528,3 +528,19 @@ class DeviceParam(ConfigParam):
def __str__(self): def __str__(self):
return f"{self.fullname} ({self.default}, opencl*, cuda*) " return f"{self.fullname} ({self.default}, opencl*, cuda*) "
class ContextsParam(ConfigParam):
def __init__(self):
super().__init__("", apply=self._apply, mutable=False)
def _apply(self, val):
if val == "":
return val
for v in val.split(";"):
s = v.split("->")
if len(s) != 2:
raise ValueError(f"Malformed context map: {v}")
if s[0] == "cpu" or s[0].startswith("cuda") or s[0].startswith("opencl"):
raise ValueError(f"Cannot use {s[0]} as context name")
return val
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论