提交 80a81928 authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: GitHub

Merge pull request #6150 from ReyhaneAskari/fix_stack_trace

Fix stack trace
...@@ -255,6 +255,21 @@ import theano and print the config variable, as in: ...@@ -255,6 +255,21 @@ import theano and print the config variable, as in:
not predictable, so if you are close to the peak memory usage, trying both not predictable, so if you are close to the peak memory usage, trying both
could give you a small gain. could give you a small gain.
.. attribute:: check_stack_trace
String value, either ``off``, ``log``, ``warn``, ``raise``
Default: ``off``
This is a flag for checking the stack trace during the optimization process.
If :attr:`check_stack_trace` is set to ``off``, no check is performed on the
stack trace. If :attr:`check_stack_trace` is set to ``log`` or ``warn``, a
dummy stack trace is inserted that indicates which optimization inserted the
variable that had an empty stack trace but, in ``warn`` a warning is also
printed.
If :attr:`check_stack_trace` is set to ``raise``, an exception is raised if a
stack trace is missing.
.. attribute:: openmp .. attribute:: openmp
Bool value: either ``True`` or ``False`` Bool value: either ``True`` or ``False``
......
...@@ -226,6 +226,16 @@ optdb.register('add_destroy_handler', AddDestroyHandler(), ...@@ -226,6 +226,16 @@ optdb.register('add_destroy_handler', AddDestroyHandler(),
optdb.register('merge3', gof.MergeOptimizer(), optdb.register('merge3', gof.MergeOptimizer(),
100, 'fast_run', 'merge') 100, 'fast_run', 'merge')
if theano.config.check_stack_trace in ['raise', 'warn', 'log']:
_tags = ('fast_run', 'fast_compile')
if theano.config.check_stack_trace == 'off':
_tags = ()
optdb.register('CheckStackTrace',
gof.CheckStackTraceOptimization(), -1, *_tags)
del _tags
class Mode(object): class Mode(object):
""" """
......
...@@ -1496,6 +1496,18 @@ AddConfigVar('cycle_detection', ...@@ -1496,6 +1496,18 @@ AddConfigVar('cycle_detection',
EnumStr('regular', 'fast'), EnumStr('regular', 'fast'),
in_c_key=False) in_c_key=False)
AddConfigVar('check_stack_trace',
"A flag for checking the stack trace during the optimization process. "
"default (off): does not check the stack trace of any optimization "
"log: inserts a dummy stack trace that identifies the optimization"
"that inserted the variable that had an empty stack trace."
"warn: prints a warning if a stack trace is missing and also a dummy"
"stack trace is inserted that indicates which optimization inserted"
"the variable that had an empty stack trace."
"raise: raises an exception if a stack trace is missing",
EnumStr('off', 'log', 'warn', 'raise'),
in_c_key=False)
def _timeout_default(): def _timeout_default():
return theano.config.compile.wait * 24 return theano.config.compile.wait * 24
......
...@@ -65,7 +65,7 @@ from theano.gof.opt import ( ...@@ -65,7 +65,7 @@ from theano.gof.opt import (
LocalOptimizer, local_optimizer, LocalOptGroup, LocalOptimizer, local_optimizer, LocalOptGroup,
OpSub, OpRemove, PatternSub, OpSub, OpRemove, PatternSub,
NavigatorOptimizer, TopoOptimizer, EquilibriumOptimizer, NavigatorOptimizer, TopoOptimizer, EquilibriumOptimizer,
OpKeyOptimizer) OpKeyOptimizer, CheckStackTraceOptimization)
from theano.gof.optdb import \ from theano.gof.optdb import \
DB, LocalGroupDB, Query, \ DB, LocalGroupDB, Query, \
......
...@@ -3038,3 +3038,34 @@ def check_stack_trace(f_or_fgraph, ops_to_check='last', bug_print='raise'): ...@@ -3038,3 +3038,34 @@ def check_stack_trace(f_or_fgraph, ops_to_check='last', bug_print='raise'):
return False return False
return True return True
class CheckStrackTraceFeature(object):
def on_import(self, fgraph, node, reason):
# In optdb we only register the CheckStackTraceOptimization when
# theano.config.check_stack_trace is not off but we also double check here.
if theano.config.check_stack_trace != 'off' and not check_stack_trace(fgraph, 'all'):
if theano.config.check_stack_trace == 'raise':
raise AssertionError(
'Empty stack trace! The optimization that inserted this variable is ' + str(reason))
elif theano.config.check_stack_trace in ['log', 'warn']:
apply_nodes_to_check = fgraph.apply_nodes
for node in apply_nodes_to_check:
for output in node.outputs:
if not hasattr(output.tag, 'trace') or not output.tag.trace:
output.tag.trace = [[('', 0, 'Empty stack trace! The optimization that' +
'inserted this variable is ' + str(reason), '')]]
if theano.config.check_stack_trace == 'warn':
warnings.warn(
'Empty stack trace! The optimization that inserted this variable is' + str(reason))
class CheckStackTraceOptimization(Optimizer):
"""Optimizer that serves to add CheckStackTraceOptimization as an fgraph feature."""
def add_requirements(self, fgraph):
if not hasattr(fgraph, 'CheckStrackTraceFeature'):
fgraph.attach_feature(CheckStrackTraceFeature())
def apply(self, fgraph):
pass
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论