提交 d72eed70 authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Brandon T. Willard

Unify all config parameter types

- Classes instead of functions for the `IntParam`, `StrParam`, etc. - Strong signature for the `ConfigParam` type - `filter` was renamed to `apply`, because that's closer to its functionality - `is_valid` was renamed to `validate` to match `apply` - Both `apply` and `validate` callables can now be set for all params - `DeviceParam` was moved over to where the other config param types are defined - Already deprecated config parameters were removed - Add a few more tests
上级 cbc0a102
""" """Test config options."""
Test config options.
"""
import logging import logging
from unittest.mock import patch from unittest.mock import patch
import pytest
from theano import configparser
from theano.configdefaults import default_blas_ldflags from theano.configdefaults import default_blas_ldflags
from theano.configparser import THEANO_FLAGS_DICT, AddConfigVar, ConfigParam from theano.configparser import THEANO_FLAGS_DICT, AddConfigVar, ConfigParam
...@@ -12,24 +13,23 @@ def test_invalid_default(): ...@@ -12,24 +13,23 @@ def test_invalid_default():
# Ensure an invalid default value found in the Theano code only causes # Ensure an invalid default value found in the Theano code only causes
# a crash if it is not overridden by the user. # a crash if it is not overridden by the user.
def filter(val): def validate(val):
if val == "invalid": if val == "invalid":
raise ValueError() raise ValueError("Test-triggered")
else:
return val
try: with pytest.raises(ValueError, match="Test-triggered"):
# This should raise a ValueError because the default value is # This should raise a ValueError because the default value is
# invalid. # invalid.
AddConfigVar( AddConfigVar(
"T_config.test_invalid_default_a", "T_config.test_invalid_default_a",
doc="unittest", doc="unittest",
configparam=ConfigParam("invalid", filter=filter), configparam=ConfigParam("invalid", validate=validate),
in_c_key=False, in_c_key=False,
) )
raise AssertionError()
except ValueError: THEANO_FLAGS_DICT["T_config.test_invalid_default_b"] = "ok"
pass # This should succeed since we defined a proper value, even
# though the default was invalid.
THEANO_FLAGS_DICT["T_config.test_invalid_default_b"] = "ok" THEANO_FLAGS_DICT["T_config.test_invalid_default_b"] = "ok"
# This should succeed since we defined a proper value, even # This should succeed since we defined a proper value, even
...@@ -37,15 +37,14 @@ def test_invalid_default(): ...@@ -37,15 +37,14 @@ def test_invalid_default():
AddConfigVar( AddConfigVar(
"T_config.test_invalid_default_b", "T_config.test_invalid_default_b",
doc="unittest", doc="unittest",
configparam=ConfigParam("invalid", filter=filter), configparam=ConfigParam("invalid", validate=validate),
in_c_key=False, in_c_key=False,
) )
# TODO We should remove these dummy options on test exit.
# Check that the flag has been removed # Check that the flag has been removed
assert "T_config.test_invalid_default_b" not in THEANO_FLAGS_DICT assert "T_config.test_invalid_default_b" not in THEANO_FLAGS_DICT
# TODO We should remove these dummy options on test exit.
@patch("theano.configdefaults.try_blas_flag", return_value=None) @patch("theano.configdefaults.try_blas_flag", return_value=None)
@patch("theano.configdefaults.sys") @patch("theano.configdefaults.sys")
...@@ -58,3 +57,50 @@ def test_default_blas_ldflags(sys_mock, try_blas_flag_mock, caplog): ...@@ -58,3 +57,50 @@ def test_default_blas_ldflags(sys_mock, try_blas_flag_mock, caplog):
default_blas_ldflags() default_blas_ldflags()
assert "install mkl with" in caplog.text assert "install mkl with" in caplog.text
def test_config_param_apply_and_validation():
cp = ConfigParam(
"TheDeFauLt",
apply=lambda v: v.lower(),
validate=lambda v: v in "thedefault,thesetting",
mutable=True,
)
assert cp.default == "TheDeFauLt"
assert not hasattr(cp, "val")
# can't assign invalid value
with pytest.raises(ValueError, match="Invalid value"):
cp.__set__("cls", "invalid")
assert not hasattr(cp, "val")
# effectivity of apply function
cp.__set__("cls", "THESETTING")
assert cp.val == "thesetting"
# respect the mutability
cp._mutable = False
with pytest.raises(Exception, match="Can't change"):
cp.__set__("cls", "THEDEFAULT")
def test_config_types_bool():
valids = {
True: ["1", 1, True, "true", "True"],
False: ["0", 0, False, "false", "False"],
}
param = configparser.BoolParam(None)
assert isinstance(param, configparser.ConfigParam)
assert param.default is None
for outcome, inputs in valids.items():
for input in inputs:
applied = param.apply(input)
assert applied == outcome
assert param.validate(applied) is not False
with pytest.raises(ValueError, match="Invalid value"):
param.apply("notabool")
...@@ -17,6 +17,7 @@ from theano.configparser import ( ...@@ -17,6 +17,7 @@ from theano.configparser import (
AddConfigVar, AddConfigVar,
BoolParam, BoolParam,
ConfigParam, ConfigParam,
DeviceParam,
EnumStr, EnumStr,
FloatParam, FloatParam,
IntParam, IntParam,
...@@ -109,38 +110,6 @@ AddConfigVar( ...@@ -109,38 +110,6 @@ AddConfigVar(
# gpuX mean use the gpu number X. # gpuX mean use the gpu number X.
class DeviceParam(ConfigParam):
def __init__(self, default, *options, **kwargs):
self.default = default
def filter(val):
if (
val == self.default
or val.startswith("opencl")
or val.startswith("cuda")
):
return val
elif val.startswith("gpu"):
raise ValueError(
"You are tring to use the old GPU back-end. "
"It was removed from Theano. Use device=cuda* now. "
"See https://github.com/Theano/Theano/wiki/Converting-to-the-new-gpu-back-end%28gpuarray%29 "
"for more information."
)
else:
raise ValueError(
'Invalid value ("{val}") for configuration '
'variable "{self.fullname}". Valid options start with '
'one of "cpu", "opencl" or "cuda".'
)
over = kwargs.get("allow_override", True)
super().__init__(default, filter, over)
def __str__(self):
return f"{self.fullname} ({self.default}, opencl*, cuda*) "
AddConfigVar( AddConfigVar(
"device", "device",
( (
...@@ -148,7 +117,7 @@ AddConfigVar( ...@@ -148,7 +117,7 @@ AddConfigVar(
"default to try to move computation to the GPU. Do not use upper case" "default to try to move computation to the GPU. Do not use upper case"
"letters, only lower case even if NVIDIA uses capital letters." "letters, only lower case even if NVIDIA uses capital letters."
), ),
DeviceParam("cpu", allow_override=False), DeviceParam("cpu", mutable=False),
in_c_key=False, in_c_key=False,
) )
...@@ -160,14 +129,14 @@ AddConfigVar( ...@@ -160,14 +129,14 @@ AddConfigVar(
"nor shared variables, to the specified GPU. " "nor shared variables, to the specified GPU. "
"It can be used to run GPU-specific tests on a particular GPU." "It can be used to run GPU-specific tests on a particular GPU."
), ),
DeviceParam("", allow_override=False), DeviceParam("", mutable=False),
in_c_key=False, in_c_key=False,
) )
AddConfigVar( AddConfigVar(
"force_device", "force_device",
"Raise an error if we can't use the specified device", "Raise an error if we can't use the specified device",
BoolParam(False, allow_override=False), BoolParam(False, mutable=False),
in_c_key=False, in_c_key=False,
) )
...@@ -205,7 +174,7 @@ class ContextsParam(ConfigParam): ...@@ -205,7 +174,7 @@ class ContextsParam(ConfigParam):
raise ValueError(f"Cannot use {s[0]} as context name") raise ValueError(f"Cannot use {s[0]} as context name")
return val return val
ConfigParam.__init__(self, "", filter, False) ConfigParam.__init__(self, "", apply=filter, mutable=False)
AddConfigVar( AddConfigVar(
...@@ -226,23 +195,7 @@ AddConfigVar( ...@@ -226,23 +195,7 @@ AddConfigVar(
AddConfigVar( AddConfigVar(
"print_active_device", "print_active_device",
"Print active device at when the GPU device is initialized.", "Print active device at when the GPU device is initialized.",
BoolParam(True, allow_override=False), BoolParam(True, mutable=False),
in_c_key=False,
)
def deprecated_gpuarray_sync(val):
if val:
raise RuntimeError(
"Flag gpuarray.sync is deprecated and will be removed in next Theano release."
)
return False
AddConfigVar(
"gpuarray.sync",
"""This flag is deprecated and will be removed in next Theano release.""",
ConfigParam(False, allow_override=False, filter=deprecated_gpuarray_sync),
in_c_key=False, in_c_key=False,
) )
...@@ -253,7 +206,7 @@ AddConfigVar( ...@@ -253,7 +206,7 @@ AddConfigVar(
preallocates that fraction of the total GPU memory. If 1 preallocates that fraction of the total GPU memory. If 1
or greater it will preallocate that amount of memory (in or greater it will preallocate that amount of memory (in
megabytes).""", megabytes).""",
FloatParam(0, allow_override=False), FloatParam(0, mutable=False),
in_c_key=False, in_c_key=False,
) )
...@@ -323,62 +276,6 @@ AddConfigVar( ...@@ -323,62 +276,6 @@ AddConfigVar(
) )
def safe_no_dnn_workmem(workmem):
"""
Make sure the user is not attempting to use dnn.conv.workmem`.
"""
if workmem:
raise RuntimeError(
"The option `dnn.conv.workmem` has been removed and should "
"not be used anymore. Please use the option "
"`dnn.conv.algo_fwd` instead."
)
return True
AddConfigVar(
"dnn.conv.workmem",
"This flag is deprecated; use dnn.conv.algo_fwd.",
ConfigParam("", allow_override=False, filter=safe_no_dnn_workmem),
in_c_key=False,
)
def safe_no_dnn_workmem_bwd(workmem):
"""
Make sure the user is not attempting to use dnn.conv.workmem_bwd`.
"""
if workmem:
raise RuntimeError(
"The option `dnn.conv.workmem_bwd` has been removed and "
"should not be used anymore. Please use the options "
"`dnn.conv.algo_bwd_filter` and `dnn.conv.algo_bwd_data` instead."
)
return True
AddConfigVar(
"dnn.conv.workmem_bwd",
"This flag is deprecated; use `dnn.conv.algo_bwd_filter` "
"and `dnn.conv.algo_bwd_data` instead.",
ConfigParam("", allow_override=False, filter=safe_no_dnn_workmem_bwd),
in_c_key=False,
)
def safe_no_dnn_algo_bwd(algo):
"""
Make sure the user is not attempting to use dnn.conv.algo_bwd`.
"""
if algo:
raise RuntimeError(
"The option `dnn.conv.algo_bwd` has been removed and "
"should not be used anymore. Please use the options "
"`dnn.conv.algo_bwd_filter` and `dnn.conv.algo_bwd_data` instead."
)
return True
# Those are the options provided by Theano to choose algorithms at runtime. # Those are the options provided by Theano to choose algorithms at runtime.
SUPPORTED_DNN_CONV_ALGO_RUNTIME = ( SUPPORTED_DNN_CONV_ALGO_RUNTIME = (
"guess_once", "guess_once",
...@@ -425,14 +322,6 @@ SUPPORTED_DNN_CONV_PRECISION = ( ...@@ -425,14 +322,6 @@ SUPPORTED_DNN_CONV_PRECISION = (
"float64", "float64",
) )
AddConfigVar(
"dnn.conv.algo_bwd",
"This flag is deprecated; use dnn.conv.algo_bwd_data and "
"dnn.conv.algo_bwd_filter.",
ConfigParam("", allow_override=False, filter=safe_no_dnn_algo_bwd),
in_c_key=False,
)
AddConfigVar( AddConfigVar(
"dnn.conv.algo_fwd", "dnn.conv.algo_fwd",
"Default implementation to use for cuDNN forward convolution.", "Default implementation to use for cuDNN forward convolution.",
...@@ -569,7 +458,7 @@ AddConfigVar( ...@@ -569,7 +458,7 @@ AddConfigVar(
AddConfigVar( AddConfigVar(
"assert_no_cpu_op", "assert_no_cpu_op",
"Raise an error/warning if there is a CPU op in the computational graph.", "Raise an error/warning if there is a CPU op in the computational graph.",
EnumStr("ignore", "warn", "raise", "pdb", allow_override=True), EnumStr("ignore", "warn", "raise", "pdb", mutable=True),
in_c_key=False, in_c_key=False,
) )
...@@ -604,7 +493,10 @@ def filter_mode(val): ...@@ -604,7 +493,10 @@ def filter_mode(val):
AddConfigVar( AddConfigVar(
"mode", "Default compilation mode", ConfigParam("Mode", filter_mode), in_c_key=False "mode",
"Default compilation mode",
ConfigParam("Mode", apply=filter_mode),
in_c_key=False,
) )
param = "g++" param = "g++"
...@@ -675,7 +567,7 @@ AddConfigVar( ...@@ -675,7 +567,7 @@ AddConfigVar(
" supported, but supporting additional compilers should not be " " supported, but supporting additional compilers should not be "
"too difficult. " "too difficult. "
"If it is empty, no C++ code is compiled.", "If it is empty, no C++ code is compiled.",
StrParam(param, is_valid=warn_cxx), StrParam(param, validate=warn_cxx),
in_c_key=False, in_c_key=False,
) )
del param del param
...@@ -775,7 +667,7 @@ AddConfigVar( ...@@ -775,7 +667,7 @@ AddConfigVar(
AddConfigVar( AddConfigVar(
"tensor.cmp_sloppy", "tensor.cmp_sloppy",
"Relax tensor._allclose (0) not at all, (1) a bit, (2) more", "Relax tensor._allclose (0) not at all, (1) a bit, (2) more",
IntParam(0, lambda i: i in (0, 1, 2), allow_override=False), IntParam(0, lambda i: i in (0, 1, 2), mutable=False),
in_c_key=False, in_c_key=False,
) )
...@@ -864,9 +756,7 @@ AddConfigVar( ...@@ -864,9 +756,7 @@ AddConfigVar(
"by the following flags: seterr_divide, seterr_over, " "by the following flags: seterr_divide, seterr_over, "
"seterr_under and seterr_invalid.", "seterr_under and seterr_invalid.",
), ),
EnumStr( EnumStr("ignore", "warn", "raise", "call", "print", "log", "None", mutable=False),
"ignore", "warn", "raise", "call", "print", "log", "None", allow_override=False
),
in_c_key=False, in_c_key=False,
) )
...@@ -876,9 +766,7 @@ AddConfigVar( ...@@ -876,9 +766,7 @@ AddConfigVar(
"Sets numpy's behavior for division by zero, see numpy.seterr. " "Sets numpy's behavior for division by zero, see numpy.seterr. "
"'None' means using the default, defined by numpy.seterr_all." "'None' means using the default, defined by numpy.seterr_all."
), ),
EnumStr( EnumStr("None", "ignore", "warn", "raise", "call", "print", "log", mutable=False),
"None", "ignore", "warn", "raise", "call", "print", "log", allow_override=False
),
in_c_key=False, in_c_key=False,
) )
...@@ -889,9 +777,7 @@ AddConfigVar( ...@@ -889,9 +777,7 @@ AddConfigVar(
"see numpy.seterr. " "see numpy.seterr. "
"'None' means using the default, defined by numpy.seterr_all." "'None' means using the default, defined by numpy.seterr_all."
), ),
EnumStr( EnumStr("None", "ignore", "warn", "raise", "call", "print", "log", mutable=False),
"None", "ignore", "warn", "raise", "call", "print", "log", allow_override=False
),
in_c_key=False, in_c_key=False,
) )
...@@ -902,9 +788,7 @@ AddConfigVar( ...@@ -902,9 +788,7 @@ AddConfigVar(
"see numpy.seterr. " "see numpy.seterr. "
"'None' means using the default, defined by numpy.seterr_all." "'None' means using the default, defined by numpy.seterr_all."
), ),
EnumStr( EnumStr("None", "ignore", "warn", "raise", "call", "print", "log", mutable=False),
"None", "ignore", "warn", "raise", "call", "print", "log", allow_override=False
),
in_c_key=False, in_c_key=False,
) )
...@@ -915,9 +799,7 @@ AddConfigVar( ...@@ -915,9 +799,7 @@ AddConfigVar(
"see numpy.seterr. " "see numpy.seterr. "
"'None' means using the default, defined by numpy.seterr_all." "'None' means using the default, defined by numpy.seterr_all."
), ),
EnumStr( EnumStr("None", "ignore", "warn", "raise", "call", "print", "log", mutable=False),
"None", "ignore", "warn", "raise", "call", "print", "log", allow_override=False
),
in_c_key=False, in_c_key=False,
) )
...@@ -955,7 +837,7 @@ AddConfigVar( ...@@ -955,7 +837,7 @@ AddConfigVar(
"1.0.3", "1.0.3",
"1.0.4", "1.0.4",
"1.0.5", "1.0.5",
allow_override=False, mutable=False,
), ),
in_c_key=False, in_c_key=False,
) )
...@@ -1166,7 +1048,7 @@ AddConfigVar( ...@@ -1166,7 +1048,7 @@ AddConfigVar(
AddConfigVar( AddConfigVar(
"reoptimize_unpickled_function", "reoptimize_unpickled_function",
"Re-optimize the graph when a theano function is unpickled from the disk.", "Re-optimize the graph when a theano function is unpickled from the disk.",
BoolParam(False, allow_override=True), BoolParam(False, mutable=True),
in_c_key=False, in_c_key=False,
) )
...@@ -1280,7 +1162,7 @@ AddConfigVar( ...@@ -1280,7 +1162,7 @@ AddConfigVar(
"unittests.rseed", "unittests.rseed",
"Seed to use for randomized unit tests. " "Seed to use for randomized unit tests. "
"Special value 'random' means using a seed of None.", "Special value 'random' means using a seed of None.",
StrParam(666, is_valid=good_seed_param), StrParam(666, validate=good_seed_param),
in_c_key=False, in_c_key=False,
) )
...@@ -1318,7 +1200,7 @@ AddConfigVar( ...@@ -1318,7 +1200,7 @@ AddConfigVar(
"When using the default mode, we will remove optimizer with " "When using the default mode, we will remove optimizer with "
"these tags. Separate tags with ':'." "these tags. Separate tags with ':'."
), ),
StrParam("", allow_override=False), StrParam("", mutable=False),
in_c_key=False, in_c_key=False,
) )
...@@ -1328,7 +1210,7 @@ AddConfigVar( ...@@ -1328,7 +1210,7 @@ AddConfigVar(
"When using the default mode, we will add optimizer with " "When using the default mode, we will add optimizer with "
"these tags. Separate tags with ':'." "these tags. Separate tags with ':'."
), ),
StrParam("", allow_override=False), StrParam("", mutable=False),
in_c_key=False, in_c_key=False,
) )
...@@ -1338,7 +1220,7 @@ AddConfigVar( ...@@ -1338,7 +1220,7 @@ AddConfigVar(
"When using the default mode, we will require optimizer with " "When using the default mode, we will require optimizer with "
"these tags. Separate tags with ':'." "these tags. Separate tags with ':'."
), ),
StrParam("", allow_override=False), StrParam("", mutable=False),
in_c_key=False, in_c_key=False,
) )
...@@ -1423,7 +1305,7 @@ AddConfigVar( ...@@ -1423,7 +1305,7 @@ AddConfigVar(
'"wrong_size" (larger and smaller dimensions), and ' '"wrong_size" (larger and smaller dimensions), and '
'"ALL" (all of the above).' '"ALL" (all of the above).'
), ),
StrParam("", is_valid=is_valid_check_preallocated_output_param), StrParam("", validate=is_valid_check_preallocated_output_param),
in_c_key=False, in_c_key=False,
) )
...@@ -1565,7 +1447,7 @@ AddConfigVar( ...@@ -1565,7 +1447,7 @@ AddConfigVar(
AddConfigVar( AddConfigVar(
"cmodule.preload_cache", "cmodule.preload_cache",
"If set to True, will preload the C module cache at import time", "If set to True, will preload the C module cache at import time",
BoolParam(False, allow_override=False), BoolParam(False, mutable=False),
in_c_key=False, in_c_key=False,
) )
...@@ -1573,7 +1455,7 @@ AddConfigVar( ...@@ -1573,7 +1455,7 @@ AddConfigVar(
"cmodule.age_thresh_use", "cmodule.age_thresh_use",
"In seconds. The time after which " "Theano won't reuse a compile c module.", "In seconds. The time after which " "Theano won't reuse a compile c module.",
# 24 days # 24 days
IntParam(60 * 60 * 24 * 24, allow_override=False), IntParam(60 * 60 * 24 * 24, mutable=False),
in_c_key=False, in_c_key=False,
) )
...@@ -1979,7 +1861,7 @@ AddConfigVar( ...@@ -1979,7 +1861,7 @@ AddConfigVar(
" auto detect if lazy evaluation is needed and use the appropriate" " auto detect if lazy evaluation is needed and use the appropriate"
" version. If lazy is True/False, force the version used between" " version. If lazy is True/False, force the version used between"
" Loop/LoopGC and Stack.", " Loop/LoopGC and Stack.",
ConfigParam("None", filter_vm_lazy), ConfigParam("None", apply=filter_vm_lazy),
in_c_key=False, in_c_key=False,
) )
...@@ -2012,7 +1894,7 @@ AddConfigVar( ...@@ -2012,7 +1894,7 @@ AddConfigVar(
" Generates error if not True. Use" " Generates error if not True. Use"
" optimizer_excluding=local_alloc_elemwise" " optimizer_excluding=local_alloc_elemwise"
" to dsiable.", " to dsiable.",
theano.configparser.BoolParam(True, is_valid=lambda x: x), theano.configparser.BoolParam(True),
in_c_key=False, in_c_key=False,
) )
...@@ -2049,7 +1931,7 @@ AddConfigVar( ...@@ -2049,7 +1931,7 @@ AddConfigVar(
AddConfigVar( AddConfigVar(
"compile.wait", "compile.wait",
"""Time to wait before retrying to acquire the compile lock.""", """Time to wait before retrying to acquire the compile lock.""",
IntParam(5, lambda i: i > 0, allow_override=False), IntParam(5, validate=lambda i: i > 0, mutable=False),
in_c_key=False, in_c_key=False,
) )
...@@ -2091,7 +1973,7 @@ override an existing lock. An override only happens when the existing ...@@ -2091,7 +1973,7 @@ override an existing lock. An override only happens when the existing
lock is held by the same owner *and* has not been 'refreshed' by this lock is held by the same owner *and* has not been 'refreshed' by this
owner for more than this period. Refreshes are done every half timeout owner for more than this period. Refreshes are done every half timeout
period for running processes.""", period for running processes.""",
IntParam(_timeout_default, lambda i: i >= 0, allow_override=False), IntParam(_timeout_default, validate=lambda i: i >= 0, mutable=False),
in_c_key=False, in_c_key=False,
) )
...@@ -2240,7 +2122,7 @@ AddConfigVar( ...@@ -2240,7 +2122,7 @@ AddConfigVar(
""" """
) )
), ),
StrParam(default_compiledir_format, allow_override=False), StrParam(default_compiledir_format, mutable=False),
in_c_key=False, in_c_key=False,
) )
...@@ -2331,9 +2213,7 @@ else: ...@@ -2331,9 +2213,7 @@ else:
AddConfigVar( AddConfigVar(
"base_compiledir", "base_compiledir",
"platform-independent root directory for compiled modules", "platform-independent root directory for compiled modules",
ConfigParam( ConfigParam(default_base_compiledir, apply=filter_base_compiledir, mutable=False),
default_base_compiledir, filter=filter_base_compiledir, allow_override=False
),
in_c_key=False, in_c_key=False,
) )
...@@ -2345,7 +2225,7 @@ def default_compiledir(): ...@@ -2345,7 +2225,7 @@ def default_compiledir():
AddConfigVar( AddConfigVar(
"compiledir", "compiledir",
"platform-dependent cache directory for compiled modules", "platform-dependent cache directory for compiled modules",
ConfigParam(default_compiledir, filter=filter_compiledir, allow_override=False), ConfigParam(default_compiledir, apply=filter_compiledir, mutable=False),
in_c_key=False, in_c_key=False,
) )
...@@ -2354,8 +2234,8 @@ AddConfigVar( ...@@ -2354,8 +2234,8 @@ AddConfigVar(
"Directory to cache pre-compiled kernels for the gpuarray backend.", "Directory to cache pre-compiled kernels for the gpuarray backend.",
ConfigParam( ConfigParam(
lambda: os.path.join(config.compiledir, "gpuarray_kernels"), lambda: os.path.join(config.compiledir, "gpuarray_kernels"),
filter=filter_base_compiledir, apply=filter_base_compiledir,
allow_override=False, mutable=False,
), ),
in_c_key=False, in_c_key=False,
) )
...@@ -2365,7 +2245,7 @@ AddConfigVar( ...@@ -2365,7 +2245,7 @@ AddConfigVar(
"Directory which contains the root of Baidu CTC library. It is assumed \ "Directory which contains the root of Baidu CTC library. It is assumed \
that the compiled library is either inside the build, lib or lib64 \ that the compiled library is either inside the build, lib or lib64 \
subdirectory, and the header inside the include directory.", subdirectory, and the header inside the include directory.",
StrParam("", allow_override=False), StrParam("", mutable=False),
in_c_key=False, in_c_key=False,
) )
......
...@@ -4,6 +4,7 @@ import logging ...@@ -4,6 +4,7 @@ import logging
import os import os
import shlex import shlex
import sys import sys
import typing
import warnings import warnings
from functools import wraps from functools import wraps
from io import StringIO from io import StringIO
...@@ -317,24 +318,86 @@ def AddConfigVar(name, doc, configparam, root=config, in_c_key=True): ...@@ -317,24 +318,86 @@ def AddConfigVar(name, doc, configparam, root=config, in_c_key=True):
class ConfigParam: class ConfigParam:
def __init__(self, default, filter=None, allow_override=True): """Base class of all kinds of configuration parameters.
A ConfigParam has not only default values and configurable mutability, but
also documentation text, as well as filtering and validation routines
that can be context-dependent.
"""
def __init__(
self,
default: typing.Union[object, typing.Callable[[object], object]],
*,
apply: typing.Optional[typing.Callable[[object], object]] = None,
validate: typing.Optional[typing.Callable[[object], bool]] = None,
mutable: bool = True,
):
""" """
If allow_override is False, we can't change the value after the import Represents a configuration parameter and its associated casting and validation logic.
of Theano. So the value should be the same during all the execution.
Parameters
----------
default : object or callable
A default value, or function that returns a default value for this parameter.
apply : callable, optional
Callable that applies a modification to an input value during assignment.
Typical use cases: type casting or expansion of '~' to user home directory.
validate : callable, optional
A callable that validates the parameter value during assignment.
It may raise an (informative!) exception itself, or simply return True/False.
For example to check the availability of a path, device or to restrict a float into a range.
mutable : bool
If mutable is False, the value of this config settings can not be changed at runtime.
""" """
self.default = default self._default = default
self.filter = filter self._apply = apply
self.allow_override = allow_override self._validate = validate
self._mutable = mutable
self.is_default = True self.is_default = True
# N.B. -- # set by AddConfigVar:
# self.fullname # set by AddConfigVar self.fullname = None
# self.doc # set by AddConfigVar self.doc = None
# Note that we do not call `self.filter` on the default value: this # Note that we do not call `self.filter` on the default value: this
# will be done automatically in AddConfigVar, potentially with a # will be done automatically in AddConfigVar, potentially with a
# more appropriate user-provided default value. # more appropriate user-provided default value.
# Calling `filter` here may actually be harmful if the default value is # Calling `filter` here may actually be harmful if the default value is
# invalid and causes a crash or has unwanted side effects. # invalid and causes a crash or has unwanted side effects.
super().__init__()
@property
def default(self):
return self._default
@property
def mutable(self) -> bool:
return self._mutable
def apply(self, value):
"""Applies modifications to a parameter value during assignment.
Typical use cases are casting or the subsitution of '~' with the user home directory.
"""
if callable(self._apply):
return self._apply(value)
return value
def validate(self, value) -> None:
"""Validates that a parameter values falls into a supported set or range.
Raises
------
ValueError
when the validation turns out negative
"""
if not callable(self._validate):
return True
if self._validate(value) is False:
raise ValueError(
f"Invalid value ({value}) for configuration variable '{self.fullname}'."
)
return True
def __get__(self, cls, type_, delete_key=False): def __get__(self, cls, type_, delete_key=False):
if cls is None: if cls is None:
...@@ -349,41 +412,31 @@ class ConfigParam: ...@@ -349,41 +412,31 @@ class ConfigParam:
else: else:
val_str = self.default val_str = self.default
self.__set__(cls, val_str) self.__set__(cls, val_str)
# print "RVAL", self.val
return self.val return self.val
def __set__(self, cls, val): def __set__(self, cls, val):
if not self.allow_override and hasattr(self, "val"): if not self.mutable and hasattr(self, "val"):
raise Exception( raise Exception(
"Can't change the value of this config parameter " "Can't change the value of {self.fullname} config parameter after initialization!"
"after initialization!"
) )
# print "SETTING PARAM", self.fullname,(cls), val applied = self.apply(val)
if self.filter: self.validate(applied)
self.val = self.filter(val) self.val = applied
else:
self.val = val
class EnumStr(ConfigParam): class EnumStr(ConfigParam):
def __init__(self, default, *options, **kwargs): def __init__(self, default, *options, **kwargs):
self.default = default self.all = {default, *options}
self.all = (default,) + options
# All options should be strings # All options should be strings
for val in self.all: for val in self.all:
if not isinstance(val, str): if not isinstance(val, str):
raise ValueError( raise ValueError(f"Non-str value '{val}' for an EnumStr parameter.")
"Valid values for an EnumStr parameter " "should be strings", super().__init__(
val, default, apply=self._apply, mutable=kwargs.get("mutable", True)
type(val),
) )
convert = kwargs.get("convert", None) def _apply(self, val):
def filter(val):
if convert:
val = convert(val)
if val in self.all: if val in self.all:
return val return val
else: else:
...@@ -392,63 +445,72 @@ class EnumStr(ConfigParam): ...@@ -392,63 +445,72 @@ class EnumStr(ConfigParam):
f"Valid options are {self.all}" f"Valid options are {self.all}"
) )
over = kwargs.get("allow_override", True)
super().__init__(default, filter, over)
def __str__(self): def __str__(self):
return f"{self.fullname} ({self.all}) " return f"{self.fullname} ({self.all}) "
class TypedParam(ConfigParam): class TypedParam(ConfigParam):
def __init__(self, default, mytype, is_valid=None, allow_override=True):
self.mytype = mytype
def filter(val):
cast_val = mytype(val)
if callable(is_valid):
if is_valid(cast_val):
return cast_val
else:
raise ValueError(
f"Invalid value ({val}) for configuration variable "
f'"{self.fullname}".'
)
return cast_val
super().__init__(default, filter, allow_override=allow_override)
def __str__(self): def __str__(self):
return f"{self.fullname} ({self.mytype}) " # The "_apply" callable is the type itself.
return f"{self.fullname} ({self._apply}) "
class StrParam(TypedParam):
def __init__(self, default, validate=None, mutable=True):
super().__init__(default, apply=str, validate=validate, mutable=mutable)
def StrParam(default, is_valid=None, allow_override=True):
return TypedParam(default, str, is_valid, allow_override=allow_override)
class IntParam(TypedParam):
def __init__(self, default, validate=None, mutable=True):
super().__init__(default, apply=int, validate=validate, mutable=mutable)
def IntParam(default, is_valid=None, allow_override=True):
return TypedParam(default, int, is_valid, allow_override=allow_override)
class FloatParam(TypedParam):
def __init__(self, default, validate=None, mutable=True):
super().__init__(default, apply=float, validate=validate, mutable=mutable)
def FloatParam(default, is_valid=None, allow_override=True):
return TypedParam(default, float, is_valid, allow_override=allow_override)
class BoolParam(TypedParam):
"""A boolean parameter that may be initialized from any of the following:
False, 0, "false", "False", "0"
True, 1, "true", "True", "1"
"""
def BoolParam(default, is_valid=None, allow_override=True): def __init__(self, default, validate=None, mutable=True):
# see comment at the beginning of this file. super().__init__(default, apply=self._apply, validate=validate, mutable=mutable)
def booltype(s): def _apply(self, value):
if s in ["False", "false", "0", False]: if value in {False, 0, "false", "False", "0"}:
return False return False
elif s in ["True", "true", "1", True]: elif value in {True, 1, "true", "True", "1"}:
return True return True
raise ValueError(
f"Invalid value ({value}) for configuration variable '{self.fullname}'."
)
def is_valid_bool(s):
if s in ["False", "false", "0", "True", "true", "1", False, True]:
return True
else:
return False
if is_valid is None: class DeviceParam(ConfigParam):
is_valid = is_valid_bool def __init__(self, default, *options, **kwargs):
super().__init__(
default, apply=self._apply, mutable=kwargs.get("mutable", True)
)
def _apply(self, val):
if val == self.default or val.startswith("opencl") or val.startswith("cuda"):
return val
elif val.startswith("gpu"):
raise ValueError(
"You are tring to use the old GPU back-end. "
"It was removed from Theano. Use device=cuda* now. "
"See https://github.com/Theano/Theano/wiki/Converting-to-the-new-gpu-back-end%28gpuarray%29 "
"for more information."
)
else:
raise ValueError(
'Invalid value ("{val}") for configuration '
'variable "{self.fullname}". Valid options start with '
'one of "cpu", "opencl" or "cuda".'
)
return TypedParam(default, booltype, is_valid, allow_override=allow_override) def __str__(self):
return f"{self.fullname} ({self.default}, opencl*, cuda*) "
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论