提交 8cce5c55 authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Brandon T. Willard

Always pass EnumStr options as sequence

上级 d72eed70
......@@ -85,22 +85,39 @@ def test_config_param_apply_and_validation():
cp.__set__("cls", "THEDEFAULT")
def test_config_types_bool():
valids = {
True: ["1", 1, True, "true", "True"],
False: ["0", 0, False, "false", "False"],
}
param = configparser.BoolParam(None)
assert isinstance(param, configparser.ConfigParam)
assert param.default is None
for outcome, inputs in valids.items():
for input in inputs:
applied = param.apply(input)
assert applied == outcome
assert param.validate(applied) is not False
with pytest.raises(ValueError, match="Invalid value"):
param.apply("notabool")
class TestConfigTypes:
def test_bool(self):
valids = {
True: ["1", 1, True, "true", "True"],
False: ["0", 0, False, "false", "False"],
}
param = configparser.BoolParam(None)
assert isinstance(param, configparser.ConfigParam)
assert param.default is None
for outcome, inputs in valids.items():
for input in inputs:
applied = param.apply(input)
assert applied == outcome
assert param.validate(applied) is not False
with pytest.raises(ValueError, match="Invalid value"):
param.apply("notabool")
pass
def test_enumstr(self):
cp = configparser.EnumStr("blue", ["red", "green", "yellow"])
assert len(cp.all) == 4
with pytest.raises(ValueError, match=r"Invalid value \('foo'\)"):
cp.apply("foo")
with pytest.raises(ValueError, match="Non-str value"):
configparser.EnumStr(default="red", options=["red", 12, "yellow"])
pass
def test_deviceparam(self):
cp = configparser.DeviceParam("cpu", mutable=False)
assert cp.default == "cpu"
assert cp._apply("cuda123") == "cuda123"
with pytest.raises(ValueError, match="old GPU back-end"):
cp._apply("gpu123")
with pytest.raises(ValueError, match="Invalid value"):
cp._apply("notadevice")
assert str(cp) == "None (cpu, opencl*, cuda*) "
......@@ -39,11 +39,7 @@ AddConfigVar(
"Default floating-point precision for python casts.\n"
"\n"
"Note: float16 support is experimental, use at your own risk.",
EnumStr(
"float64",
"float32",
"float16",
),
EnumStr("float64", ["float32", "float16"]),
# TODO: see gh-4466 for how to remove it.
in_c_key=True,
)
......@@ -53,7 +49,7 @@ AddConfigVar(
"Do an action when a tensor variable with float64 dtype is"
" created. They can't be run on the GPU with the current(old)"
" gpu back-end and are slow with gamer GPUs.",
EnumStr("ignore", "warn", "raise", "pdb"),
EnumStr("ignore", ["warn", "raise", "pdb"]),
in_c_key=False,
)
......@@ -70,7 +66,7 @@ AddConfigVar(
"Rules for implicit type casting",
EnumStr(
"custom",
"numpy+floatX",
["numpy+floatX"],
# The 'numpy' policy was originally planned to provide a
# smooth transition from numpy. It was meant to behave the
# same as numpy+floatX, but keeping float64 when numpy
......@@ -89,7 +85,7 @@ AddConfigVar(
AddConfigVar(
"int_division",
"What to do when one computes x / y, where both x and y are of " "integer types",
EnumStr("int", "raise", "floatX"),
EnumStr("int", ["raise", "floatX"]),
in_c_key=False,
)
......@@ -101,7 +97,7 @@ AddConfigVar(
"non-deterministic implementaion, e.g. when we do not have a GPU "
"implementation that is deterministic. Also see "
"the dnn.conv.algo* flags to cover more cases.",
EnumStr("default", "more"),
EnumStr("default", ["more"]),
in_c_key=False,
)
......@@ -218,7 +214,7 @@ AddConfigVar(
CPU overhead when waiting for GPU. One user found that it
speeds up his other processes that was doing data augmentation.
""",
EnumStr("default", "multi", "single"),
EnumStr("default", ["multi", "single"]),
)
AddConfigVar(
......@@ -325,7 +321,7 @@ SUPPORTED_DNN_CONV_PRECISION = (
AddConfigVar(
"dnn.conv.algo_fwd",
"Default implementation to use for cuDNN forward convolution.",
EnumStr(*SUPPORTED_DNN_CONV_ALGO_FWD),
EnumStr("small", SUPPORTED_DNN_CONV_ALGO_FWD),
in_c_key=False,
)
......@@ -333,7 +329,7 @@ AddConfigVar(
"dnn.conv.algo_bwd_data",
"Default implementation to use for cuDNN backward convolution to "
"get the gradients of the convolution with regard to the inputs.",
EnumStr(*SUPPORTED_DNN_CONV_ALGO_BWD_DATA),
EnumStr("none", SUPPORTED_DNN_CONV_ALGO_BWD_DATA),
in_c_key=False,
)
......@@ -342,7 +338,7 @@ AddConfigVar(
"Default implementation to use for cuDNN backward convolution to "
"get the gradients of the convolution with regard to the "
"filters.",
EnumStr(*SUPPORTED_DNN_CONV_ALGO_BWD_FILTER),
EnumStr("none", SUPPORTED_DNN_CONV_ALGO_BWD_FILTER),
in_c_key=False,
)
......@@ -351,7 +347,7 @@ AddConfigVar(
"Default data precision to use for the computation in cuDNN "
"convolutions (defaults to the same dtype as the inputs of the "
"convolutions, or float32 if inputs are float16).",
EnumStr(*SUPPORTED_DNN_CONV_PRECISION),
EnumStr("as_input_f32", SUPPORTED_DNN_CONV_PRECISION),
in_c_key=False,
)
......@@ -434,7 +430,7 @@ AddConfigVar(
" If True and cuDNN can not be used, raise an error."
" If False, disable cudnn even if present."
" If no_check, assume present and the version between header and library match (so less compilation at context init)",
EnumStr("auto", "True", "False", "no_check"),
EnumStr("auto", ["True", "False", "no_check"]),
in_c_key=False,
)
......@@ -458,7 +454,7 @@ AddConfigVar(
AddConfigVar(
"assert_no_cpu_op",
"Raise an error/warning if there is a CPU op in the computational graph.",
EnumStr("ignore", "warn", "raise", "pdb", mutable=True),
EnumStr("ignore", ["warn", "raise", "pdb"], mutable=True),
in_c_key=False,
)
......@@ -583,7 +579,7 @@ if rc == 0 and config.cxx != "":
AddConfigVar(
"linker",
"Default linker used if the theano flags mode is Mode",
EnumStr("cvm", "c|py", "py", "c", "c|py_nogc", "vm", "vm_nogc", "cvm_nogc"),
EnumStr("cvm", ["c|py", "py", "c", "c|py_nogc", "vm", "vm_nogc", "cvm_nogc"]),
in_c_key=False,
)
else:
......@@ -592,7 +588,7 @@ else:
AddConfigVar(
"linker",
"Default linker used if the theano flags mode is Mode",
EnumStr("vm", "py", "vm_nogc"),
EnumStr("vm", ["py", "vm_nogc"]),
in_c_key=False,
)
if type(config).cxx.is_default:
......@@ -623,7 +619,7 @@ AddConfigVar(
"optimizer",
"Default optimizer. If not None, will use this optimizer with the Mode",
EnumStr(
"o4", "o3", "o2", "o1", "unsafe", "fast_run", "fast_compile", "merge", "None"
"o4", ["o3", "o2", "o1", "unsafe", "fast_run", "fast_compile", "merge", "None"]
),
in_c_key=False,
)
......@@ -641,7 +637,7 @@ AddConfigVar(
"What to do when an optimization crashes: warn and skip it, raise "
"the exception, or fall into the pdb debugger."
),
EnumStr("warn", "raise", "pdb", "ignore"),
EnumStr("warn", ["raise", "pdb", "ignore"]),
in_c_key=False,
)
......@@ -656,7 +652,7 @@ AddConfigVar(
"on_unused_input",
"What to do if a variable in the 'inputs' list of "
" theano.function() is not used in the graph.",
EnumStr("raise", "warn", "ignore"),
EnumStr("raise", ["warn", "ignore"]),
in_c_key=False,
)
......@@ -756,7 +752,7 @@ AddConfigVar(
"by the following flags: seterr_divide, seterr_over, "
"seterr_under and seterr_invalid.",
),
EnumStr("ignore", "warn", "raise", "call", "print", "log", "None", mutable=False),
EnumStr("ignore", ["warn", "raise", "call", "print", "log", "None"], mutable=False),
in_c_key=False,
)
......@@ -766,7 +762,7 @@ AddConfigVar(
"Sets numpy's behavior for division by zero, see numpy.seterr. "
"'None' means using the default, defined by numpy.seterr_all."
),
EnumStr("None", "ignore", "warn", "raise", "call", "print", "log", mutable=False),
EnumStr("None", ["ignore", "warn", "raise", "call", "print", "log"], mutable=False),
in_c_key=False,
)
......@@ -777,7 +773,7 @@ AddConfigVar(
"see numpy.seterr. "
"'None' means using the default, defined by numpy.seterr_all."
),
EnumStr("None", "ignore", "warn", "raise", "call", "print", "log", mutable=False),
EnumStr("None", ["ignore", "warn", "raise", "call", "print", "log"], mutable=False),
in_c_key=False,
)
......@@ -788,7 +784,7 @@ AddConfigVar(
"see numpy.seterr. "
"'None' means using the default, defined by numpy.seterr_all."
),
EnumStr("None", "ignore", "warn", "raise", "call", "print", "log", mutable=False),
EnumStr("None", ["ignore", "warn", "raise", "call", "print", "log"], mutable=False),
in_c_key=False,
)
......@@ -799,7 +795,7 @@ AddConfigVar(
"see numpy.seterr. "
"'None' means using the default, defined by numpy.seterr_all."
),
EnumStr("None", "ignore", "warn", "raise", "call", "print", "log", mutable=False),
EnumStr("None", ["ignore", "warn", "raise", "call", "print", "log"], mutable=False),
in_c_key=False,
)
......@@ -818,25 +814,27 @@ AddConfigVar(
),
EnumStr(
"0.9",
"None",
"all",
"0.3",
"0.4",
"0.4.1",
"0.5",
"0.6",
"0.7",
"0.8",
"0.8.1",
"0.8.2",
"0.9",
"0.10",
"1.0",
"1.0.1",
"1.0.2",
"1.0.3",
"1.0.4",
"1.0.5",
[
"None",
"all",
"0.3",
"0.4",
"0.4.1",
"0.5",
"0.6",
"0.7",
"0.8",
"0.8.1",
"0.8.2",
"0.9",
"0.10",
"1.0",
"1.0.1",
"1.0.2",
"1.0.3",
"1.0.4",
"1.0.5",
],
mutable=False,
),
in_c_key=False,
......@@ -1005,7 +1003,7 @@ AddConfigVar(
"to the function. This helps the user track down problems in the "
"graph before it gets optimized."
),
EnumStr("off", "ignore", "warn", "raise", "pdb"),
EnumStr("off", ["ignore", "warn", "raise", "pdb"]),
in_c_key=False,
)
......@@ -1030,7 +1028,7 @@ AddConfigVar(
" Same as compute_test_value, but is used"
" during Theano optimization"
),
EnumStr("off", "ignore", "warn", "raise", "pdb"),
EnumStr("off", ["ignore", "warn", "raise", "pdb"]),
in_c_key=False,
)
......@@ -1068,7 +1066,7 @@ AddConfigVar(
A. Elemwise{add_no_inplace}
B. log_likelihood_v_given_h
C. log_likelihood_h""",
EnumStr("low", "high"),
EnumStr("low", ["high"]),
in_c_key=False,
)
......@@ -1190,7 +1188,7 @@ AddConfigVar(
AddConfigVar(
"NanGuardMode.action",
"What NanGuardMode does when it finds a problem",
EnumStr("raise", "warn", "pdb"),
EnumStr("raise", ["warn", "pdb"]),
in_c_key=False,
)
......@@ -1876,7 +1874,7 @@ AddConfigVar(
AddConfigVar(
"on_shape_error",
"warn: print a warning and use the default" " value. raise: raise an error",
theano.configparser.EnumStr("warn", "raise"),
theano.configparser.EnumStr("warn", ["raise"]),
in_c_key=False,
)
......@@ -1943,7 +1941,7 @@ AddConfigVar(
"The interaction of which one give the lower peak memory usage is"
"complicated and not predictable, so if you are close to the peak"
"memory usage, triyng both could give you a small gain.",
EnumStr("regular", "fast"),
EnumStr("regular", ["fast"]),
in_c_key=False,
)
......@@ -1957,7 +1955,7 @@ AddConfigVar(
"stack trace is inserted that indicates which optimization inserted"
"the variable that had an empty stack trace."
"raise: raises an exception if a stack trace is missing",
EnumStr("off", "log", "warn", "raise"),
EnumStr("off", ["log", "warn", "raise"]),
in_c_key=False,
)
......
......@@ -425,23 +425,37 @@ class ConfigParam:
class EnumStr(ConfigParam):
def __init__(self, default, *options, **kwargs):
def __init__(
self, default: str, options: typing.Sequence[str], validate=None, mutable=True
):
"""Creates a str-based parameter that takes a predefined set of options.
Parameters
----------
default : str
The default setting.
options : sequence
Further str values that the parameter may take.
May, but does not need to include the default.
validate : callable
See `ConfigParam`.
mutable : callable
See `ConfigParam`.
"""
self.all = {default, *options}
# All options should be strings
for val in self.all:
if not isinstance(val, str):
raise ValueError(f"Non-str value '{val}' for an EnumStr parameter.")
super().__init__(
default, apply=self._apply, mutable=kwargs.get("mutable", True)
)
super().__init__(default, apply=self._apply, validate=validate, mutable=mutable)
def _apply(self, val):
if val in self.all:
return val
else:
raise ValueError(
f'Invalid value ("{val}") for configuration variable "{self.fullname}". '
f"Invalid value ('{val}') for configuration variable '{self.fullname}'. "
f"Valid options are {self.all}"
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论