Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
8cce5c55
提交
8cce5c55
authored
11月 28, 2020
作者:
Michael Osthege
提交者:
Brandon T. Willard
12月 04, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Always pass EnumStr options as sequence
上级
d72eed70
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
105 行增加
和
76 行删除
+105
-76
test_config.py
tests/test_config.py
+36
-19
configdefaults.py
theano/configdefaults.py
+50
-52
configparser.py
theano/configparser.py
+19
-5
没有找到文件。
tests/test_config.py
浏览文件 @
8cce5c55
...
...
@@ -85,22 +85,39 @@ def test_config_param_apply_and_validation():
cp
.
__set__
(
"cls"
,
"THEDEFAULT"
)
def
test_config_types_bool
():
valids
=
{
True
:
[
"1"
,
1
,
True
,
"true"
,
"True"
],
False
:
[
"0"
,
0
,
False
,
"false"
,
"False"
],
}
param
=
configparser
.
BoolParam
(
None
)
assert
isinstance
(
param
,
configparser
.
ConfigParam
)
assert
param
.
default
is
None
for
outcome
,
inputs
in
valids
.
items
():
for
input
in
inputs
:
applied
=
param
.
apply
(
input
)
assert
applied
==
outcome
assert
param
.
validate
(
applied
)
is
not
False
with
pytest
.
raises
(
ValueError
,
match
=
"Invalid value"
):
param
.
apply
(
"notabool"
)
class
TestConfigTypes
:
def
test_bool
(
self
):
valids
=
{
True
:
[
"1"
,
1
,
True
,
"true"
,
"True"
],
False
:
[
"0"
,
0
,
False
,
"false"
,
"False"
],
}
param
=
configparser
.
BoolParam
(
None
)
assert
isinstance
(
param
,
configparser
.
ConfigParam
)
assert
param
.
default
is
None
for
outcome
,
inputs
in
valids
.
items
():
for
input
in
inputs
:
applied
=
param
.
apply
(
input
)
assert
applied
==
outcome
assert
param
.
validate
(
applied
)
is
not
False
with
pytest
.
raises
(
ValueError
,
match
=
"Invalid value"
):
param
.
apply
(
"notabool"
)
pass
def
test_enumstr
(
self
):
cp
=
configparser
.
EnumStr
(
"blue"
,
[
"red"
,
"green"
,
"yellow"
])
assert
len
(
cp
.
all
)
==
4
with
pytest
.
raises
(
ValueError
,
match
=
r"Invalid value \('foo'\)"
):
cp
.
apply
(
"foo"
)
with
pytest
.
raises
(
ValueError
,
match
=
"Non-str value"
):
configparser
.
EnumStr
(
default
=
"red"
,
options
=
[
"red"
,
12
,
"yellow"
])
pass
def
test_deviceparam
(
self
):
cp
=
configparser
.
DeviceParam
(
"cpu"
,
mutable
=
False
)
assert
cp
.
default
==
"cpu"
assert
cp
.
_apply
(
"cuda123"
)
==
"cuda123"
with
pytest
.
raises
(
ValueError
,
match
=
"old GPU back-end"
):
cp
.
_apply
(
"gpu123"
)
with
pytest
.
raises
(
ValueError
,
match
=
"Invalid value"
):
cp
.
_apply
(
"notadevice"
)
assert
str
(
cp
)
==
"None (cpu, opencl*, cuda*) "
theano/configdefaults.py
浏览文件 @
8cce5c55
...
...
@@ -39,11 +39,7 @@ AddConfigVar(
"Default floating-point precision for python casts.
\n
"
"
\n
"
"Note: float16 support is experimental, use at your own risk."
,
EnumStr
(
"float64"
,
"float32"
,
"float16"
,
),
EnumStr
(
"float64"
,
[
"float32"
,
"float16"
]),
# TODO: see gh-4466 for how to remove it.
in_c_key
=
True
,
)
...
...
@@ -53,7 +49,7 @@ AddConfigVar(
"Do an action when a tensor variable with float64 dtype is"
" created. They can't be run on the GPU with the current(old)"
" gpu back-end and are slow with gamer GPUs."
,
EnumStr
(
"ignore"
,
"warn"
,
"raise"
,
"pdb"
),
EnumStr
(
"ignore"
,
[
"warn"
,
"raise"
,
"pdb"
]
),
in_c_key
=
False
,
)
...
...
@@ -70,7 +66,7 @@ AddConfigVar(
"Rules for implicit type casting"
,
EnumStr
(
"custom"
,
"numpy+floatX"
,
[
"numpy+floatX"
]
,
# The 'numpy' policy was originally planned to provide a
# smooth transition from numpy. It was meant to behave the
# same as numpy+floatX, but keeping float64 when numpy
...
...
@@ -89,7 +85,7 @@ AddConfigVar(
AddConfigVar
(
"int_division"
,
"What to do when one computes x / y, where both x and y are of "
"integer types"
,
EnumStr
(
"int"
,
"raise"
,
"floatX"
),
EnumStr
(
"int"
,
[
"raise"
,
"floatX"
]
),
in_c_key
=
False
,
)
...
...
@@ -101,7 +97,7 @@ AddConfigVar(
"non-deterministic implementaion, e.g. when we do not have a GPU "
"implementation that is deterministic. Also see "
"the dnn.conv.algo* flags to cover more cases."
,
EnumStr
(
"default"
,
"more"
),
EnumStr
(
"default"
,
[
"more"
]
),
in_c_key
=
False
,
)
...
...
@@ -218,7 +214,7 @@ AddConfigVar(
CPU overhead when waiting for GPU. One user found that it
speeds up his other processes that was doing data augmentation.
"""
,
EnumStr
(
"default"
,
"multi"
,
"single"
),
EnumStr
(
"default"
,
[
"multi"
,
"single"
]
),
)
AddConfigVar
(
...
...
@@ -325,7 +321,7 @@ SUPPORTED_DNN_CONV_PRECISION = (
AddConfigVar
(
"dnn.conv.algo_fwd"
,
"Default implementation to use for cuDNN forward convolution."
,
EnumStr
(
*
SUPPORTED_DNN_CONV_ALGO_FWD
),
EnumStr
(
"small"
,
SUPPORTED_DNN_CONV_ALGO_FWD
),
in_c_key
=
False
,
)
...
...
@@ -333,7 +329,7 @@ AddConfigVar(
"dnn.conv.algo_bwd_data"
,
"Default implementation to use for cuDNN backward convolution to "
"get the gradients of the convolution with regard to the inputs."
,
EnumStr
(
*
SUPPORTED_DNN_CONV_ALGO_BWD_DATA
),
EnumStr
(
"none"
,
SUPPORTED_DNN_CONV_ALGO_BWD_DATA
),
in_c_key
=
False
,
)
...
...
@@ -342,7 +338,7 @@ AddConfigVar(
"Default implementation to use for cuDNN backward convolution to "
"get the gradients of the convolution with regard to the "
"filters."
,
EnumStr
(
*
SUPPORTED_DNN_CONV_ALGO_BWD_FILTER
),
EnumStr
(
"none"
,
SUPPORTED_DNN_CONV_ALGO_BWD_FILTER
),
in_c_key
=
False
,
)
...
...
@@ -351,7 +347,7 @@ AddConfigVar(
"Default data precision to use for the computation in cuDNN "
"convolutions (defaults to the same dtype as the inputs of the "
"convolutions, or float32 if inputs are float16)."
,
EnumStr
(
*
SUPPORTED_DNN_CONV_PRECISION
),
EnumStr
(
"as_input_f32"
,
SUPPORTED_DNN_CONV_PRECISION
),
in_c_key
=
False
,
)
...
...
@@ -434,7 +430,7 @@ AddConfigVar(
" If True and cuDNN can not be used, raise an error."
" If False, disable cudnn even if present."
" If no_check, assume present and the version between header and library match (so less compilation at context init)"
,
EnumStr
(
"auto"
,
"True"
,
"False"
,
"no_check"
),
EnumStr
(
"auto"
,
[
"True"
,
"False"
,
"no_check"
]
),
in_c_key
=
False
,
)
...
...
@@ -458,7 +454,7 @@ AddConfigVar(
AddConfigVar
(
"assert_no_cpu_op"
,
"Raise an error/warning if there is a CPU op in the computational graph."
,
EnumStr
(
"ignore"
,
"warn"
,
"raise"
,
"pdb"
,
mutable
=
True
),
EnumStr
(
"ignore"
,
[
"warn"
,
"raise"
,
"pdb"
]
,
mutable
=
True
),
in_c_key
=
False
,
)
...
...
@@ -583,7 +579,7 @@ if rc == 0 and config.cxx != "":
AddConfigVar
(
"linker"
,
"Default linker used if the theano flags mode is Mode"
,
EnumStr
(
"cvm"
,
"c|py"
,
"py"
,
"c"
,
"c|py_nogc"
,
"vm"
,
"vm_nogc"
,
"cvm_nogc"
),
EnumStr
(
"cvm"
,
[
"c|py"
,
"py"
,
"c"
,
"c|py_nogc"
,
"vm"
,
"vm_nogc"
,
"cvm_nogc"
]
),
in_c_key
=
False
,
)
else
:
...
...
@@ -592,7 +588,7 @@ else:
AddConfigVar
(
"linker"
,
"Default linker used if the theano flags mode is Mode"
,
EnumStr
(
"vm"
,
"py"
,
"vm_nogc"
),
EnumStr
(
"vm"
,
[
"py"
,
"vm_nogc"
]
),
in_c_key
=
False
,
)
if
type
(
config
)
.
cxx
.
is_default
:
...
...
@@ -623,7 +619,7 @@ AddConfigVar(
"optimizer"
,
"Default optimizer. If not None, will use this optimizer with the Mode"
,
EnumStr
(
"o4"
,
"o3"
,
"o2"
,
"o1"
,
"unsafe"
,
"fast_run"
,
"fast_compile"
,
"merge"
,
"None"
"o4"
,
[
"o3"
,
"o2"
,
"o1"
,
"unsafe"
,
"fast_run"
,
"fast_compile"
,
"merge"
,
"None"
]
),
in_c_key
=
False
,
)
...
...
@@ -641,7 +637,7 @@ AddConfigVar(
"What to do when an optimization crashes: warn and skip it, raise "
"the exception, or fall into the pdb debugger."
),
EnumStr
(
"warn"
,
"raise"
,
"pdb"
,
"ignore"
),
EnumStr
(
"warn"
,
[
"raise"
,
"pdb"
,
"ignore"
]
),
in_c_key
=
False
,
)
...
...
@@ -656,7 +652,7 @@ AddConfigVar(
"on_unused_input"
,
"What to do if a variable in the 'inputs' list of "
" theano.function() is not used in the graph."
,
EnumStr
(
"raise"
,
"warn"
,
"ignore"
),
EnumStr
(
"raise"
,
[
"warn"
,
"ignore"
]
),
in_c_key
=
False
,
)
...
...
@@ -756,7 +752,7 @@ AddConfigVar(
"by the following flags: seterr_divide, seterr_over, "
"seterr_under and seterr_invalid."
,
),
EnumStr
(
"ignore"
,
"warn"
,
"raise"
,
"call"
,
"print"
,
"log"
,
"None"
,
mutable
=
False
),
EnumStr
(
"ignore"
,
[
"warn"
,
"raise"
,
"call"
,
"print"
,
"log"
,
"None"
]
,
mutable
=
False
),
in_c_key
=
False
,
)
...
...
@@ -766,7 +762,7 @@ AddConfigVar(
"Sets numpy's behavior for division by zero, see numpy.seterr. "
"'None' means using the default, defined by numpy.seterr_all."
),
EnumStr
(
"None"
,
"ignore"
,
"warn"
,
"raise"
,
"call"
,
"print"
,
"log"
,
mutable
=
False
),
EnumStr
(
"None"
,
[
"ignore"
,
"warn"
,
"raise"
,
"call"
,
"print"
,
"log"
]
,
mutable
=
False
),
in_c_key
=
False
,
)
...
...
@@ -777,7 +773,7 @@ AddConfigVar(
"see numpy.seterr. "
"'None' means using the default, defined by numpy.seterr_all."
),
EnumStr
(
"None"
,
"ignore"
,
"warn"
,
"raise"
,
"call"
,
"print"
,
"log"
,
mutable
=
False
),
EnumStr
(
"None"
,
[
"ignore"
,
"warn"
,
"raise"
,
"call"
,
"print"
,
"log"
]
,
mutable
=
False
),
in_c_key
=
False
,
)
...
...
@@ -788,7 +784,7 @@ AddConfigVar(
"see numpy.seterr. "
"'None' means using the default, defined by numpy.seterr_all."
),
EnumStr
(
"None"
,
"ignore"
,
"warn"
,
"raise"
,
"call"
,
"print"
,
"log"
,
mutable
=
False
),
EnumStr
(
"None"
,
[
"ignore"
,
"warn"
,
"raise"
,
"call"
,
"print"
,
"log"
]
,
mutable
=
False
),
in_c_key
=
False
,
)
...
...
@@ -799,7 +795,7 @@ AddConfigVar(
"see numpy.seterr. "
"'None' means using the default, defined by numpy.seterr_all."
),
EnumStr
(
"None"
,
"ignore"
,
"warn"
,
"raise"
,
"call"
,
"print"
,
"log"
,
mutable
=
False
),
EnumStr
(
"None"
,
[
"ignore"
,
"warn"
,
"raise"
,
"call"
,
"print"
,
"log"
]
,
mutable
=
False
),
in_c_key
=
False
,
)
...
...
@@ -818,25 +814,27 @@ AddConfigVar(
),
EnumStr
(
"0.9"
,
"None"
,
"all"
,
"0.3"
,
"0.4"
,
"0.4.1"
,
"0.5"
,
"0.6"
,
"0.7"
,
"0.8"
,
"0.8.1"
,
"0.8.2"
,
"0.9"
,
"0.10"
,
"1.0"
,
"1.0.1"
,
"1.0.2"
,
"1.0.3"
,
"1.0.4"
,
"1.0.5"
,
[
"None"
,
"all"
,
"0.3"
,
"0.4"
,
"0.4.1"
,
"0.5"
,
"0.6"
,
"0.7"
,
"0.8"
,
"0.8.1"
,
"0.8.2"
,
"0.9"
,
"0.10"
,
"1.0"
,
"1.0.1"
,
"1.0.2"
,
"1.0.3"
,
"1.0.4"
,
"1.0.5"
,
],
mutable
=
False
,
),
in_c_key
=
False
,
...
...
@@ -1005,7 +1003,7 @@ AddConfigVar(
"to the function. This helps the user track down problems in the "
"graph before it gets optimized."
),
EnumStr
(
"off"
,
"ignore"
,
"warn"
,
"raise"
,
"pdb"
),
EnumStr
(
"off"
,
[
"ignore"
,
"warn"
,
"raise"
,
"pdb"
]
),
in_c_key
=
False
,
)
...
...
@@ -1030,7 +1028,7 @@ AddConfigVar(
" Same as compute_test_value, but is used"
" during Theano optimization"
),
EnumStr
(
"off"
,
"ignore"
,
"warn"
,
"raise"
,
"pdb"
),
EnumStr
(
"off"
,
[
"ignore"
,
"warn"
,
"raise"
,
"pdb"
]
),
in_c_key
=
False
,
)
...
...
@@ -1068,7 +1066,7 @@ AddConfigVar(
A. Elemwise{add_no_inplace}
B. log_likelihood_v_given_h
C. log_likelihood_h"""
,
EnumStr
(
"low"
,
"high"
),
EnumStr
(
"low"
,
[
"high"
]
),
in_c_key
=
False
,
)
...
...
@@ -1190,7 +1188,7 @@ AddConfigVar(
AddConfigVar
(
"NanGuardMode.action"
,
"What NanGuardMode does when it finds a problem"
,
EnumStr
(
"raise"
,
"warn"
,
"pdb"
),
EnumStr
(
"raise"
,
[
"warn"
,
"pdb"
]
),
in_c_key
=
False
,
)
...
...
@@ -1876,7 +1874,7 @@ AddConfigVar(
AddConfigVar
(
"on_shape_error"
,
"warn: print a warning and use the default"
" value. raise: raise an error"
,
theano
.
configparser
.
EnumStr
(
"warn"
,
"raise"
),
theano
.
configparser
.
EnumStr
(
"warn"
,
[
"raise"
]
),
in_c_key
=
False
,
)
...
...
@@ -1943,7 +1941,7 @@ AddConfigVar(
"The interaction of which one give the lower peak memory usage is"
"complicated and not predictable, so if you are close to the peak"
"memory usage, triyng both could give you a small gain."
,
EnumStr
(
"regular"
,
"fast"
),
EnumStr
(
"regular"
,
[
"fast"
]
),
in_c_key
=
False
,
)
...
...
@@ -1957,7 +1955,7 @@ AddConfigVar(
"stack trace is inserted that indicates which optimization inserted"
"the variable that had an empty stack trace."
"raise: raises an exception if a stack trace is missing"
,
EnumStr
(
"off"
,
"log"
,
"warn"
,
"raise"
),
EnumStr
(
"off"
,
[
"log"
,
"warn"
,
"raise"
]
),
in_c_key
=
False
,
)
...
...
theano/configparser.py
浏览文件 @
8cce5c55
...
...
@@ -425,23 +425,37 @@ class ConfigParam:
class
EnumStr
(
ConfigParam
):
def
__init__
(
self
,
default
,
*
options
,
**
kwargs
):
def
__init__
(
self
,
default
:
str
,
options
:
typing
.
Sequence
[
str
],
validate
=
None
,
mutable
=
True
):
"""Creates a str-based parameter that takes a predefined set of options.
Parameters
----------
default : str
The default setting.
options : sequence
Further str values that the parameter may take.
May, but does not need to include the default.
validate : callable
See `ConfigParam`.
mutable : callable
See `ConfigParam`.
"""
self
.
all
=
{
default
,
*
options
}
# All options should be strings
for
val
in
self
.
all
:
if
not
isinstance
(
val
,
str
):
raise
ValueError
(
f
"Non-str value '{val}' for an EnumStr parameter."
)
super
()
.
__init__
(
default
,
apply
=
self
.
_apply
,
mutable
=
kwargs
.
get
(
"mutable"
,
True
)
)
super
()
.
__init__
(
default
,
apply
=
self
.
_apply
,
validate
=
validate
,
mutable
=
mutable
)
def
_apply
(
self
,
val
):
if
val
in
self
.
all
:
return
val
else
:
raise
ValueError
(
f
'Invalid value ("{val}") for configuration variable "{self.fullname}". '
f
"Invalid value ('{val}') for configuration variable '{self.fullname}'. "
f
"Valid options are {self.all}"
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论