提交 3ff717fb authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Brandon T. Willard

Prevent and test cross-instance access

上级 d248e219
"""Test config options.""" """Test config options."""
import configparser as stdlib_configparser
import logging import logging
from unittest.mock import patch from unittest.mock import patch
...@@ -9,6 +10,14 @@ from theano.configdefaults import default_blas_ldflags ...@@ -9,6 +10,14 @@ from theano.configdefaults import default_blas_ldflags
from theano.configparser import ConfigParam from theano.configparser import ConfigParam
def _create_test_config():
return configparser.TheanoConfigParser(
flags_dict={},
theano_cfg=stdlib_configparser.ConfigParser(),
theano_raw_cfg=stdlib_configparser.RawConfigParser(),
)
def test_api_deprecation_warning(): def test_api_deprecation_warning():
# accessing through configdefaults.config is the new best practice # accessing through configdefaults.config is the new best practice
with pytest.warns(None): with pytest.warns(None):
...@@ -29,7 +38,7 @@ def test_api_deprecation_warning(): ...@@ -29,7 +38,7 @@ def test_api_deprecation_warning():
def test_api_redirect(): def test_api_redirect():
root = configdefaults.config root = _create_test_config()
# one section level # one section level
root.add( root.add(
"test__section_redirect", "test__section_redirect",
...@@ -59,7 +68,7 @@ def test_invalid_default(): ...@@ -59,7 +68,7 @@ def test_invalid_default():
# Ensure an invalid default value found in the Theano code only causes # Ensure an invalid default value found in the Theano code only causes
# a crash if it is not overridden by the user. # a crash if it is not overridden by the user.
root = configdefaults.config root = _create_test_config()
def validate(val): def validate(val):
if val == "invalid": if val == "invalid":
...@@ -69,25 +78,24 @@ def test_invalid_default(): ...@@ -69,25 +78,24 @@ def test_invalid_default():
# This should raise a ValueError because the default value is # This should raise a ValueError because the default value is
# invalid. # invalid.
root.add( root.add(
"T_config__test_invalid_default_a", "test__test_invalid_default_a",
doc="unittest", doc="unittest",
configparam=ConfigParam("invalid", validate=validate), configparam=ConfigParam("invalid", validate=validate),
in_c_key=False, in_c_key=False,
) )
root._flags_dict["T_config__test_invalid_default_b"] = "ok" root._flags_dict["test__test_invalid_default_b"] = "ok"
# This should succeed since we defined a proper value, even # This should succeed since we defined a proper value, even
# though the default was invalid. # though the default was invalid.
root.add( root.add(
"T_config__test_invalid_default_b", "test__test_invalid_default_b",
doc="unittest", doc="unittest",
configparam=ConfigParam("invalid", validate=validate), configparam=ConfigParam("invalid", validate=validate),
in_c_key=False, in_c_key=False,
) )
# TODO We should remove these dummy options on test exit.
# Check that the flag has been removed # Check that the flag has been removed
assert "T_config__test_invalid_default_b" not in root._flags_dict assert "test__test_invalid_default_b" not in root._flags_dict
@patch("theano.configdefaults.try_blas_flag", return_value=None) @patch("theano.configdefaults.try_blas_flag", return_value=None)
...@@ -130,18 +138,17 @@ def test_config_param_apply_and_validation(): ...@@ -130,18 +138,17 @@ def test_config_param_apply_and_validation():
def test_config_hash(): def test_config_hash():
# TODO: use custom config instance for the test root = _create_test_config()
root = configparser.config
root.add( root.add(
"test_config_hash", "test__config_hash",
"A config var from a test case.", "A config var from a test case.",
configparser.StrParam("test_default"), configparser.StrParam("test_default"),
) )
h0 = root.get_config_hash() h0 = root.get_config_hash()
with configparser.change_flags(test_config_hash="new_value"): with root.change_flags(test__config_hash="new_value"):
assert root.test_config_hash == "new_value" assert root.test__config_hash == "new_value"
h1 = root.get_config_hash() h1 = root.get_config_hash()
h2 = root.get_config_hash() h2 = root.get_config_hash()
...@@ -150,7 +157,7 @@ def test_config_hash(): ...@@ -150,7 +157,7 @@ def test_config_hash():
def test_config_print(): def test_config_print():
root = configparser.config root = configdefaults.config
result = str(root) result = str(root)
assert isinstance(result, str) assert isinstance(result, str)
...@@ -192,29 +199,56 @@ class TestConfigTypes: ...@@ -192,29 +199,56 @@ class TestConfigTypes:
def test_config_context(): def test_config_context():
# TODO: use custom config instance for the test root = _create_test_config()
root = configparser.config
root.add( root.add(
"test_config_context", "test__config_context",
"A config var from a test case.", "A config var from a test case.",
configparser.StrParam("test_default"), configparser.StrParam("test_default"),
) )
assert hasattr(root, "test_config_context") assert hasattr(root, "test__config_context")
assert root.test_config_context == "test_default" assert root.test__config_context == "test_default"
with root.change_flags(test__config_context="new_value"):
assert root.test__config_context == "new_value"
with root.change_flags({"test__config_context": "new_value2"}):
assert root.test__config_context == "new_value2"
assert root.test__config_context == "new_value"
assert root.test__config_context == "test_default"
def test_invalid_configvar_access():
root = configdefaults.config
root_test = _create_test_config()
# add a setting to the test instance
root_test.add(
"test__on_test_instance",
"This config setting was added to the test instance.",
configparser.IntParam(5),
)
assert hasattr(root_test, "test__on_test_instance")
# While the property _actually_ exists on all instances,
# accessing it through another instance raises an AttributeError.
assert not hasattr(root, "test__on_test_instance")
with configparser.change_flags(test_config_context="new_value"): # But we can make sure that nothing crazy happens when we access it:
assert root.test_config_context == "new_value" with pytest.raises(configparser.ConfigAccessViolation, match="different instance"):
with root.change_flags({"test_config_context": "new_value2"}): print(root.test__on_test_instance)
assert root.test_config_context == "new_value2"
assert root.test_config_context == "new_value" # And also that we can't add two configs of the same name to different instances:
assert root.test_config_context == "test_default" with pytest.raises(AttributeError, match="already registered"):
root.add(
"test__on_test_instance",
"This config setting was already added to another instance.",
configparser.IntParam(5),
)
def test_no_more_dotting(): def test_no_more_dotting():
root = configparser.config root = configdefaults.config
with pytest.raises(ValueError, match="Dot-based"): with pytest.raises(ValueError, match="Dot-based"):
root.add( root.add(
"T_config.something", "test.something",
doc="unittest", doc="unittest",
configparam=ConfigParam("invalid"), configparam=ConfigParam("invalid"),
in_c_key=False, in_c_key=False,
......
...@@ -28,6 +28,10 @@ class TheanoConfigWarning(Warning): ...@@ -28,6 +28,10 @@ class TheanoConfigWarning(Warning):
warn = classmethod(warn) warn = classmethod(warn)
class ConfigAccessViolation(AttributeError):
""" Raised when a config setting is accessed through the wrong config instance. """
class _ChangeFlagsDecorator: class _ChangeFlagsDecorator:
def __init__(self, *args, _root=None, **kwargs): def __init__(self, *args, _root=None, **kwargs):
# the old API supported passing a dict as the first argument: # the old API supported passing a dict as the first argument:
...@@ -167,11 +171,18 @@ class TheanoConfigParser: ...@@ -167,11 +171,18 @@ class TheanoConfigParser:
raise ValueError( raise ValueError(
f"Dot-based sections were removed. Use double underscores! ({name})" f"Dot-based sections were removed. Use double underscores! ({name})"
) )
if hasattr(self, name): # Can't use hasattr here, because it returns False upon AttributeErrors
raise AttributeError(f"The name {name} is already taken") if name in dir(self):
raise AttributeError(
f"A config parameter with the name '{name}' was already registered on another config instance."
)
configparam.doc = doc configparam.doc = doc
configparam.name = name configparam.name = name
configparam.in_c_key = in_c_key configparam.in_c_key = in_c_key
# Register it on this instance before the code below already starts accessing it
self._config_var_dict[name] = configparam
# Trigger a read of the value from config files and env vars # Trigger a read of the value from config files and env vars
# This allow to filter wrong value from the user. # This allow to filter wrong value from the user.
if not callable(configparam.default): if not callable(configparam.default):
...@@ -193,8 +204,6 @@ class TheanoConfigParser: ...@@ -193,8 +204,6 @@ class TheanoConfigParser:
# the ConfigParam implements __get__/__set__, enabling us to create a property: # the ConfigParam implements __get__/__set__, enabling us to create a property:
setattr(self.__class__, name, configparam) setattr(self.__class__, name, configparam)
# keep the ConfigParam object in a dictionary:
self._config_var_dict[name] = configparam
# The old API used dots for accessing a hierarchy of sections. # The old API used dots for accessing a hierarchy of sections.
# The following code adds redirects that spill DeprecationWarnings # The following code adds redirects that spill DeprecationWarnings
...@@ -343,6 +352,11 @@ class ConfigParam: ...@@ -343,6 +352,11 @@ class ConfigParam:
def __get__(self, cls, type_, delete_key=False): def __get__(self, cls, type_, delete_key=False):
if cls is None: if cls is None:
return self return self
if self.name not in cls._config_var_dict:
raise ConfigAccessViolation(
f"The config parameter '{self.name}' was registered on a different instance of the TheanoConfigParser."
f" It is not accessible through the instance with id '{id(cls)}' because of safeguarding."
)
if not hasattr(self, "val"): if not hasattr(self, "val"):
try: try:
val_str = cls.fetch_val_for_key(self.name, delete_key=delete_key) val_str = cls.fetch_val_for_key(self.name, delete_key=delete_key)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论