提交 8cce5c55 authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Brandon T. Willard

Always pass EnumStr options as sequence

上级 d72eed70
...@@ -85,22 +85,39 @@ def test_config_param_apply_and_validation(): ...@@ -85,22 +85,39 @@ def test_config_param_apply_and_validation():
cp.__set__("cls", "THEDEFAULT") cp.__set__("cls", "THEDEFAULT")
def test_config_types_bool(): class TestConfigTypes:
valids = { def test_bool(self):
True: ["1", 1, True, "true", "True"], valids = {
False: ["0", 0, False, "false", "False"], True: ["1", 1, True, "true", "True"],
} False: ["0", 0, False, "false", "False"],
}
param = configparser.BoolParam(None) param = configparser.BoolParam(None)
assert isinstance(param, configparser.ConfigParam)
assert isinstance(param, configparser.ConfigParam) assert param.default is None
assert param.default is None for outcome, inputs in valids.items():
for input in inputs:
for outcome, inputs in valids.items(): applied = param.apply(input)
for input in inputs: assert applied == outcome
applied = param.apply(input) assert param.validate(applied) is not False
assert applied == outcome with pytest.raises(ValueError, match="Invalid value"):
assert param.validate(applied) is not False param.apply("notabool")
pass
with pytest.raises(ValueError, match="Invalid value"):
param.apply("notabool") def test_enumstr(self):
cp = configparser.EnumStr("blue", ["red", "green", "yellow"])
assert len(cp.all) == 4
with pytest.raises(ValueError, match=r"Invalid value \('foo'\)"):
cp.apply("foo")
with pytest.raises(ValueError, match="Non-str value"):
configparser.EnumStr(default="red", options=["red", 12, "yellow"])
pass
def test_deviceparam(self):
cp = configparser.DeviceParam("cpu", mutable=False)
assert cp.default == "cpu"
assert cp._apply("cuda123") == "cuda123"
with pytest.raises(ValueError, match="old GPU back-end"):
cp._apply("gpu123")
with pytest.raises(ValueError, match="Invalid value"):
cp._apply("notadevice")
assert str(cp) == "None (cpu, opencl*, cuda*) "
...@@ -39,11 +39,7 @@ AddConfigVar( ...@@ -39,11 +39,7 @@ AddConfigVar(
"Default floating-point precision for python casts.\n" "Default floating-point precision for python casts.\n"
"\n" "\n"
"Note: float16 support is experimental, use at your own risk.", "Note: float16 support is experimental, use at your own risk.",
EnumStr( EnumStr("float64", ["float32", "float16"]),
"float64",
"float32",
"float16",
),
# TODO: see gh-4466 for how to remove it. # TODO: see gh-4466 for how to remove it.
in_c_key=True, in_c_key=True,
) )
...@@ -53,7 +49,7 @@ AddConfigVar( ...@@ -53,7 +49,7 @@ AddConfigVar(
"Do an action when a tensor variable with float64 dtype is" "Do an action when a tensor variable with float64 dtype is"
" created. They can't be run on the GPU with the current(old)" " created. They can't be run on the GPU with the current(old)"
" gpu back-end and are slow with gamer GPUs.", " gpu back-end and are slow with gamer GPUs.",
EnumStr("ignore", "warn", "raise", "pdb"), EnumStr("ignore", ["warn", "raise", "pdb"]),
in_c_key=False, in_c_key=False,
) )
...@@ -70,7 +66,7 @@ AddConfigVar( ...@@ -70,7 +66,7 @@ AddConfigVar(
"Rules for implicit type casting", "Rules for implicit type casting",
EnumStr( EnumStr(
"custom", "custom",
"numpy+floatX", ["numpy+floatX"],
# The 'numpy' policy was originally planned to provide a # The 'numpy' policy was originally planned to provide a
# smooth transition from numpy. It was meant to behave the # smooth transition from numpy. It was meant to behave the
# same as numpy+floatX, but keeping float64 when numpy # same as numpy+floatX, but keeping float64 when numpy
...@@ -89,7 +85,7 @@ AddConfigVar( ...@@ -89,7 +85,7 @@ AddConfigVar(
AddConfigVar( AddConfigVar(
"int_division", "int_division",
"What to do when one computes x / y, where both x and y are of " "integer types", "What to do when one computes x / y, where both x and y are of " "integer types",
EnumStr("int", "raise", "floatX"), EnumStr("int", ["raise", "floatX"]),
in_c_key=False, in_c_key=False,
) )
...@@ -101,7 +97,7 @@ AddConfigVar( ...@@ -101,7 +97,7 @@ AddConfigVar(
"non-deterministic implementaion, e.g. when we do not have a GPU " "non-deterministic implementaion, e.g. when we do not have a GPU "
"implementation that is deterministic. Also see " "implementation that is deterministic. Also see "
"the dnn.conv.algo* flags to cover more cases.", "the dnn.conv.algo* flags to cover more cases.",
EnumStr("default", "more"), EnumStr("default", ["more"]),
in_c_key=False, in_c_key=False,
) )
...@@ -218,7 +214,7 @@ AddConfigVar( ...@@ -218,7 +214,7 @@ AddConfigVar(
CPU overhead when waiting for GPU. One user found that it CPU overhead when waiting for GPU. One user found that it
speeds up his other processes that was doing data augmentation. speeds up his other processes that was doing data augmentation.
""", """,
EnumStr("default", "multi", "single"), EnumStr("default", ["multi", "single"]),
) )
AddConfigVar( AddConfigVar(
...@@ -325,7 +321,7 @@ SUPPORTED_DNN_CONV_PRECISION = ( ...@@ -325,7 +321,7 @@ SUPPORTED_DNN_CONV_PRECISION = (
AddConfigVar( AddConfigVar(
"dnn.conv.algo_fwd", "dnn.conv.algo_fwd",
"Default implementation to use for cuDNN forward convolution.", "Default implementation to use for cuDNN forward convolution.",
EnumStr(*SUPPORTED_DNN_CONV_ALGO_FWD), EnumStr("small", SUPPORTED_DNN_CONV_ALGO_FWD),
in_c_key=False, in_c_key=False,
) )
...@@ -333,7 +329,7 @@ AddConfigVar( ...@@ -333,7 +329,7 @@ AddConfigVar(
"dnn.conv.algo_bwd_data", "dnn.conv.algo_bwd_data",
"Default implementation to use for cuDNN backward convolution to " "Default implementation to use for cuDNN backward convolution to "
"get the gradients of the convolution with regard to the inputs.", "get the gradients of the convolution with regard to the inputs.",
EnumStr(*SUPPORTED_DNN_CONV_ALGO_BWD_DATA), EnumStr("none", SUPPORTED_DNN_CONV_ALGO_BWD_DATA),
in_c_key=False, in_c_key=False,
) )
...@@ -342,7 +338,7 @@ AddConfigVar( ...@@ -342,7 +338,7 @@ AddConfigVar(
"Default implementation to use for cuDNN backward convolution to " "Default implementation to use for cuDNN backward convolution to "
"get the gradients of the convolution with regard to the " "get the gradients of the convolution with regard to the "
"filters.", "filters.",
EnumStr(*SUPPORTED_DNN_CONV_ALGO_BWD_FILTER), EnumStr("none", SUPPORTED_DNN_CONV_ALGO_BWD_FILTER),
in_c_key=False, in_c_key=False,
) )
...@@ -351,7 +347,7 @@ AddConfigVar( ...@@ -351,7 +347,7 @@ AddConfigVar(
"Default data precision to use for the computation in cuDNN " "Default data precision to use for the computation in cuDNN "
"convolutions (defaults to the same dtype as the inputs of the " "convolutions (defaults to the same dtype as the inputs of the "
"convolutions, or float32 if inputs are float16).", "convolutions, or float32 if inputs are float16).",
EnumStr(*SUPPORTED_DNN_CONV_PRECISION), EnumStr("as_input_f32", SUPPORTED_DNN_CONV_PRECISION),
in_c_key=False, in_c_key=False,
) )
...@@ -434,7 +430,7 @@ AddConfigVar( ...@@ -434,7 +430,7 @@ AddConfigVar(
" If True and cuDNN can not be used, raise an error." " If True and cuDNN can not be used, raise an error."
" If False, disable cudnn even if present." " If False, disable cudnn even if present."
" If no_check, assume present and the version between header and library match (so less compilation at context init)", " If no_check, assume present and the version between header and library match (so less compilation at context init)",
EnumStr("auto", "True", "False", "no_check"), EnumStr("auto", ["True", "False", "no_check"]),
in_c_key=False, in_c_key=False,
) )
...@@ -458,7 +454,7 @@ AddConfigVar( ...@@ -458,7 +454,7 @@ AddConfigVar(
AddConfigVar( AddConfigVar(
"assert_no_cpu_op", "assert_no_cpu_op",
"Raise an error/warning if there is a CPU op in the computational graph.", "Raise an error/warning if there is a CPU op in the computational graph.",
EnumStr("ignore", "warn", "raise", "pdb", mutable=True), EnumStr("ignore", ["warn", "raise", "pdb"], mutable=True),
in_c_key=False, in_c_key=False,
) )
...@@ -583,7 +579,7 @@ if rc == 0 and config.cxx != "": ...@@ -583,7 +579,7 @@ if rc == 0 and config.cxx != "":
AddConfigVar( AddConfigVar(
"linker", "linker",
"Default linker used if the theano flags mode is Mode", "Default linker used if the theano flags mode is Mode",
EnumStr("cvm", "c|py", "py", "c", "c|py_nogc", "vm", "vm_nogc", "cvm_nogc"), EnumStr("cvm", ["c|py", "py", "c", "c|py_nogc", "vm", "vm_nogc", "cvm_nogc"]),
in_c_key=False, in_c_key=False,
) )
else: else:
...@@ -592,7 +588,7 @@ else: ...@@ -592,7 +588,7 @@ else:
AddConfigVar( AddConfigVar(
"linker", "linker",
"Default linker used if the theano flags mode is Mode", "Default linker used if the theano flags mode is Mode",
EnumStr("vm", "py", "vm_nogc"), EnumStr("vm", ["py", "vm_nogc"]),
in_c_key=False, in_c_key=False,
) )
if type(config).cxx.is_default: if type(config).cxx.is_default:
...@@ -623,7 +619,7 @@ AddConfigVar( ...@@ -623,7 +619,7 @@ AddConfigVar(
"optimizer", "optimizer",
"Default optimizer. If not None, will use this optimizer with the Mode", "Default optimizer. If not None, will use this optimizer with the Mode",
EnumStr( EnumStr(
"o4", "o3", "o2", "o1", "unsafe", "fast_run", "fast_compile", "merge", "None" "o4", ["o3", "o2", "o1", "unsafe", "fast_run", "fast_compile", "merge", "None"]
), ),
in_c_key=False, in_c_key=False,
) )
...@@ -641,7 +637,7 @@ AddConfigVar( ...@@ -641,7 +637,7 @@ AddConfigVar(
"What to do when an optimization crashes: warn and skip it, raise " "What to do when an optimization crashes: warn and skip it, raise "
"the exception, or fall into the pdb debugger." "the exception, or fall into the pdb debugger."
), ),
EnumStr("warn", "raise", "pdb", "ignore"), EnumStr("warn", ["raise", "pdb", "ignore"]),
in_c_key=False, in_c_key=False,
) )
...@@ -656,7 +652,7 @@ AddConfigVar( ...@@ -656,7 +652,7 @@ AddConfigVar(
"on_unused_input", "on_unused_input",
"What to do if a variable in the 'inputs' list of " "What to do if a variable in the 'inputs' list of "
" theano.function() is not used in the graph.", " theano.function() is not used in the graph.",
EnumStr("raise", "warn", "ignore"), EnumStr("raise", ["warn", "ignore"]),
in_c_key=False, in_c_key=False,
) )
...@@ -756,7 +752,7 @@ AddConfigVar( ...@@ -756,7 +752,7 @@ AddConfigVar(
"by the following flags: seterr_divide, seterr_over, " "by the following flags: seterr_divide, seterr_over, "
"seterr_under and seterr_invalid.", "seterr_under and seterr_invalid.",
), ),
EnumStr("ignore", "warn", "raise", "call", "print", "log", "None", mutable=False), EnumStr("ignore", ["warn", "raise", "call", "print", "log", "None"], mutable=False),
in_c_key=False, in_c_key=False,
) )
...@@ -766,7 +762,7 @@ AddConfigVar( ...@@ -766,7 +762,7 @@ AddConfigVar(
"Sets numpy's behavior for division by zero, see numpy.seterr. " "Sets numpy's behavior for division by zero, see numpy.seterr. "
"'None' means using the default, defined by numpy.seterr_all." "'None' means using the default, defined by numpy.seterr_all."
), ),
EnumStr("None", "ignore", "warn", "raise", "call", "print", "log", mutable=False), EnumStr("None", ["ignore", "warn", "raise", "call", "print", "log"], mutable=False),
in_c_key=False, in_c_key=False,
) )
...@@ -777,7 +773,7 @@ AddConfigVar( ...@@ -777,7 +773,7 @@ AddConfigVar(
"see numpy.seterr. " "see numpy.seterr. "
"'None' means using the default, defined by numpy.seterr_all." "'None' means using the default, defined by numpy.seterr_all."
), ),
EnumStr("None", "ignore", "warn", "raise", "call", "print", "log", mutable=False), EnumStr("None", ["ignore", "warn", "raise", "call", "print", "log"], mutable=False),
in_c_key=False, in_c_key=False,
) )
...@@ -788,7 +784,7 @@ AddConfigVar( ...@@ -788,7 +784,7 @@ AddConfigVar(
"see numpy.seterr. " "see numpy.seterr. "
"'None' means using the default, defined by numpy.seterr_all." "'None' means using the default, defined by numpy.seterr_all."
), ),
EnumStr("None", "ignore", "warn", "raise", "call", "print", "log", mutable=False), EnumStr("None", ["ignore", "warn", "raise", "call", "print", "log"], mutable=False),
in_c_key=False, in_c_key=False,
) )
...@@ -799,7 +795,7 @@ AddConfigVar( ...@@ -799,7 +795,7 @@ AddConfigVar(
"see numpy.seterr. " "see numpy.seterr. "
"'None' means using the default, defined by numpy.seterr_all." "'None' means using the default, defined by numpy.seterr_all."
), ),
EnumStr("None", "ignore", "warn", "raise", "call", "print", "log", mutable=False), EnumStr("None", ["ignore", "warn", "raise", "call", "print", "log"], mutable=False),
in_c_key=False, in_c_key=False,
) )
...@@ -818,25 +814,27 @@ AddConfigVar( ...@@ -818,25 +814,27 @@ AddConfigVar(
), ),
EnumStr( EnumStr(
"0.9", "0.9",
"None", [
"all", "None",
"0.3", "all",
"0.4", "0.3",
"0.4.1", "0.4",
"0.5", "0.4.1",
"0.6", "0.5",
"0.7", "0.6",
"0.8", "0.7",
"0.8.1", "0.8",
"0.8.2", "0.8.1",
"0.9", "0.8.2",
"0.10", "0.9",
"1.0", "0.10",
"1.0.1", "1.0",
"1.0.2", "1.0.1",
"1.0.3", "1.0.2",
"1.0.4", "1.0.3",
"1.0.5", "1.0.4",
"1.0.5",
],
mutable=False, mutable=False,
), ),
in_c_key=False, in_c_key=False,
...@@ -1005,7 +1003,7 @@ AddConfigVar( ...@@ -1005,7 +1003,7 @@ AddConfigVar(
"to the function. This helps the user track down problems in the " "to the function. This helps the user track down problems in the "
"graph before it gets optimized." "graph before it gets optimized."
), ),
EnumStr("off", "ignore", "warn", "raise", "pdb"), EnumStr("off", ["ignore", "warn", "raise", "pdb"]),
in_c_key=False, in_c_key=False,
) )
...@@ -1030,7 +1028,7 @@ AddConfigVar( ...@@ -1030,7 +1028,7 @@ AddConfigVar(
" Same as compute_test_value, but is used" " Same as compute_test_value, but is used"
" during Theano optimization" " during Theano optimization"
), ),
EnumStr("off", "ignore", "warn", "raise", "pdb"), EnumStr("off", ["ignore", "warn", "raise", "pdb"]),
in_c_key=False, in_c_key=False,
) )
...@@ -1068,7 +1066,7 @@ AddConfigVar( ...@@ -1068,7 +1066,7 @@ AddConfigVar(
A. Elemwise{add_no_inplace} A. Elemwise{add_no_inplace}
B. log_likelihood_v_given_h B. log_likelihood_v_given_h
C. log_likelihood_h""", C. log_likelihood_h""",
EnumStr("low", "high"), EnumStr("low", ["high"]),
in_c_key=False, in_c_key=False,
) )
...@@ -1190,7 +1188,7 @@ AddConfigVar( ...@@ -1190,7 +1188,7 @@ AddConfigVar(
AddConfigVar( AddConfigVar(
"NanGuardMode.action", "NanGuardMode.action",
"What NanGuardMode does when it finds a problem", "What NanGuardMode does when it finds a problem",
EnumStr("raise", "warn", "pdb"), EnumStr("raise", ["warn", "pdb"]),
in_c_key=False, in_c_key=False,
) )
...@@ -1876,7 +1874,7 @@ AddConfigVar( ...@@ -1876,7 +1874,7 @@ AddConfigVar(
AddConfigVar( AddConfigVar(
"on_shape_error", "on_shape_error",
"warn: print a warning and use the default" " value. raise: raise an error", "warn: print a warning and use the default" " value. raise: raise an error",
theano.configparser.EnumStr("warn", "raise"), theano.configparser.EnumStr("warn", ["raise"]),
in_c_key=False, in_c_key=False,
) )
...@@ -1943,7 +1941,7 @@ AddConfigVar( ...@@ -1943,7 +1941,7 @@ AddConfigVar(
"The interaction of which one give the lower peak memory usage is" "The interaction of which one give the lower peak memory usage is"
"complicated and not predictable, so if you are close to the peak" "complicated and not predictable, so if you are close to the peak"
"memory usage, triyng both could give you a small gain.", "memory usage, triyng both could give you a small gain.",
EnumStr("regular", "fast"), EnumStr("regular", ["fast"]),
in_c_key=False, in_c_key=False,
) )
...@@ -1957,7 +1955,7 @@ AddConfigVar( ...@@ -1957,7 +1955,7 @@ AddConfigVar(
"stack trace is inserted that indicates which optimization inserted" "stack trace is inserted that indicates which optimization inserted"
"the variable that had an empty stack trace." "the variable that had an empty stack trace."
"raise: raises an exception if a stack trace is missing", "raise: raises an exception if a stack trace is missing",
EnumStr("off", "log", "warn", "raise"), EnumStr("off", ["log", "warn", "raise"]),
in_c_key=False, in_c_key=False,
) )
......
...@@ -425,23 +425,37 @@ class ConfigParam: ...@@ -425,23 +425,37 @@ class ConfigParam:
class EnumStr(ConfigParam): class EnumStr(ConfigParam):
def __init__(self, default, *options, **kwargs): def __init__(
self, default: str, options: typing.Sequence[str], validate=None, mutable=True
):
"""Creates a str-based parameter that takes a predefined set of options.
Parameters
----------
default : str
The default setting.
options : sequence
Further str values that the parameter may take.
May, but does not need to include the default.
validate : callable
See `ConfigParam`.
mutable : callable
See `ConfigParam`.
"""
self.all = {default, *options} self.all = {default, *options}
# All options should be strings # All options should be strings
for val in self.all: for val in self.all:
if not isinstance(val, str): if not isinstance(val, str):
raise ValueError(f"Non-str value '{val}' for an EnumStr parameter.") raise ValueError(f"Non-str value '{val}' for an EnumStr parameter.")
super().__init__( super().__init__(default, apply=self._apply, validate=validate, mutable=mutable)
default, apply=self._apply, mutable=kwargs.get("mutable", True)
)
def _apply(self, val): def _apply(self, val):
if val in self.all: if val in self.all:
return val return val
else: else:
raise ValueError( raise ValueError(
f'Invalid value ("{val}") for configuration variable "{self.fullname}". ' f"Invalid value ('{val}') for configuration variable '{self.fullname}'. "
f"Valid options are {self.all}" f"Valid options are {self.all}"
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论