提交 db2e6901 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Yet more pep8 fixes.

上级 9ac04ab4
...@@ -95,7 +95,6 @@ import scalar ...@@ -95,7 +95,6 @@ import scalar
#import sparse #import sparse
import gradient import gradient
from gradient import Rop, Lop, grad from gradient import Rop, Lop, grad
import gof
if config.device.startswith('gpu') or config.init_gpu_device.startswith('gpu'): if config.device.startswith('gpu') or config.init_gpu_device.startswith('gpu'):
import theano.sandbox.cuda import theano.sandbox.cuda
......
import os import os
import logging import logging
import subprocess import subprocess
import sys
from theano.configparser import ( from theano.configparser import (
AddConfigVar, BoolParam, ConfigParam, EnumStr, IntParam, FloatParam, AddConfigVar, BoolParam, ConfigParam, EnumStr, IntParam,
StrParam, TheanoConfigParser) TheanoConfigParser)
_logger = logging.getLogger('theano.configdefaults') _logger = logging.getLogger('theano.configdefaults')
...@@ -41,10 +40,13 @@ AddConfigVar('int_division', ...@@ -41,10 +40,13 @@ AddConfigVar('int_division',
EnumStr('int', 'raise', 'floatX'), EnumStr('int', 'raise', 'floatX'),
in_c_key=False) in_c_key=False)
#gpu mean let the driver select the gpu. Needed in case of gpu in exclusive mode. # gpu means let the driver select the gpu. Needed in case of gpu in
#gpuX mean use the gpu number X. # exclusive mode.
# gpuX mean use the gpu number X.
AddConfigVar('device', AddConfigVar('device',
"Default device for computations. If gpu*, change the default to try to move computation to it and to put shared variable of float32 on it.", ("Default device for computations. If gpu*, change the default to try "
"to move computation to it and to put shared variable of float32 "
"on it."),
EnumStr('cpu', 'gpu', EnumStr('cpu', 'gpu',
'gpu0', 'gpu1', 'gpu2', 'gpu3', 'gpu0', 'gpu1', 'gpu2', 'gpu3',
'gpu4', 'gpu5', 'gpu6', 'gpu7', 'gpu4', 'gpu5', 'gpu6', 'gpu7',
...@@ -93,14 +95,16 @@ try: ...@@ -93,14 +95,16 @@ try:
stdin=dummy_stdin.fileno()) stdin=dummy_stdin.fileno())
# Keep the default linker the same as the one for the mode FAST_RUN # Keep the default linker the same as the one for the mode FAST_RUN
AddConfigVar('linker', AddConfigVar('linker',
"Default linker used if the theano flags mode is Mode or ProfileMode", ("Default linker used if the theano flags mode is Mode "
"or ProfileMode"),
EnumStr('c|py', 'py', 'c', 'c|py_nogc', 'c&py', EnumStr('c|py', 'py', 'c', 'c|py_nogc', 'c&py',
'vm', 'cvm', 'vm_nogc', 'cvm_nogc'), 'vm', 'cvm', 'vm_nogc', 'cvm_nogc'),
in_c_key=False) in_c_key=False)
except OSError: except OSError:
# g++ is not present, linker should default to python only # g++ is not present, linker should default to python only
AddConfigVar('linker', AddConfigVar('linker',
"Default linker used if the theano flags mode is Mode or ProfileMode", ("Default linker used if the theano flags mode is Mode "
"or ProfileMode"),
EnumStr('py', 'c|py', 'c', 'c|py_nogc', 'c&py', EnumStr('py', 'c|py', 'c', 'c|py_nogc', 'c&py',
'vm', 'cvm', 'vm_nogc', 'cvm_nogc'), 'vm', 'cvm', 'vm_nogc', 'cvm_nogc'),
in_c_key=False) in_c_key=False)
...@@ -113,12 +117,14 @@ del dummy_stdin ...@@ -113,12 +117,14 @@ del dummy_stdin
#Keep the default optimizer the same as the one for the mode FAST_RUN #Keep the default optimizer the same as the one for the mode FAST_RUN
AddConfigVar('optimizer', AddConfigVar('optimizer',
"Default optimizer. If not None, will use this linker with the Mode object(not ProfileMode or DebugMode)", ("Default optimizer. If not None, will use this linker with the Mode "
"object (not ProfileMode or DebugMode)"),
EnumStr('fast_run', 'merge', 'fast_compile', 'None'), EnumStr('fast_run', 'merge', 'fast_compile', 'None'),
in_c_key=False) in_c_key=False)
AddConfigVar('on_opt_error', AddConfigVar('on_opt_error',
"What to do when an optimization crashes: warn and skip it, or raise the exception", ("What to do when an optimization crashes: warn and skip it, or raise "
"the exception"),
EnumStr('warn', 'raise'), EnumStr('warn', 'raise'),
in_c_key=False) in_c_key=False)
...@@ -160,16 +166,18 @@ AddConfigVar('nocleanup', ...@@ -160,16 +166,18 @@ AddConfigVar('nocleanup',
# changed at runtime. # changed at runtime.
AddConfigVar('tensor.cmp_sloppy', AddConfigVar('tensor.cmp_sloppy',
"Relax tensor._allclose (0) not at all, (1) a bit, (2) more", "Relax tensor._allclose (0) not at all, (1) a bit, (2) more",
IntParam(0, lambda i: i in (0,1,2), allow_override=False), IntParam(0, lambda i: i in (0, 1, 2), allow_override=False),
in_c_key=False) in_c_key=False)
AddConfigVar('tensor.local_elemwise_fusion', AddConfigVar('tensor.local_elemwise_fusion',
"Enable or not in fast_run mode(fast_run optimization) the elemwise fusion optimization", ("Enable or not in fast_run mode(fast_run optimization) the elemwise "
"fusion optimization"),
BoolParam(True), BoolParam(True),
in_c_key=False) in_c_key=False)
AddConfigVar('gpu.local_elemwise_fusion', AddConfigVar('gpu.local_elemwise_fusion',
"Enable or not in fast_run mode(fast_run optimization) the gpu elemwise fusion optimization", ("Enable or not in fast_run mode(fast_run optimization) the gpu "
"elemwise fusion optimization"),
BoolParam(True), BoolParam(True),
in_c_key=False) in_c_key=False)
...@@ -179,7 +187,8 @@ AddConfigVar('lib.amdlibm', ...@@ -179,7 +187,8 @@ AddConfigVar('lib.amdlibm',
BoolParam(False)) BoolParam(False))
AddConfigVar('op.set_flops', AddConfigVar('op.set_flops',
"currently used only in ConvOp. The profile mode will print the flops/s for the op.", ("currently used only in ConvOp. The profile mode will print the "
"flops/s for the op."),
BoolParam(False), BoolParam(False),
in_c_key=False) in_c_key=False)
...@@ -244,8 +253,14 @@ AddConfigVar('numpy.seterr_invalid', ...@@ -244,8 +253,14 @@ AddConfigVar('numpy.seterr_invalid',
### To disable some warning about old bug that are fixed now. ### To disable some warning about old bug that are fixed now.
### ###
AddConfigVar('warn.ignore_bug_before', AddConfigVar('warn.ignore_bug_before',
"If 'None', we warn about all Theano bugs found by default. If 'all', we don't warn about Theano bugs found by default. If a version, we print only the warnings relative to Theano bugs found after that version. Warning for specific bugs can be configured with specific [warn] flags.", ("If 'None', we warn about all Theano bugs found by default. "
EnumStr('None', 'all', '0.3','0.4', '0.4.1', '0.5', allow_override=False), "If 'all', we don't warn about Theano bugs found by default. "
"If a version, we print only the warnings relative to Theano "
"bugs found after that version. "
"Warning for specific bugs can be configured with specific "
"[warn] flags."),
EnumStr('None', 'all', '0.3', '0.4', '0.4.1', '0.5',
allow_override=False),
in_c_key=False) in_c_key=False)
...@@ -263,34 +278,48 @@ def warn_default(version): ...@@ -263,34 +278,48 @@ def warn_default(version):
AddConfigVar('warn.argmax_pushdown_bug', AddConfigVar('warn.argmax_pushdown_bug',
"Warn if in past version of Theano we generated a bug with the theano.tensor.nnet.nnet.local_argmax_pushdown optimization. Was fixed 27 may 2010", ("Warn if in past version of Theano we generated a bug with the "
"theano.tensor.nnet.nnet.local_argmax_pushdown optimization. "
"Was fixed 27 may 2010"),
BoolParam(warn_default('0.3')), BoolParam(warn_default('0.3')),
in_c_key=False) in_c_key=False)
AddConfigVar('warn.gpusum_01_011_0111_bug', AddConfigVar('warn.gpusum_01_011_0111_bug',
"Warn if we are in a case where old version of Theano had a silent bug with GpuSum pattern 01,011 and 0111 when the first dimensions was bigger then 4096. Was fixed 31 may 2010", ("Warn if we are in a case where old version of Theano had a "
"silent bug with GpuSum pattern 01,011 and 0111 when the first "
"dimensions was bigger then 4096. Was fixed 31 may 2010"),
BoolParam(warn_default('0.3')), BoolParam(warn_default('0.3')),
in_c_key=False) in_c_key=False)
AddConfigVar('warn.sum_sum_bug', AddConfigVar('warn.sum_sum_bug',
"Warn if we are in a case where Theano version between version 9923a40c7b7a and the 2 august 2010(fixed date), generated an error in that case. This happen when their is 2 consecutive sum in the graph, bad code was generated. Was fixed 2 August 2010", ("Warn if we are in a case where Theano version between version "
"9923a40c7b7a and the 2 august 2010 (fixed date), generated an "
"error in that case. This happens when there are 2 consecutive "
"sums in the graph, bad code was generated. "
"Was fixed 2 August 2010"),
BoolParam(warn_default('0.3')), BoolParam(warn_default('0.3')),
in_c_key=False) in_c_key=False)
AddConfigVar('warn.sum_div_dimshuffle_bug', AddConfigVar('warn.sum_div_dimshuffle_bug',
"Warn if previous versions of Theano (between rev. 3bd9b789f5e8, 2010-06-16, and cfc6322e5ad4, 2010-08-03) would have given incorrect result. This bug was triggered by sum of division of dimshuffled tensors.", ("Warn if previous versions of Theano (between rev. "
"3bd9b789f5e8, 2010-06-16, and cfc6322e5ad4, 2010-08-03) "
"would have given incorrect result. This bug was triggered by "
"sum of division of dimshuffled tensors."),
BoolParam(warn_default('0.3')), BoolParam(warn_default('0.3')),
in_c_key=False) in_c_key=False)
AddConfigVar('warn.subtensor_merge_bug', AddConfigVar('warn.subtensor_merge_bug',
"Warn if previous versions of Theano (before 0.5rc2) could have given " "Warn if previous versions of Theano (before 0.5rc2) could have given "
"incorrect results when indexing into a subtensor with negative stride " "incorrect results when indexing into a subtensor with negative "
"(for instance, for instance, x[a:b:-1][c]).", "stride (for instance, for instance, x[a:b:-1][c]).",
BoolParam(warn_default('0.5')), BoolParam(warn_default('0.5')),
in_c_key=False) in_c_key=False)
AddConfigVar('compute_test_value', AddConfigVar('compute_test_value',
"If 'True', Theano will run each op at graph build time, using Constants, SharedVariables and the tag 'test_value' as inputs to the function. This helps the user track down problems in the graph before it gets optimized.", ("If 'True', Theano will run each op at graph build time, using "
"Constants, SharedVariables and the tag 'test_value' as inputs "
"to the function. This helps the user track down problems in the "
"graph before it gets optimized."),
EnumStr('off', 'ignore', 'warn', 'raise'), EnumStr('off', 'ignore', 'warn', 'raise'),
in_c_key=False) in_c_key=False)
...@@ -310,5 +339,5 @@ AddConfigVar('exception_verbosity', ...@@ -310,5 +339,5 @@ AddConfigVar('exception_verbosity',
A. Elemwise{add_no_inplace} A. Elemwise{add_no_inplace}
B. log_likelihood_v_given_h B. log_likelihood_v_given_h
C. log_likelihood_h""", C. log_likelihood_h""",
EnumStr('low','high'), EnumStr('low', 'high'),
in_c_key=False) in_c_key=False)
# For flag of bool type, we consider the string 'False','false' and '0' as False # For flag of bool type, we consider the strings 'False', 'false' and '0'
# and the string 'True', 'true', '1' as true. # as False, and the string s'True', 'true', '1' as True.
# We also accept the bool type as its corresponding value! # We also accept the bool type as its corresponding value!
import os, StringIO, sys
import ConfigParser
import logging import logging
import os
import sys
import warnings import warnings
import ConfigParser
import StringIO
import theano import theano
_logger = logging.getLogger('theano.configparser') _logger = logging.getLogger('theano.configparser')
class TheanoConfigWarning(Warning): class TheanoConfigWarning(Warning):
def warn(cls, message, stacklevel=0): def warn(cls, message, stacklevel=0):
...@@ -21,16 +24,18 @@ class TheanoConfigWarning(Warning): ...@@ -21,16 +24,18 @@ class TheanoConfigWarning(Warning):
for key in os.environ: for key in os.environ:
if key.startswith("THEANO"): if key.startswith("THEANO"):
if key not in ("THEANO_FLAGS", "THEANORC"): if key not in ("THEANO_FLAGS", "THEANORC"):
TheanoConfigWarning.warn("Ignoring deprecated environment variable %s" % key) TheanoConfigWarning.warn(
"Ignoring deprecated environment variable %s" % key)
THEANO_FLAGS = os.getenv("THEANO_FLAGS", "") THEANO_FLAGS = os.getenv("THEANO_FLAGS", "")
# The THEANO_FLAGS environment variable should be a list of comma-separated # The THEANO_FLAGS environment variable should be a list of comma-separated
# [section.]option=value entries. If the section part is omitted, their should be only one # [section.]option=value entries. If the section part is omitted, there should
# section that contains the given option. # be only one section that contains the given option.
def parse_config_string(config_string, issue_warnings=True): def parse_config_string(config_string, issue_warnings=True):
""" """
Parses a config string composed of comma-separated key=value components into a dict. Parses a config string (comma-separated key=value components) into a dict.
""" """
config_dict = {} config_dict = {}
for kv_pair in THEANO_FLAGS.split(','): for kv_pair in THEANO_FLAGS.split(','):
...@@ -40,7 +45,10 @@ def parse_config_string(config_string, issue_warnings=True): ...@@ -40,7 +45,10 @@ def parse_config_string(config_string, issue_warnings=True):
kv_tuple = kv_pair.split('=', 1) kv_tuple = kv_pair.split('=', 1)
if len(kv_tuple) == 1: if len(kv_tuple) == 1:
if issue_warnings: if issue_warnings:
TheanoConfigWarning.warn("Config key '%s' has no value, ignoring it" % kv_tuple[0], stacklevel=1) TheanoConfigWarning.warn(
("Config key '%s' has no value, ignoring it"
% kv_tuple[0]),
stacklevel=1)
else: else:
k, v = kv_tuple k, v = kv_tuple
# subsequent values for k will override earlier ones # subsequent values for k will override earlier ones
...@@ -49,12 +57,14 @@ def parse_config_string(config_string, issue_warnings=True): ...@@ -49,12 +57,14 @@ def parse_config_string(config_string, issue_warnings=True):
THEANO_FLAGS_DICT = parse_config_string(THEANO_FLAGS, issue_warnings=True) THEANO_FLAGS_DICT = parse_config_string(THEANO_FLAGS, issue_warnings=True)
# THEANORC can contain a colon-delimited list of config files, like # THEANORC can contain a colon-delimited list of config files, like
# THEANORC=~lisa/.theanorc:~/.theanorc # THEANORC=~lisa/.theanorc:~/.theanorc
# In that case, definitions in files on the right (here, ~/.theanorc) have # In that case, definitions in files on the right (here, ~/.theanorc) have
# precedence over those in files on the left. # precedence over those in files on the left.
def config_files_from_theanorc(): def config_files_from_theanorc():
rval = [os.path.expanduser(s) for s in os.getenv('THEANORC', '~/.theanorc').split(os.pathsep)] rval = [os.path.expanduser(s) for s in
os.getenv('THEANORC', '~/.theanorc').split(os.pathsep)]
if os.getenv('THEANORC') is None and sys.platform == "win32": if os.getenv('THEANORC') is None and sys.platform == "win32":
# to don't need to change the filename and make it open easily # to don't need to change the filename and make it open easily
rval.append(os.path.expanduser('~/.theanorc.txt')) rval.append(os.path.expanduser('~/.theanorc.txt'))
...@@ -62,7 +72,9 @@ def config_files_from_theanorc(): ...@@ -62,7 +72,9 @@ def config_files_from_theanorc():
config_files = config_files_from_theanorc() config_files = config_files_from_theanorc()
theano_cfg = ConfigParser.SafeConfigParser({'USER': os.getenv("USER", os.path.split(os.path.expanduser('~'))[-1])}) theano_cfg = ConfigParser.SafeConfigParser(
{'USER': os.getenv("USER", os.path.split(os.path.expanduser('~'))[-1])}
)
theano_cfg.read(config_files) theano_cfg.read(config_files)
# Having a raw version of the config around as well enables us to pass # Having a raw version of the config around as well enables us to pass
# through config values that contain format strings. # through config values that contain format strings.
...@@ -109,6 +121,7 @@ def fetch_val_for_key(key): ...@@ -109,6 +121,7 @@ def fetch_val_for_key(key):
_config_var_list = [] _config_var_list = []
def _config_print(thing, buf): def _config_print(thing, buf):
for cv in _config_var_list: for cv in _config_var_list:
print >> buf, cv print >> buf, cv
...@@ -134,6 +147,7 @@ def get_config_md5(): ...@@ -134,6 +147,7 @@ def get_config_md5():
class TheanoConfigParser(object): class TheanoConfigParser(object):
#properties are installed by AddConfigVar #properties are installed by AddConfigVar
_i_am_a_config_class = True _i_am_a_config_class = True
def __str__(self): def __str__(self):
sio = StringIO.StringIO() sio = StringIO.StringIO()
_config_print(self.__class__, sio) _config_print(self.__class__, sio)
...@@ -142,16 +156,19 @@ class TheanoConfigParser(object): ...@@ -142,16 +156,19 @@ class TheanoConfigParser(object):
# N.B. all instances of TheanoConfigParser give access to the same properties. # N.B. all instances of TheanoConfigParser give access to the same properties.
config = TheanoConfigParser() config = TheanoConfigParser()
#
# The data structure at work here is a tree of CLASSES with CLASS ATTRIBUTES/PROPERTIES that # The data structure at work here is a tree of CLASSES with
# are either a) INSTANTIATED dynamically-generated CLASSES, or b) ConfigParam instances. # CLASS ATTRIBUTES/PROPERTIES that are either a) INSTANTIATED
# The root of this tree is the TheanoConfigParser CLASS, and the internal nodes are the SubObj # dynamically-generated CLASSES, or b) ConfigParam instances. The root
# classes created inside of AddConfigVar(). # of this tree is the TheanoConfigParser CLASS, and the internal nodes
# are the SubObj classes created inside of AddConfigVar().
# Why this design ? # Why this design ?
# - The config object is a true singleton. Every instance of TheanoConfigParser is an empty # - The config object is a true singleton. Every instance of
# instance that looks up attributes/properties in the [single] TheanoConfigParser.__dict__ # TheanoConfigParser is an empty instance that looks up attributes/properties
# in the [single] TheanoConfigParser.__dict__
# - The subtrees provide the same interface as the root # - The subtrees provide the same interface as the root
# - ConfigParser subclasses control get/set of config properties to guard against craziness. # - ConfigParser subclasses control get/set of config properties to guard
# against craziness.
def AddConfigVar(name, doc, configparam, root=config, in_c_key=True): def AddConfigVar(name, doc, configparam, root=config, in_c_key=True):
"""Add a new variable to theano.config """Add a new variable to theano.config
...@@ -163,10 +180,12 @@ def AddConfigVar(name, doc, configparam, root=config, in_c_key=True): ...@@ -163,10 +180,12 @@ def AddConfigVar(name, doc, configparam, root=config, in_c_key=True):
:param doc: What does this variable specify? :param doc: What does this variable specify?
:type configparam: ConfigParam instance :type configparam: ConfigParam instance
:param configparam: an object for getting and setting this configuration parameter :param configparam: an object for getting and setting this configuration
parameter
:type root: object :type root: object
:param root: used for recusive calls -- do not provide an argument for this parameter. :param root: used for recusive calls -- do not provide an argument for
this parameter.
:type in_c_key: boolean :type in_c_key: boolean
:param in_c_key: If True, then whenever this config option changes, the :param in_c_key: If True, then whenever this config option changes, the
...@@ -178,7 +197,8 @@ def AddConfigVar(name, doc, configparam, root=config, in_c_key=True): ...@@ -178,7 +197,8 @@ def AddConfigVar(name, doc, configparam, root=config, in_c_key=True):
:returns: None :returns: None
""" """
# this method also performs some of the work of initializing ConfigParam instances # This method also performs some of the work of initializing ConfigParam
# instances
if root is config: if root is config:
#only set the name in the first call, not the recursive ones #only set the name in the first call, not the recursive ones
...@@ -187,21 +207,27 @@ def AddConfigVar(name, doc, configparam, root=config, in_c_key=True): ...@@ -187,21 +207,27 @@ def AddConfigVar(name, doc, configparam, root=config, in_c_key=True):
if len(sections) > 1: if len(sections) > 1:
# set up a subobject # set up a subobject
if not hasattr(root, sections[0]): if not hasattr(root, sections[0]):
# every internal node in the config tree is an instance of its own unique class # every internal node in the config tree is an instance of its own
# unique class
class SubObj(object): class SubObj(object):
_i_am_a_config_class = True _i_am_a_config_class = True
setattr(root.__class__, sections[0], SubObj()) setattr(root.__class__, sections[0], SubObj())
newroot = getattr(root, sections[0]) newroot = getattr(root, sections[0])
if not getattr(newroot, '_i_am_a_config_class', False) or isinstance(newroot, type): if (not getattr(newroot, '_i_am_a_config_class', False)
raise TypeError('Internal config nodes must be config class instances', newroot) or isinstance(newroot, type)):
raise TypeError(
'Internal config nodes must be config class instances',
newroot)
return AddConfigVar('.'.join(sections[1:]), doc, configparam, return AddConfigVar('.'.join(sections[1:]), doc, configparam,
root=newroot, in_c_key=in_c_key) root=newroot, in_c_key=in_c_key)
else: else:
if hasattr(root, name): if hasattr(root, name):
raise AttributeError('This name is already taken', configparam.fullname) raise AttributeError('This name is already taken',
configparam.fullname)
configparam.doc = doc configparam.doc = doc
configparam.in_c_key = in_c_key configparam.in_c_key = in_c_key
configparam.__get__() # trigger a read of the value from config files and env vars # trigger a read of the value from config files and env vars
configparam.__get__()
setattr(root.__class__, sections[0], configparam) setattr(root.__class__, sections[0], configparam)
_config_var_list.append(configparam) _config_var_list.append(configparam)
...@@ -210,8 +236,8 @@ class ConfigParam(object): ...@@ -210,8 +236,8 @@ class ConfigParam(object):
def __init__(self, default, filter=None, allow_override=True): def __init__(self, default, filter=None, allow_override=True):
""" """
If allow_override is False, we can't change the value after the import of Theano. If allow_override is False, we can't change the value after the import
So the value should be the same during all the execution of Theano. So the value should be the same during all the execution.
""" """
self.default = default self.default = default
self.filter = filter self.filter = filter
...@@ -239,7 +265,9 @@ class ConfigParam(object): ...@@ -239,7 +265,9 @@ class ConfigParam(object):
def __set__(self, cls, val): def __set__(self, cls, val):
if not self.allow_override and hasattr(self, 'val'): if not self.allow_override and hasattr(self, 'val'):
raise Exception("Can't change the value of this config parameter after initialization!") raise Exception(
"Can't change the value of this config parameter "
"after initialization!")
#print "SETTING PARAM", self.fullname,(cls), val #print "SETTING PARAM", self.fullname,(cls), val
if self.filter: if self.filter:
self.val = self.filter(val) self.val = self.filter(val)
...@@ -262,42 +290,59 @@ class EnumStr(ConfigParam): ...@@ -262,42 +290,59 @@ class EnumStr(ConfigParam):
if val in self.all: if val in self.all:
return val return val
else: else:
raise ValueError('Invalid value ("%s") for configuration variable "%s". Legal options are %s' raise ValueError((
% (val, self.fullname, self.all)) 'Invalid value ("%s") for configuration variable "%s". '
'Valid options are %s'
% (val, self.fullname, self.all)))
over = kwargs.get("allow_override", True) over = kwargs.get("allow_override", True)
super(EnumStr, self).__init__(default, filter, over) super(EnumStr, self).__init__(default, filter, over)
def __str__(self): def __str__(self):
return '%s (%s) ' % (self.fullname, self.all) return '%s (%s) ' % (self.fullname, self.all)
class TypedParam(ConfigParam): class TypedParam(ConfigParam):
def __init__(self, default, mytype, is_valid=None, allow_override=True): def __init__(self, default, mytype, is_valid=None, allow_override=True):
self.mytype = mytype self.mytype = mytype
def filter(val): def filter(val):
cast_val = mytype(val) cast_val = mytype(val)
if callable(is_valid): if callable(is_valid):
if is_valid(cast_val): if is_valid(cast_val):
return cast_val return cast_val
else: else:
raise ValueError('Invalid value (%s) for configuration variable "%s".' raise ValueError(
'Invalid value (%s) for configuration variable '
'"%s".'
% (val, self.fullname), val) % (val, self.fullname), val)
return cast_val return cast_val
super(TypedParam, self).__init__(default, filter, allow_override=allow_override)
super(TypedParam, self).__init__(default, filter,
allow_override=allow_override)
def __str__(self): def __str__(self):
return '%s (%s) ' % (self.fullname, self.mytype) return '%s (%s) ' % (self.fullname, self.mytype)
def StrParam(default, is_valid=None, allow_override=True): def StrParam(default, is_valid=None, allow_override=True):
return TypedParam(default, str, is_valid, allow_override=allow_override) return TypedParam(default, str, is_valid, allow_override=allow_override)
def IntParam(default, is_valid=None, allow_override=True): def IntParam(default, is_valid=None, allow_override=True):
return TypedParam(default, int, is_valid, allow_override=allow_override) return TypedParam(default, int, is_valid, allow_override=allow_override)
def FloatParam(default, is_valid=None, allow_override=True): def FloatParam(default, is_valid=None, allow_override=True):
return TypedParam(default, float, is_valid, allow_override=allow_override) return TypedParam(default, float, is_valid, allow_override=allow_override)
def BoolParam(default, is_valid=None, allow_override=True): def BoolParam(default, is_valid=None, allow_override=True):
#see comment at the beggining of this file. #see comment at the beginning of this file.
def booltype(s): def booltype(s):
if s in ['False','false','0', False]: if s in ['False', 'false', '0', False]:
return False return False
elif s in ['True','true','1', True]: elif s in ['True', 'true', '1', True]:
return True return True
def is_valid_bool(s): def is_valid_bool(s):
...@@ -305,6 +350,9 @@ def BoolParam(default, is_valid=None, allow_override=True): ...@@ -305,6 +350,9 @@ def BoolParam(default, is_valid=None, allow_override=True):
return True return True
else: else:
return False return False
if is_valid is None: if is_valid is None:
is_valid = is_valid_bool is_valid = is_valid_bool
return TypedParam(default, booltype, is_valid, allow_override=allow_override)
return TypedParam(default, booltype, is_valid,
allow_override=allow_override)
...@@ -12,6 +12,7 @@ from theano.compile.sharedvalue import SharedVariable ...@@ -12,6 +12,7 @@ from theano.compile.sharedvalue import SharedVariable
import logging import logging
logger = logging.getLogger('theano.updates') logger = logging.getLogger('theano.updates')
class Updates(dict): class Updates(dict):
""" """
Dict-like mapping from SharedVariable keys to their new values. Dict-like mapping from SharedVariable keys to their new values.
...@@ -30,7 +31,9 @@ class Updates(dict): ...@@ -30,7 +31,9 @@ class Updates(dict):
return super(Updates, self).__setitem__(key, value) return super(Updates, self).__setitem__(key, value)
else: else:
raise TypeError('Updates keys must inherit from SharedVariable', key) raise TypeError('Updates keys must inherit from SharedVariable',
key)
def update(self, other): def update(self, other):
for key, val in dict(other).iteritems(): for key, val in dict(other).iteritems():
if key in self: if key in self:
...@@ -50,4 +53,3 @@ class Updates(dict): ...@@ -50,4 +53,3 @@ class Updates(dict):
rval.update(other) rval.update(other)
rval.update(self) rval.update(self)
return rval return rval
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论