提交 71f727da authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Brandon T. Willard

Generalize excepthook registration

上级 5bb459cf
import sys from theano.link.basic import (
from theano.link.basic import (
Container, Container,
Linker, Linker,
LocalLinker, LocalLinker,
...@@ -11,7 +9,7 @@ from theano.link.basic import ( ...@@ -11,7 +9,7 @@ from theano.link.basic import (
map_storage, map_storage,
streamline, streamline,
) )
from theano.link.debugging import raise_with_op, set_excepthook from theano.link.debugging import raise_with_op, register_thunk_trace_excepthook
set_excepthook(handler=sys.stdout) register_thunk_trace_excepthook()
...@@ -6,74 +6,10 @@ from operator import itemgetter ...@@ -6,74 +6,10 @@ from operator import itemgetter
import numpy as np import numpy as np
from theano import config from theano import config, utils
from theano.gof.fg import FunctionGraph from theano.gof.fg import FunctionGraph
def __log_thunk_trace(value, handler):
"""
Log Theano's diagnostic stack trace for an exception
raised by raise_with_op.
"""
def write(msg):
print(f"log_thunk_trace: {msg.strip()}", file=handler)
if hasattr(value, "__thunk_trace__"):
trace2 = value.__thunk_trace__
write("There was a problem executing an Op.")
if trace2 is None:
write("Could not find where this Op was defined.")
write(
" * You might have instantiated this Op "
"directly instead of using a constructor."
)
write(
" * The Op you constructed might have been"
" optimized. Try turning off optimizations."
)
elif trace2:
write("Definition in: ")
for line in traceback.format_list(trace2):
write(line)
write(
"For the full definition stack trace set"
" the Theano flags traceback__limit to -1"
)
def set_excepthook(handler: io.TextIOWrapper):
def thunk_hook(type, value, trace):
"""
This function is meant to replace excepthook and do some
special work if the exception value has a __thunk_trace__
field.
In that case, it retrieves the field, which should
contain a trace as returned by L{traceback.extract_stack},
and prints it out on L{stderr}.
The normal excepthook is then called.
Parameters:
----------
type
Exception class
value
Exception instance
trace
Traceback object
Notes
-----
This hook replaced in testing, so it does not run.
"""
__log_thunk_trace(value, handler=handler)
sys.__excepthook__(type, value, trace)
sys.excepthook = thunk_hook
def raise_with_op( def raise_with_op(
fgraph: FunctionGraph, node, thunk=None, exc_info=None, storage_map=None fgraph: FunctionGraph, node, thunk=None, exc_info=None, storage_map=None
): ):
...@@ -334,3 +270,54 @@ def raise_with_op( ...@@ -334,3 +270,54 @@ def raise_with_op(
# Some exception need extra parameter in inputs. So forget the # Some exception need extra parameter in inputs. So forget the
# extra long error message in that case. # extra long error message in that case.
raise exc_value.with_traceback(exc_trace) raise exc_value.with_traceback(exc_trace)
def __log_thunk_trace(value, handler: io.TextIOWrapper):
"""
Log Theano's diagnostic stack trace for an exception.
Uses custom attributes that are added to trace objects by raise_with_op.
"""
def write(msg):
print(f"log_thunk_trace: {msg.strip()}", file=handler)
if hasattr(value, "__thunk_trace__"):
trace2 = value.__thunk_trace__
write("There was a problem executing an Op.")
if trace2 is None:
write("Could not find where this Op was defined.")
write(
" * You might have instantiated this Op "
"directly instead of using a constructor."
)
write(
" * The Op you constructed might have been"
" optimized. Try turning off optimizations."
)
elif trace2:
write("Definition in: ")
for line in traceback.format_list(trace2):
write(line)
write(
"For the full definition stack trace set"
" the Theano flags traceback__limit to -1"
)
def register_thunk_trace_excepthook(handler: io.TextIOWrapper = sys.stdout):
"""Adds the __log_thunk_trace except hook to the collection in theano.utils.
Parameters
----------
handler : TextIOWrapper
Target for printing the output.
"""
def wrapper(type, value, trace):
__log_thunk_trace(value, handler)
utils.add_excepthook(wrapper)
register_thunk_trace_excepthook()
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import inspect import inspect
import os import os
import subprocess import subprocess
import sys
import traceback import traceback
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
...@@ -23,6 +24,49 @@ __all__ = [ ...@@ -23,6 +24,49 @@ __all__ = [
] ]
__excepthooks = []
def __call_excepthooks(type, value, trace):
"""
This function is meant to replace excepthook and do some
special work if the exception value has a __thunk_trace__
field.
In that case, it retrieves the field, which should
contain a trace as returned by L{traceback.extract_stack},
and prints it out on L{stderr}.
The normal excepthook is then called.
Parameters:
----------
type
Exception class
value
Exception instance
trace
Traceback object
Notes
-----
This hook replaced in testing, so it does not run.
"""
for hook in __excepthooks:
hook(type, value, trace)
sys.__excepthook__(type, value, trace)
def add_excepthook(hook):
"""Adds an excepthook to a list of excepthooks that are called
when an unhandled exception happens.
See https://docs.python.org/3/library/sys.html#sys.excepthook for signature info.
"""
__excepthooks.append(hook)
sys.excepthook = __call_excepthooks
def exc_message(e): def exc_message(e):
""" """
In python 3.x, when an exception is reraised it saves original In python 3.x, when an exception is reraised it saves original
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论