提交 3ff717fb authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Brandon T. Willard

Prevent and test cross-instance access

上级 d248e219
"""Test config options."""
import configparser as stdlib_configparser
import logging
from unittest.mock import patch
......@@ -9,6 +10,14 @@ from theano.configdefaults import default_blas_ldflags
from theano.configparser import ConfigParam
def _create_test_config():
return configparser.TheanoConfigParser(
flags_dict={},
theano_cfg=stdlib_configparser.ConfigParser(),
theano_raw_cfg=stdlib_configparser.RawConfigParser(),
)
def test_api_deprecation_warning():
# accessing through configdefaults.config is the new best practice
with pytest.warns(None):
......@@ -29,7 +38,7 @@ def test_api_deprecation_warning():
def test_api_redirect():
root = configdefaults.config
root = _create_test_config()
# one section level
root.add(
"test__section_redirect",
......@@ -59,7 +68,7 @@ def test_invalid_default():
# Ensure an invalid default value found in the Theano code only causes
# a crash if it is not overridden by the user.
root = configdefaults.config
root = _create_test_config()
def validate(val):
if val == "invalid":
......@@ -69,25 +78,24 @@ def test_invalid_default():
# This should raise a ValueError because the default value is
# invalid.
root.add(
"T_config__test_invalid_default_a",
"test__test_invalid_default_a",
doc="unittest",
configparam=ConfigParam("invalid", validate=validate),
in_c_key=False,
)
root._flags_dict["T_config__test_invalid_default_b"] = "ok"
root._flags_dict["test__test_invalid_default_b"] = "ok"
# This should succeed since we defined a proper value, even
# though the default was invalid.
root.add(
"T_config__test_invalid_default_b",
"test__test_invalid_default_b",
doc="unittest",
configparam=ConfigParam("invalid", validate=validate),
in_c_key=False,
)
# TODO We should remove these dummy options on test exit.
# Check that the flag has been removed
assert "T_config__test_invalid_default_b" not in root._flags_dict
assert "test__test_invalid_default_b" not in root._flags_dict
@patch("theano.configdefaults.try_blas_flag", return_value=None)
......@@ -130,18 +138,17 @@ def test_config_param_apply_and_validation():
def test_config_hash():
# TODO: use custom config instance for the test
root = configparser.config
root = _create_test_config()
root.add(
"test_config_hash",
"test__config_hash",
"A config var from a test case.",
configparser.StrParam("test_default"),
)
h0 = root.get_config_hash()
with configparser.change_flags(test_config_hash="new_value"):
assert root.test_config_hash == "new_value"
with root.change_flags(test__config_hash="new_value"):
assert root.test__config_hash == "new_value"
h1 = root.get_config_hash()
h2 = root.get_config_hash()
......@@ -150,7 +157,7 @@ def test_config_hash():
def test_config_print():
root = configparser.config
root = configdefaults.config
result = str(root)
assert isinstance(result, str)
......@@ -192,29 +199,56 @@ class TestConfigTypes:
def test_config_context():
# TODO: use custom config instance for the test
root = configparser.config
root = _create_test_config()
root.add(
"test_config_context",
"test__config_context",
"A config var from a test case.",
configparser.StrParam("test_default"),
)
assert hasattr(root, "test_config_context")
assert root.test_config_context == "test_default"
assert hasattr(root, "test__config_context")
assert root.test__config_context == "test_default"
with root.change_flags(test__config_context="new_value"):
assert root.test__config_context == "new_value"
with root.change_flags({"test__config_context": "new_value2"}):
assert root.test__config_context == "new_value2"
assert root.test__config_context == "new_value"
assert root.test__config_context == "test_default"
def test_invalid_configvar_access():
root = configdefaults.config
root_test = _create_test_config()
# add a setting to the test instance
root_test.add(
"test__on_test_instance",
"This config setting was added to the test instance.",
configparser.IntParam(5),
)
assert hasattr(root_test, "test__on_test_instance")
# While the property _actually_ exists on all instances,
# accessing it through another instance raises an AttributeError.
assert not hasattr(root, "test__on_test_instance")
with configparser.change_flags(test_config_context="new_value"):
assert root.test_config_context == "new_value"
with root.change_flags({"test_config_context": "new_value2"}):
assert root.test_config_context == "new_value2"
assert root.test_config_context == "new_value"
assert root.test_config_context == "test_default"
# But we can make sure that nothing crazy happens when we access it:
with pytest.raises(configparser.ConfigAccessViolation, match="different instance"):
print(root.test__on_test_instance)
# And also that we can't add two configs of the same name to different instances:
with pytest.raises(AttributeError, match="already registered"):
root.add(
"test__on_test_instance",
"This config setting was already added to another instance.",
configparser.IntParam(5),
)
def test_no_more_dotting():
root = configparser.config
root = configdefaults.config
with pytest.raises(ValueError, match="Dot-based"):
root.add(
"T_config.something",
"test.something",
doc="unittest",
configparam=ConfigParam("invalid"),
in_c_key=False,
......
......@@ -28,6 +28,10 @@ class TheanoConfigWarning(Warning):
warn = classmethod(warn)
class ConfigAccessViolation(AttributeError):
""" Raised when a config setting is accessed through the wrong config instance. """
class _ChangeFlagsDecorator:
def __init__(self, *args, _root=None, **kwargs):
# the old API supported passing a dict as the first argument:
......@@ -167,11 +171,18 @@ class TheanoConfigParser:
raise ValueError(
f"Dot-based sections were removed. Use double underscores! ({name})"
)
if hasattr(self, name):
raise AttributeError(f"The name {name} is already taken")
# Can't use hasattr here, because it returns False upon AttributeErrors
if name in dir(self):
raise AttributeError(
f"A config parameter with the name '{name}' was already registered on another config instance."
)
configparam.doc = doc
configparam.name = name
configparam.in_c_key = in_c_key
# Register it on this instance before the code below already starts accessing it
self._config_var_dict[name] = configparam
# Trigger a read of the value from config files and env vars
# This allow to filter wrong value from the user.
if not callable(configparam.default):
......@@ -193,8 +204,6 @@ class TheanoConfigParser:
# the ConfigParam implements __get__/__set__, enabling us to create a property:
setattr(self.__class__, name, configparam)
# keep the ConfigParam object in a dictionary:
self._config_var_dict[name] = configparam
# The old API used dots for accessing a hierarchy of sections.
# The following code adds redirects that spill DeprecationWarnings
......@@ -343,6 +352,11 @@ class ConfigParam:
def __get__(self, cls, type_, delete_key=False):
if cls is None:
return self
if self.name not in cls._config_var_dict:
raise ConfigAccessViolation(
f"The config parameter '{self.name}' was registered on a different instance of the TheanoConfigParser."
f" It is not accessible through the instance with id '{id(cls)}' because of safeguarding."
)
if not hasattr(self, "val"):
try:
val_str = cls.fetch_val_for_key(self.name, delete_key=delete_key)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论