提交 d72eed70 authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Brandon T. Willard

Unify all config parameter types

- Classes instead of functions for the `IntParam`, `StrParam`, etc. - Strong signature for the `ConfigParam` type - `filter` was renamed to `apply`, because that's closer to its functionality - `is_valid` was renamed to `validate` to match `apply` - Both `apply` and `validate` callables can now be set for all params - `DeviceParam` was moved over to where the other config param types are defined - Already deprecated config parameters were removed - Add a few more tests
上级 cbc0a102
""" """Test config options."""
Test config options.
"""
import logging import logging
from unittest.mock import patch from unittest.mock import patch
import pytest
from theano import configparser
from theano.configdefaults import default_blas_ldflags from theano.configdefaults import default_blas_ldflags
from theano.configparser import THEANO_FLAGS_DICT, AddConfigVar, ConfigParam from theano.configparser import THEANO_FLAGS_DICT, AddConfigVar, ConfigParam
...@@ -12,24 +13,23 @@ def test_invalid_default(): ...@@ -12,24 +13,23 @@ def test_invalid_default():
# Ensure an invalid default value found in the Theano code only causes # Ensure an invalid default value found in the Theano code only causes
# a crash if it is not overridden by the user. # a crash if it is not overridden by the user.
def filter(val): def validate(val):
if val == "invalid": if val == "invalid":
raise ValueError() raise ValueError("Test-triggered")
else:
return val
try: with pytest.raises(ValueError, match="Test-triggered"):
# This should raise a ValueError because the default value is # This should raise a ValueError because the default value is
# invalid. # invalid.
AddConfigVar( AddConfigVar(
"T_config.test_invalid_default_a", "T_config.test_invalid_default_a",
doc="unittest", doc="unittest",
configparam=ConfigParam("invalid", filter=filter), configparam=ConfigParam("invalid", validate=validate),
in_c_key=False, in_c_key=False,
) )
raise AssertionError()
except ValueError: THEANO_FLAGS_DICT["T_config.test_invalid_default_b"] = "ok"
pass # This should succeed since we defined a proper value, even
# though the default was invalid.
THEANO_FLAGS_DICT["T_config.test_invalid_default_b"] = "ok" THEANO_FLAGS_DICT["T_config.test_invalid_default_b"] = "ok"
# This should succeed since we defined a proper value, even # This should succeed since we defined a proper value, even
...@@ -37,15 +37,14 @@ def test_invalid_default(): ...@@ -37,15 +37,14 @@ def test_invalid_default():
AddConfigVar( AddConfigVar(
"T_config.test_invalid_default_b", "T_config.test_invalid_default_b",
doc="unittest", doc="unittest",
configparam=ConfigParam("invalid", filter=filter), configparam=ConfigParam("invalid", validate=validate),
in_c_key=False, in_c_key=False,
) )
# TODO We should remove these dummy options on test exit.
# Check that the flag has been removed # Check that the flag has been removed
assert "T_config.test_invalid_default_b" not in THEANO_FLAGS_DICT assert "T_config.test_invalid_default_b" not in THEANO_FLAGS_DICT
# TODO We should remove these dummy options on test exit.
@patch("theano.configdefaults.try_blas_flag", return_value=None) @patch("theano.configdefaults.try_blas_flag", return_value=None)
@patch("theano.configdefaults.sys") @patch("theano.configdefaults.sys")
...@@ -58,3 +57,50 @@ def test_default_blas_ldflags(sys_mock, try_blas_flag_mock, caplog): ...@@ -58,3 +57,50 @@ def test_default_blas_ldflags(sys_mock, try_blas_flag_mock, caplog):
default_blas_ldflags() default_blas_ldflags()
assert "install mkl with" in caplog.text assert "install mkl with" in caplog.text
def test_config_param_apply_and_validation():
cp = ConfigParam(
"TheDeFauLt",
apply=lambda v: v.lower(),
validate=lambda v: v in "thedefault,thesetting",
mutable=True,
)
assert cp.default == "TheDeFauLt"
assert not hasattr(cp, "val")
# can't assign invalid value
with pytest.raises(ValueError, match="Invalid value"):
cp.__set__("cls", "invalid")
assert not hasattr(cp, "val")
# effectivity of apply function
cp.__set__("cls", "THESETTING")
assert cp.val == "thesetting"
# respect the mutability
cp._mutable = False
with pytest.raises(Exception, match="Can't change"):
cp.__set__("cls", "THEDEFAULT")
def test_config_types_bool():
valids = {
True: ["1", 1, True, "true", "True"],
False: ["0", 0, False, "false", "False"],
}
param = configparser.BoolParam(None)
assert isinstance(param, configparser.ConfigParam)
assert param.default is None
for outcome, inputs in valids.items():
for input in inputs:
applied = param.apply(input)
assert applied == outcome
assert param.validate(applied) is not False
with pytest.raises(ValueError, match="Invalid value"):
param.apply("notabool")
差异被折叠。
...@@ -4,6 +4,7 @@ import logging ...@@ -4,6 +4,7 @@ import logging
import os import os
import shlex import shlex
import sys import sys
import typing
import warnings import warnings
from functools import wraps from functools import wraps
from io import StringIO from io import StringIO
...@@ -317,24 +318,86 @@ def AddConfigVar(name, doc, configparam, root=config, in_c_key=True): ...@@ -317,24 +318,86 @@ def AddConfigVar(name, doc, configparam, root=config, in_c_key=True):
class ConfigParam: class ConfigParam:
def __init__(self, default, filter=None, allow_override=True): """Base class of all kinds of configuration parameters.
A ConfigParam has not only default values and configurable mutability, but
also documentation text, as well as filtering and validation routines
that can be context-dependent.
"""
def __init__(
self,
default: typing.Union[object, typing.Callable[[object], object]],
*,
apply: typing.Optional[typing.Callable[[object], object]] = None,
validate: typing.Optional[typing.Callable[[object], bool]] = None,
mutable: bool = True,
):
""" """
If allow_override is False, we can't change the value after the import Represents a configuration parameter and its associated casting and validation logic.
of Theano. So the value should be the same during all the execution.
Parameters
----------
default : object or callable
A default value, or function that returns a default value for this parameter.
apply : callable, optional
Callable that applies a modification to an input value during assignment.
Typical use cases: type casting or expansion of '~' to user home directory.
validate : callable, optional
A callable that validates the parameter value during assignment.
It may raise an (informative!) exception itself, or simply return True/False.
For example to check the availability of a path, device or to restrict a float into a range.
mutable : bool
If mutable is False, the value of this config settings can not be changed at runtime.
""" """
self.default = default self._default = default
self.filter = filter self._apply = apply
self.allow_override = allow_override self._validate = validate
self._mutable = mutable
self.is_default = True self.is_default = True
# N.B. -- # set by AddConfigVar:
# self.fullname # set by AddConfigVar self.fullname = None
# self.doc # set by AddConfigVar self.doc = None
# Note that we do not call `self.filter` on the default value: this # Note that we do not call `self.filter` on the default value: this
# will be done automatically in AddConfigVar, potentially with a # will be done automatically in AddConfigVar, potentially with a
# more appropriate user-provided default value. # more appropriate user-provided default value.
# Calling `filter` here may actually be harmful if the default value is # Calling `filter` here may actually be harmful if the default value is
# invalid and causes a crash or has unwanted side effects. # invalid and causes a crash or has unwanted side effects.
super().__init__()
@property
def default(self):
return self._default
@property
def mutable(self) -> bool:
return self._mutable
def apply(self, value):
"""Applies modifications to a parameter value during assignment.
Typical use cases are casting or the subsitution of '~' with the user home directory.
"""
if callable(self._apply):
return self._apply(value)
return value
def validate(self, value) -> None:
"""Validates that a parameter values falls into a supported set or range.
Raises
------
ValueError
when the validation turns out negative
"""
if not callable(self._validate):
return True
if self._validate(value) is False:
raise ValueError(
f"Invalid value ({value}) for configuration variable '{self.fullname}'."
)
return True
def __get__(self, cls, type_, delete_key=False): def __get__(self, cls, type_, delete_key=False):
if cls is None: if cls is None:
...@@ -349,41 +412,31 @@ class ConfigParam: ...@@ -349,41 +412,31 @@ class ConfigParam:
else: else:
val_str = self.default val_str = self.default
self.__set__(cls, val_str) self.__set__(cls, val_str)
# print "RVAL", self.val
return self.val return self.val
def __set__(self, cls, val): def __set__(self, cls, val):
if not self.allow_override and hasattr(self, "val"): if not self.mutable and hasattr(self, "val"):
raise Exception( raise Exception(
"Can't change the value of this config parameter " "Can't change the value of {self.fullname} config parameter after initialization!"
"after initialization!"
) )
# print "SETTING PARAM", self.fullname,(cls), val applied = self.apply(val)
if self.filter: self.validate(applied)
self.val = self.filter(val) self.val = applied
else:
self.val = val
class EnumStr(ConfigParam): class EnumStr(ConfigParam):
def __init__(self, default, *options, **kwargs): def __init__(self, default, *options, **kwargs):
self.default = default self.all = {default, *options}
self.all = (default,) + options
# All options should be strings # All options should be strings
for val in self.all: for val in self.all:
if not isinstance(val, str): if not isinstance(val, str):
raise ValueError( raise ValueError(f"Non-str value '{val}' for an EnumStr parameter.")
"Valid values for an EnumStr parameter " "should be strings", super().__init__(
val, default, apply=self._apply, mutable=kwargs.get("mutable", True)
type(val),
) )
convert = kwargs.get("convert", None) def _apply(self, val):
def filter(val):
if convert:
val = convert(val)
if val in self.all: if val in self.all:
return val return val
else: else:
...@@ -392,63 +445,72 @@ class EnumStr(ConfigParam): ...@@ -392,63 +445,72 @@ class EnumStr(ConfigParam):
f"Valid options are {self.all}" f"Valid options are {self.all}"
) )
over = kwargs.get("allow_override", True)
super().__init__(default, filter, over)
def __str__(self): def __str__(self):
return f"{self.fullname} ({self.all}) " return f"{self.fullname} ({self.all}) "
class TypedParam(ConfigParam): class TypedParam(ConfigParam):
def __init__(self, default, mytype, is_valid=None, allow_override=True):
self.mytype = mytype
def filter(val):
cast_val = mytype(val)
if callable(is_valid):
if is_valid(cast_val):
return cast_val
else:
raise ValueError(
f"Invalid value ({val}) for configuration variable "
f'"{self.fullname}".'
)
return cast_val
super().__init__(default, filter, allow_override=allow_override)
def __str__(self): def __str__(self):
return f"{self.fullname} ({self.mytype}) " # The "_apply" callable is the type itself.
return f"{self.fullname} ({self._apply}) "
class StrParam(TypedParam):
def __init__(self, default, validate=None, mutable=True):
super().__init__(default, apply=str, validate=validate, mutable=mutable)
def StrParam(default, is_valid=None, allow_override=True):
return TypedParam(default, str, is_valid, allow_override=allow_override)
class IntParam(TypedParam):
def __init__(self, default, validate=None, mutable=True):
super().__init__(default, apply=int, validate=validate, mutable=mutable)
def IntParam(default, is_valid=None, allow_override=True):
return TypedParam(default, int, is_valid, allow_override=allow_override)
class FloatParam(TypedParam):
def __init__(self, default, validate=None, mutable=True):
super().__init__(default, apply=float, validate=validate, mutable=mutable)
def FloatParam(default, is_valid=None, allow_override=True):
return TypedParam(default, float, is_valid, allow_override=allow_override)
class BoolParam(TypedParam):
"""A boolean parameter that may be initialized from any of the following:
False, 0, "false", "False", "0"
True, 1, "true", "True", "1"
"""
def BoolParam(default, is_valid=None, allow_override=True): def __init__(self, default, validate=None, mutable=True):
# see comment at the beginning of this file. super().__init__(default, apply=self._apply, validate=validate, mutable=mutable)
def booltype(s): def _apply(self, value):
if s in ["False", "false", "0", False]: if value in {False, 0, "false", "False", "0"}:
return False return False
elif s in ["True", "true", "1", True]: elif value in {True, 1, "true", "True", "1"}:
return True return True
raise ValueError(
f"Invalid value ({value}) for configuration variable '{self.fullname}'."
)
def is_valid_bool(s):
if s in ["False", "false", "0", "True", "true", "1", False, True]:
return True
else:
return False
if is_valid is None: class DeviceParam(ConfigParam):
is_valid = is_valid_bool def __init__(self, default, *options, **kwargs):
super().__init__(
default, apply=self._apply, mutable=kwargs.get("mutable", True)
)
def _apply(self, val):
if val == self.default or val.startswith("opencl") or val.startswith("cuda"):
return val
elif val.startswith("gpu"):
raise ValueError(
"You are tring to use the old GPU back-end. "
"It was removed from Theano. Use device=cuda* now. "
"See https://github.com/Theano/Theano/wiki/Converting-to-the-new-gpu-back-end%28gpuarray%29 "
"for more information."
)
else:
raise ValueError(
'Invalid value ("{val}") for configuration '
'variable "{self.fullname}". Valid options start with '
'one of "cpu", "opencl" or "cuda".'
)
return TypedParam(default, booltype, is_valid, allow_override=allow_override) def __str__(self):
return f"{self.fullname} ({self.default}, opencl*, cuda*) "
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论