Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
3ff717fb
提交
3ff717fb
authored
12月 08, 2020
作者:
Michael Osthege
提交者:
Brandon T. Willard
12月 09, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Prevent and test cross-instance access
上级
d248e219
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
78 行增加
和
30 行删除
+78
-30
test_config.py
tests/test_config.py
+60
-26
configparser.py
theano/configparser.py
+18
-4
没有找到文件。
tests/test_config.py
浏览文件 @
3ff717fb
"""Test config options."""
import
configparser
as
stdlib_configparser
import
logging
from
unittest.mock
import
patch
...
...
@@ -9,6 +10,14 @@ from theano.configdefaults import default_blas_ldflags
from
theano.configparser
import
ConfigParam
def
_create_test_config
():
return
configparser
.
TheanoConfigParser
(
flags_dict
=
{},
theano_cfg
=
stdlib_configparser
.
ConfigParser
(),
theano_raw_cfg
=
stdlib_configparser
.
RawConfigParser
(),
)
def
test_api_deprecation_warning
():
# accessing through configdefaults.config is the new best practice
with
pytest
.
warns
(
None
):
...
...
@@ -29,7 +38,7 @@ def test_api_deprecation_warning():
def
test_api_redirect
():
root
=
configdefaults
.
config
root
=
_create_test_config
()
# one section level
root
.
add
(
"test__section_redirect"
,
...
...
@@ -59,7 +68,7 @@ def test_invalid_default():
# Ensure an invalid default value found in the Theano code only causes
# a crash if it is not overridden by the user.
root
=
configdefaults
.
config
root
=
_create_test_config
()
def
validate
(
val
):
if
val
==
"invalid"
:
...
...
@@ -69,25 +78,24 @@ def test_invalid_default():
# This should raise a ValueError because the default value is
# invalid.
root
.
add
(
"
T_config
__test_invalid_default_a"
,
"
test
__test_invalid_default_a"
,
doc
=
"unittest"
,
configparam
=
ConfigParam
(
"invalid"
,
validate
=
validate
),
in_c_key
=
False
,
)
root
.
_flags_dict
[
"
T_config
__test_invalid_default_b"
]
=
"ok"
root
.
_flags_dict
[
"
test
__test_invalid_default_b"
]
=
"ok"
# This should succeed since we defined a proper value, even
# though the default was invalid.
root
.
add
(
"
T_config
__test_invalid_default_b"
,
"
test
__test_invalid_default_b"
,
doc
=
"unittest"
,
configparam
=
ConfigParam
(
"invalid"
,
validate
=
validate
),
in_c_key
=
False
,
)
# TODO We should remove these dummy options on test exit.
# Check that the flag has been removed
assert
"
T_config
__test_invalid_default_b"
not
in
root
.
_flags_dict
assert
"
test
__test_invalid_default_b"
not
in
root
.
_flags_dict
@patch
(
"theano.configdefaults.try_blas_flag"
,
return_value
=
None
)
...
...
@@ -130,18 +138,17 @@ def test_config_param_apply_and_validation():
def
test_config_hash
():
# TODO: use custom config instance for the test
root
=
configparser
.
config
root
=
_create_test_config
()
root
.
add
(
"test_config_hash"
,
"test_
_
config_hash"
,
"A config var from a test case."
,
configparser
.
StrParam
(
"test_default"
),
)
h0
=
root
.
get_config_hash
()
with
configparser
.
change_flags
(
test
_config_hash
=
"new_value"
):
assert
root
.
test_config_hash
==
"new_value"
with
root
.
change_flags
(
test_
_config_hash
=
"new_value"
):
assert
root
.
test_
_
config_hash
==
"new_value"
h1
=
root
.
get_config_hash
()
h2
=
root
.
get_config_hash
()
...
...
@@ -150,7 +157,7 @@ def test_config_hash():
def
test_config_print
():
root
=
config
parser
.
config
root
=
config
defaults
.
config
result
=
str
(
root
)
assert
isinstance
(
result
,
str
)
...
...
@@ -192,29 +199,56 @@ class TestConfigTypes:
def
test_config_context
():
# TODO: use custom config instance for the test
root
=
configparser
.
config
root
=
_create_test_config
()
root
.
add
(
"test_config_context"
,
"test_
_
config_context"
,
"A config var from a test case."
,
configparser
.
StrParam
(
"test_default"
),
)
assert
hasattr
(
root
,
"test_config_context"
)
assert
root
.
test_config_context
==
"test_default"
assert
hasattr
(
root
,
"test__config_context"
)
assert
root
.
test__config_context
==
"test_default"
with
root
.
change_flags
(
test__config_context
=
"new_value"
):
assert
root
.
test__config_context
==
"new_value"
with
root
.
change_flags
({
"test__config_context"
:
"new_value2"
}):
assert
root
.
test__config_context
==
"new_value2"
assert
root
.
test__config_context
==
"new_value"
assert
root
.
test__config_context
==
"test_default"
def
test_invalid_configvar_access
():
root
=
configdefaults
.
config
root_test
=
_create_test_config
()
# add a setting to the test instance
root_test
.
add
(
"test__on_test_instance"
,
"This config setting was added to the test instance."
,
configparser
.
IntParam
(
5
),
)
assert
hasattr
(
root_test
,
"test__on_test_instance"
)
# While the property _actually_ exists on all instances,
# accessing it through another instance raises an AttributeError.
assert
not
hasattr
(
root
,
"test__on_test_instance"
)
with
configparser
.
change_flags
(
test_config_context
=
"new_value"
):
assert
root
.
test_config_context
==
"new_value"
with
root
.
change_flags
({
"test_config_context"
:
"new_value2"
}):
assert
root
.
test_config_context
==
"new_value2"
assert
root
.
test_config_context
==
"new_value"
assert
root
.
test_config_context
==
"test_default"
# But we can make sure that nothing crazy happens when we access it:
with
pytest
.
raises
(
configparser
.
ConfigAccessViolation
,
match
=
"different instance"
):
print
(
root
.
test__on_test_instance
)
# And also that we can't add two configs of the same name to different instances:
with
pytest
.
raises
(
AttributeError
,
match
=
"already registered"
):
root
.
add
(
"test__on_test_instance"
,
"This config setting was already added to another instance."
,
configparser
.
IntParam
(
5
),
)
def
test_no_more_dotting
():
root
=
config
parser
.
config
root
=
config
defaults
.
config
with
pytest
.
raises
(
ValueError
,
match
=
"Dot-based"
):
root
.
add
(
"
T_config
.something"
,
"
test
.something"
,
doc
=
"unittest"
,
configparam
=
ConfigParam
(
"invalid"
),
in_c_key
=
False
,
...
...
theano/configparser.py
浏览文件 @
3ff717fb
...
...
@@ -28,6 +28,10 @@ class TheanoConfigWarning(Warning):
warn
=
classmethod
(
warn
)
class
ConfigAccessViolation
(
AttributeError
):
""" Raised when a config setting is accessed through the wrong config instance. """
class
_ChangeFlagsDecorator
:
def
__init__
(
self
,
*
args
,
_root
=
None
,
**
kwargs
):
# the old API supported passing a dict as the first argument:
...
...
@@ -167,11 +171,18 @@ class TheanoConfigParser:
raise
ValueError
(
f
"Dot-based sections were removed. Use double underscores! ({name})"
)
if
hasattr
(
self
,
name
):
raise
AttributeError
(
f
"The name {name} is already taken"
)
# Can't use hasattr here, because it returns False upon AttributeErrors
if
name
in
dir
(
self
):
raise
AttributeError
(
f
"A config parameter with the name '{name}' was already registered on another config instance."
)
configparam
.
doc
=
doc
configparam
.
name
=
name
configparam
.
in_c_key
=
in_c_key
# Register it on this instance before the code below already starts accessing it
self
.
_config_var_dict
[
name
]
=
configparam
# Trigger a read of the value from config files and env vars
# This allow to filter wrong value from the user.
if
not
callable
(
configparam
.
default
):
...
...
@@ -193,8 +204,6 @@ class TheanoConfigParser:
# the ConfigParam implements __get__/__set__, enabling us to create a property:
setattr
(
self
.
__class__
,
name
,
configparam
)
# keep the ConfigParam object in a dictionary:
self
.
_config_var_dict
[
name
]
=
configparam
# The old API used dots for accessing a hierarchy of sections.
# The following code adds redirects that spill DeprecationWarnings
...
...
@@ -343,6 +352,11 @@ class ConfigParam:
def
__get__
(
self
,
cls
,
type_
,
delete_key
=
False
):
if
cls
is
None
:
return
self
if
self
.
name
not
in
cls
.
_config_var_dict
:
raise
ConfigAccessViolation
(
f
"The config parameter '{self.name}' was registered on a different instance of the TheanoConfigParser."
f
" It is not accessible through the instance with id '{id(cls)}' because of safeguarding."
)
if
not
hasattr
(
self
,
"val"
):
try
:
val_str
=
cls
.
fetch_val_for_key
(
self
.
name
,
delete_key
=
delete_key
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论