提交 d72eed70 authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Brandon T. Willard

Unify all config parameter types

- Classes instead of functions for the `IntParam`, `StrParam`, etc. - Strong signature for the `ConfigParam` type - `filter` was renamed to `apply`, because that's closer to its functionality - `is_valid` was renamed to `validate` to match `apply` - Both `apply` and `validate` callables can now be set for all params - `DeviceParam` was moved over to where the other config param types are defined - Already deprecated config parameters were removed - Add a few more tests
上级 cbc0a102
""" """Test config options."""
Test config options.
"""
import logging import logging
from unittest.mock import patch from unittest.mock import patch
import pytest
from theano import configparser
from theano.configdefaults import default_blas_ldflags from theano.configdefaults import default_blas_ldflags
from theano.configparser import THEANO_FLAGS_DICT, AddConfigVar, ConfigParam from theano.configparser import THEANO_FLAGS_DICT, AddConfigVar, ConfigParam
...@@ -12,24 +13,23 @@ def test_invalid_default(): ...@@ -12,24 +13,23 @@ def test_invalid_default():
# Ensure an invalid default value found in the Theano code only causes # Ensure an invalid default value found in the Theano code only causes
# a crash if it is not overridden by the user. # a crash if it is not overridden by the user.
def filter(val): def validate(val):
if val == "invalid": if val == "invalid":
raise ValueError() raise ValueError("Test-triggered")
else:
return val
try: with pytest.raises(ValueError, match="Test-triggered"):
# This should raise a ValueError because the default value is # This should raise a ValueError because the default value is
# invalid. # invalid.
AddConfigVar( AddConfigVar(
"T_config.test_invalid_default_a", "T_config.test_invalid_default_a",
doc="unittest", doc="unittest",
configparam=ConfigParam("invalid", filter=filter), configparam=ConfigParam("invalid", validate=validate),
in_c_key=False, in_c_key=False,
) )
raise AssertionError()
except ValueError: THEANO_FLAGS_DICT["T_config.test_invalid_default_b"] = "ok"
pass # This should succeed since we defined a proper value, even
# though the default was invalid.
THEANO_FLAGS_DICT["T_config.test_invalid_default_b"] = "ok" THEANO_FLAGS_DICT["T_config.test_invalid_default_b"] = "ok"
# This should succeed since we defined a proper value, even # This should succeed since we defined a proper value, even
...@@ -37,15 +37,14 @@ def test_invalid_default(): ...@@ -37,15 +37,14 @@ def test_invalid_default():
AddConfigVar( AddConfigVar(
"T_config.test_invalid_default_b", "T_config.test_invalid_default_b",
doc="unittest", doc="unittest",
configparam=ConfigParam("invalid", filter=filter), configparam=ConfigParam("invalid", validate=validate),
in_c_key=False, in_c_key=False,
) )
# TODO We should remove these dummy options on test exit.
# Check that the flag has been removed # Check that the flag has been removed
assert "T_config.test_invalid_default_b" not in THEANO_FLAGS_DICT assert "T_config.test_invalid_default_b" not in THEANO_FLAGS_DICT
# TODO We should remove these dummy options on test exit.
@patch("theano.configdefaults.try_blas_flag", return_value=None) @patch("theano.configdefaults.try_blas_flag", return_value=None)
@patch("theano.configdefaults.sys") @patch("theano.configdefaults.sys")
...@@ -58,3 +57,50 @@ def test_default_blas_ldflags(sys_mock, try_blas_flag_mock, caplog): ...@@ -58,3 +57,50 @@ def test_default_blas_ldflags(sys_mock, try_blas_flag_mock, caplog):
default_blas_ldflags() default_blas_ldflags()
assert "install mkl with" in caplog.text assert "install mkl with" in caplog.text
def test_config_param_apply_and_validation():
cp = ConfigParam(
"TheDeFauLt",
apply=lambda v: v.lower(),
validate=lambda v: v in "thedefault,thesetting",
mutable=True,
)
assert cp.default == "TheDeFauLt"
assert not hasattr(cp, "val")
# can't assign invalid value
with pytest.raises(ValueError, match="Invalid value"):
cp.__set__("cls", "invalid")
assert not hasattr(cp, "val")
# effectivity of apply function
cp.__set__("cls", "THESETTING")
assert cp.val == "thesetting"
# respect the mutability
cp._mutable = False
with pytest.raises(Exception, match="Can't change"):
cp.__set__("cls", "THEDEFAULT")
def test_config_types_bool():
valids = {
True: ["1", 1, True, "true", "True"],
False: ["0", 0, False, "false", "False"],
}
param = configparser.BoolParam(None)
assert isinstance(param, configparser.ConfigParam)
assert param.default is None
for outcome, inputs in valids.items():
for input in inputs:
applied = param.apply(input)
assert applied == outcome
assert param.validate(applied) is not False
with pytest.raises(ValueError, match="Invalid value"):
param.apply("notabool")
差异被折叠。
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论