提交 8cf75dc8 authored 作者: Frederic's avatar Frederic

allow floatX={32,64} in the user config code.

上级 6aa85f9e
...@@ -11,10 +11,18 @@ _logger = logging.getLogger('theano.configdefaults') ...@@ -11,10 +11,18 @@ _logger = logging.getLogger('theano.configdefaults')
config = TheanoConfigParser() config = TheanoConfigParser()
def floatX_convert(s):
if s == "32":
return "float32"
elif s == "64":
return "float64"
else:
return s
AddConfigVar('floatX', AddConfigVar('floatX',
"Default floating-point precision for python casts", "Default floating-point precision for python casts",
EnumStr('float64', 'float32'), EnumStr('float64', 'float32', convert=floatX_convert,),
) )
AddConfigVar('cast_policy', AddConfigVar('cast_policy',
"Rules for implicit type casting", "Rules for implicit type casting",
......
...@@ -305,7 +305,11 @@ class EnumStr(ConfigParam): ...@@ -305,7 +305,11 @@ class EnumStr(ConfigParam):
raise ValueError('Valid values for an EnumStr parameter ' raise ValueError('Valid values for an EnumStr parameter '
'should be strings', val, type(val)) 'should be strings', val, type(val))
convert = kwargs.get("convert", None)
def filter(val): def filter(val):
if convert:
val = convert(val)
if val in self.all: if val in self.all:
return val return val
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论