提交 d72eed70 authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Brandon T. Willard

Unify all config parameter types

- Classes instead of functions for the `IntParam`, `StrParam`, etc. - Strong signature for the `ConfigParam` type - `filter` was renamed to `apply`, because that's closer to its functionality - `is_valid` was renamed to `validate` to match `apply` - Both `apply` and `validate` callables can now be set for all params - `DeviceParam` was moved over to where the other config param types are defined - Already deprecated config parameters were removed - Add a few more tests
上级 cbc0a102
"""
Test config options.
"""
"""Test config options."""
import logging
from unittest.mock import patch
import pytest
from theano import configparser
from theano.configdefaults import default_blas_ldflags
from theano.configparser import THEANO_FLAGS_DICT, AddConfigVar, ConfigParam
......@@ -12,24 +13,23 @@ def test_invalid_default():
# Ensure an invalid default value found in the Theano code only causes
# a crash if it is not overridden by the user.
def filter(val):
def validate(val):
if val == "invalid":
raise ValueError()
else:
return val
raise ValueError("Test-triggered")
try:
with pytest.raises(ValueError, match="Test-triggered"):
# This should raise a ValueError because the default value is
# invalid.
AddConfigVar(
"T_config.test_invalid_default_a",
doc="unittest",
configparam=ConfigParam("invalid", filter=filter),
configparam=ConfigParam("invalid", validate=validate),
in_c_key=False,
)
raise AssertionError()
except ValueError:
pass
THEANO_FLAGS_DICT["T_config.test_invalid_default_b"] = "ok"
# This should succeed since we defined a proper value, even
# though the default was invalid.
THEANO_FLAGS_DICT["T_config.test_invalid_default_b"] = "ok"
# This should succeed since we defined a proper value, even
......@@ -37,15 +37,14 @@ def test_invalid_default():
AddConfigVar(
"T_config.test_invalid_default_b",
doc="unittest",
configparam=ConfigParam("invalid", filter=filter),
configparam=ConfigParam("invalid", validate=validate),
in_c_key=False,
)
# TODO We should remove these dummy options on test exit.
# Check that the flag has been removed
assert "T_config.test_invalid_default_b" not in THEANO_FLAGS_DICT
# TODO We should remove these dummy options on test exit.
@patch("theano.configdefaults.try_blas_flag", return_value=None)
@patch("theano.configdefaults.sys")
......@@ -58,3 +57,50 @@ def test_default_blas_ldflags(sys_mock, try_blas_flag_mock, caplog):
default_blas_ldflags()
assert "install mkl with" in caplog.text
def test_config_param_apply_and_validation():
cp = ConfigParam(
"TheDeFauLt",
apply=lambda v: v.lower(),
validate=lambda v: v in "thedefault,thesetting",
mutable=True,
)
assert cp.default == "TheDeFauLt"
assert not hasattr(cp, "val")
# can't assign invalid value
with pytest.raises(ValueError, match="Invalid value"):
cp.__set__("cls", "invalid")
assert not hasattr(cp, "val")
# effectivity of apply function
cp.__set__("cls", "THESETTING")
assert cp.val == "thesetting"
# respect the mutability
cp._mutable = False
with pytest.raises(Exception, match="Can't change"):
cp.__set__("cls", "THEDEFAULT")
def test_config_types_bool():
valids = {
True: ["1", 1, True, "true", "True"],
False: ["0", 0, False, "false", "False"],
}
param = configparser.BoolParam(None)
assert isinstance(param, configparser.ConfigParam)
assert param.default is None
for outcome, inputs in valids.items():
for input in inputs:
applied = param.apply(input)
assert applied == outcome
assert param.validate(applied) is not False
with pytest.raises(ValueError, match="Invalid value"):
param.apply("notabool")
差异被折叠。
......@@ -4,6 +4,7 @@ import logging
import os
import shlex
import sys
import typing
import warnings
from functools import wraps
from io import StringIO
......@@ -317,24 +318,86 @@ def AddConfigVar(name, doc, configparam, root=config, in_c_key=True):
class ConfigParam:
def __init__(self, default, filter=None, allow_override=True):
"""Base class of all kinds of configuration parameters.
A ConfigParam has not only default values and configurable mutability, but
also documentation text, as well as filtering and validation routines
that can be context-dependent.
"""
def __init__(
self,
default: typing.Union[object, typing.Callable[[object], object]],
*,
apply: typing.Optional[typing.Callable[[object], object]] = None,
validate: typing.Optional[typing.Callable[[object], bool]] = None,
mutable: bool = True,
):
"""
If allow_override is False, we can't change the value after the import
of Theano. So the value should be the same during all the execution.
Represents a configuration parameter and its associated casting and validation logic.
Parameters
----------
default : object or callable
A default value, or function that returns a default value for this parameter.
apply : callable, optional
Callable that applies a modification to an input value during assignment.
Typical use cases: type casting or expansion of '~' to user home directory.
validate : callable, optional
A callable that validates the parameter value during assignment.
It may raise an (informative!) exception itself, or simply return True/False.
For example to check the availability of a path, device or to restrict a float into a range.
mutable : bool
If mutable is False, the value of this config settings can not be changed at runtime.
"""
self.default = default
self.filter = filter
self.allow_override = allow_override
self._default = default
self._apply = apply
self._validate = validate
self._mutable = mutable
self.is_default = True
# N.B. --
# self.fullname # set by AddConfigVar
# self.doc # set by AddConfigVar
# set by AddConfigVar:
self.fullname = None
self.doc = None
# Note that we do not call `self.filter` on the default value: this
# will be done automatically in AddConfigVar, potentially with a
# more appropriate user-provided default value.
# Calling `filter` here may actually be harmful if the default value is
# invalid and causes a crash or has unwanted side effects.
super().__init__()
@property
def default(self):
return self._default
@property
def mutable(self) -> bool:
return self._mutable
def apply(self, value):
"""Applies modifications to a parameter value during assignment.
Typical use cases are casting or the subsitution of '~' with the user home directory.
"""
if callable(self._apply):
return self._apply(value)
return value
def validate(self, value) -> None:
"""Validates that a parameter values falls into a supported set or range.
Raises
------
ValueError
when the validation turns out negative
"""
if not callable(self._validate):
return True
if self._validate(value) is False:
raise ValueError(
f"Invalid value ({value}) for configuration variable '{self.fullname}'."
)
return True
def __get__(self, cls, type_, delete_key=False):
if cls is None:
......@@ -349,41 +412,31 @@ class ConfigParam:
else:
val_str = self.default
self.__set__(cls, val_str)
# print "RVAL", self.val
return self.val
def __set__(self, cls, val):
if not self.allow_override and hasattr(self, "val"):
if not self.mutable and hasattr(self, "val"):
raise Exception(
"Can't change the value of this config parameter "
"after initialization!"
"Can't change the value of {self.fullname} config parameter after initialization!"
)
# print "SETTING PARAM", self.fullname,(cls), val
if self.filter:
self.val = self.filter(val)
else:
self.val = val
applied = self.apply(val)
self.validate(applied)
self.val = applied
class EnumStr(ConfigParam):
def __init__(self, default, *options, **kwargs):
self.default = default
self.all = (default,) + options
self.all = {default, *options}
# All options should be strings
for val in self.all:
if not isinstance(val, str):
raise ValueError(
"Valid values for an EnumStr parameter " "should be strings",
val,
type(val),
raise ValueError(f"Non-str value '{val}' for an EnumStr parameter.")
super().__init__(
default, apply=self._apply, mutable=kwargs.get("mutable", True)
)
convert = kwargs.get("convert", None)
def filter(val):
if convert:
val = convert(val)
def _apply(self, val):
if val in self.all:
return val
else:
......@@ -392,63 +445,72 @@ class EnumStr(ConfigParam):
f"Valid options are {self.all}"
)
over = kwargs.get("allow_override", True)
super().__init__(default, filter, over)
def __str__(self):
return f"{self.fullname} ({self.all}) "
class TypedParam(ConfigParam):
def __init__(self, default, mytype, is_valid=None, allow_override=True):
self.mytype = mytype
def filter(val):
cast_val = mytype(val)
if callable(is_valid):
if is_valid(cast_val):
return cast_val
else:
raise ValueError(
f"Invalid value ({val}) for configuration variable "
f'"{self.fullname}".'
)
return cast_val
super().__init__(default, filter, allow_override=allow_override)
def __str__(self):
return f"{self.fullname} ({self.mytype}) "
# The "_apply" callable is the type itself.
return f"{self.fullname} ({self._apply}) "
class StrParam(TypedParam):
def __init__(self, default, validate=None, mutable=True):
super().__init__(default, apply=str, validate=validate, mutable=mutable)
def StrParam(default, is_valid=None, allow_override=True):
return TypedParam(default, str, is_valid, allow_override=allow_override)
class IntParam(TypedParam):
def __init__(self, default, validate=None, mutable=True):
super().__init__(default, apply=int, validate=validate, mutable=mutable)
def IntParam(default, is_valid=None, allow_override=True):
return TypedParam(default, int, is_valid, allow_override=allow_override)
class FloatParam(TypedParam):
def __init__(self, default, validate=None, mutable=True):
super().__init__(default, apply=float, validate=validate, mutable=mutable)
def FloatParam(default, is_valid=None, allow_override=True):
return TypedParam(default, float, is_valid, allow_override=allow_override)
class BoolParam(TypedParam):
"""A boolean parameter that may be initialized from any of the following:
False, 0, "false", "False", "0"
True, 1, "true", "True", "1"
"""
def BoolParam(default, is_valid=None, allow_override=True):
# see comment at the beginning of this file.
def __init__(self, default, validate=None, mutable=True):
super().__init__(default, apply=self._apply, validate=validate, mutable=mutable)
def booltype(s):
if s in ["False", "false", "0", False]:
def _apply(self, value):
if value in {False, 0, "false", "False", "0"}:
return False
elif s in ["True", "true", "1", True]:
elif value in {True, 1, "true", "True", "1"}:
return True
raise ValueError(
f"Invalid value ({value}) for configuration variable '{self.fullname}'."
)
def is_valid_bool(s):
if s in ["False", "false", "0", "True", "true", "1", False, True]:
return True
else:
return False
if is_valid is None:
is_valid = is_valid_bool
class DeviceParam(ConfigParam):
def __init__(self, default, *options, **kwargs):
super().__init__(
default, apply=self._apply, mutable=kwargs.get("mutable", True)
)
def _apply(self, val):
if val == self.default or val.startswith("opencl") or val.startswith("cuda"):
return val
elif val.startswith("gpu"):
raise ValueError(
"You are tring to use the old GPU back-end. "
"It was removed from Theano. Use device=cuda* now. "
"See https://github.com/Theano/Theano/wiki/Converting-to-the-new-gpu-back-end%28gpuarray%29 "
"for more information."
)
else:
raise ValueError(
'Invalid value ("{val}") for configuration '
'variable "{self.fullname}". Valid options start with '
'one of "cpu", "opencl" or "cuda".'
)
return TypedParam(default, booltype, is_valid, allow_override=allow_override)
def __str__(self):
return f"{self.fullname} ({self.default}, opencl*, cuda*) "
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论