提交 d1d31a31 authored 作者: Frederic Bastien's avatar Frederic Bastien

symplify the managing of mode.

上级 5b8a6be8
...@@ -681,8 +681,7 @@ class FunctionMaker(object): ...@@ -681,8 +681,7 @@ class FunctionMaker(object):
in the graph from the inputs to the outputs in the graph from the inputs to the outputs
""" """
if mode is None: mode = mode_module.get_mode(mode)
mode = mode_module.default_mode
# Handle the case where inputs and/or outputs is a single Variable (not in a list) # Handle the case where inputs and/or outputs is a single Variable (not in a list)
unpack_single = False unpack_single = False
...@@ -706,8 +705,7 @@ class FunctionMaker(object): ...@@ -706,8 +705,7 @@ class FunctionMaker(object):
env, additional_outputs = std_env(expanded_inputs, outputs, accept_inplace) env, additional_outputs = std_env(expanded_inputs, outputs, accept_inplace)
self.env = env self.env = env
# Fetch the mode and then the optimizer and linker # Fetch the optimizer and linker
mode = mode_module.predefined_modes.get(mode, mode)
optimizer, linker = mode.optimizer, copy.copy(mode.linker) optimizer, linker = mode.optimizer, copy.copy(mode.linker)
# optimize the env # optimize the env
...@@ -870,8 +868,7 @@ def function(inputs, outputs, mode=None, accept_inplace = False): ...@@ -870,8 +868,7 @@ def function(inputs, outputs, mode=None, accept_inplace = False):
#`Out` instance if necessary: #`Out` instance if necessary:
t1 = time.time() t1 = time.time()
if mode is None: mode = mode_module.get_mode(mode)
mode = mode_module.default_mode
inputs = map(convert_function_input, inputs) inputs = map(convert_function_input, inputs)
if outputs is not None: if outputs is not None:
...@@ -882,7 +879,6 @@ def function(inputs, outputs, mode=None, accept_inplace = False): ...@@ -882,7 +879,6 @@ def function(inputs, outputs, mode=None, accept_inplace = False):
defaults = [getattr(input, 'value', None) for input in inputs] defaults = [getattr(input, 'value', None) for input in inputs]
mode = mode_module.predefined_modes.get(mode, mode)
if isinstance(mode, (list, tuple)): # "mode comparison" semantics if isinstance(mode, (list, tuple)): # "mode comparison" semantics
_logger.warning('Passing multiple modes is deprecated (20091019)') _logger.warning('Passing multiple modes is deprecated (20091019)')
if not mode: if not mode:
......
...@@ -214,10 +214,15 @@ predefined_modes = {'FAST_COMPILE': FAST_COMPILE, ...@@ -214,10 +214,15 @@ predefined_modes = {'FAST_COMPILE': FAST_COMPILE,
## ##
default_mode = config.THEANO_DEFAULT_MODE default_mode = config.THEANO_DEFAULT_MODE
def get_mode(string):
if string is None: string = default_mode
if not isinstance(string, str): return string #it is already a mode...
if not predefined_modes.has_key(string):
raise Exception("No predefixed mode exist for string: %s"%string)
return predefined_modes[string]
def get_default_mode(): def get_default_mode():
if not predefined_modes.has_key(default_mode): return get_mode(default_mode)
raise Exception("No predefixed mode exist for string: %s"%default_mode)
return predefined_modes[default_mode]
def register_mode(name, mode): def register_mode(name, mode):
"""Add a `Mode` which can be referred to by `name` in `function`.""" """Add a `Mode` which can be referred to by `name` in `function`."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论