提交 fdbe417b authored 作者: Frederic's avatar Frederic 提交者: Arnaud Bergeron

Make flag mode=NanGuardMode work and make it user provided optimizer

上级 5922a930
......@@ -387,12 +387,17 @@ def get_mode(orig_string):
default_mode_class):
return instanciated_default_mode
if string in ['Mode', 'ProfileMode', 'DebugMode']:
if string in ['Mode', 'ProfileMode', 'DebugMode', 'NanGuardMode']:
if string == 'DebugMode':
# need to import later to break circular dependency.
from .debugmode import DebugMode
# DebugMode use its own linker.
ret = DebugMode(optimizer=config.optimizer)
elif string == 'NanGuardMode':
# need to import later to break circular dependency.
from .nanguardmode import NanGuardMode
# DebugMode use its own linker.
ret = NanGuardMode(True, True, True, optimizer=config.optimizer)
else:
# This might be required if the string is 'ProfileMode'
from .profilemode import ProfileMode # noqa
......
......@@ -110,9 +110,15 @@ class NanGuardMode(Mode):
big_is_error : bool
If True, raise an error when a value greater than 1e10 is encountered.
Note
----
We ignore the linker parameter
"""
def __init__(self, nan_is_error, inf_is_error, big_is_error=True):
# We currently loose the 3 first param freuquently, when calling
# mode.including() and variant.
def __init__(self, nan_is_error=True, inf_is_error=True, big_is_error=True,
optimizer=None, linker=None):
self.provided_optimizer = optimizer
cuda_compile_failed = False
if cuda.cuda_available:
self.guard_input = cuda.fvector('nan_guard')
......@@ -246,4 +252,4 @@ class NanGuardMode(Mode):
wrap_linker = theano.gof.WrapLinker([theano.gof.OpWiseCLinker()],
nan_check)
super(NanGuardMode, self).__init__(wrap_linker,
optimizer=theano.config.optimizer)
optimizer=self.provided_optimizer)
......@@ -150,6 +150,7 @@ AddConfigVar(
'mode',
"Default compilation mode",
EnumStr('Mode', 'ProfileMode', 'DebugMode', 'FAST_RUN',
'NanGuardMode',
'FAST_COMPILE', 'PROFILE_MODE', 'DEBUG_MODE'),
in_c_key=False)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论