提交 c3ddf8d6 authored 作者: Frederic's avatar Frederic

In cvm, vm, py, c|py and debugmode, print more info when the execution of an apply node crash.

We print the inputs shape and strides now. We can't do that with the c linker.
上级 d8b8a183
...@@ -1908,7 +1908,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -1908,7 +1908,7 @@ class _Linker(gof.link.LocalLinker):
try: try:
thunk_c() thunk_c()
except Exception: except Exception:
raise_with_op(node) raise_with_op(node, thunk_c)
for r in node.outputs: for r in node.outputs:
# check output values for type-correctness # check output values for type-correctness
...@@ -1958,7 +1958,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -1958,7 +1958,7 @@ class _Linker(gof.link.LocalLinker):
try: try:
thunk_c() thunk_c()
except Exception: except Exception:
raise_with_op(node) raise_with_op(node, thunk_c)
_logger.debug( _logger.debug(
'%i - calling _check_preallocated_output ' '%i - calling _check_preallocated_output '
'with thunk_c', i) 'with thunk_c', i)
......
...@@ -583,6 +583,14 @@ class Function(object): ...@@ -583,6 +583,14 @@ class Function(object):
# this is a new vm-provided function # this is a new vm-provided function
# the C VM needs this because the exception manipulation # the C VM needs this because the exception manipulation
# done by raise_with_op is not implemented in C. # done by raise_with_op is not implemented in C.
if hasattr(self.fn, 'thunks'):
# For the CVM
gof.vm.raise_with_op(self.fn.nodes[self.fn.position_of_error],
self.fn.thunks[self.fn.position_of_error])
else:
# For the c linker
# We don't have access from python to all the temps values
# So for now, we just don't print the extra shapes/strides info
gof.vm.raise_with_op(self.fn.nodes[self.fn.position_of_error]) gof.vm.raise_with_op(self.fn.nodes[self.fn.position_of_error])
else: else:
# old-style linkers raise their own exceptions # old-style linkers raise their own exceptions
......
"""WRITEME""" """WRITEME"""
from copy import copy from copy import copy
import StringIO
import sys import sys
import traceback import traceback
...@@ -55,7 +56,7 @@ sys.excepthook = thunk_hook ...@@ -55,7 +56,7 @@ sys.excepthook = thunk_hook
# TODO: Make this work with linker defined schedule # TODO: Make this work with linker defined schedule
def raise_with_op(op, exc_info=None): def raise_with_op(op, thunk=None, exc_info=None):
""" """
Re-raise an exception while annotating the exception object with Re-raise an exception while annotating the exception object with
debug info. debug info.
...@@ -94,6 +95,9 @@ def raise_with_op(op, exc_info=None): ...@@ -94,6 +95,9 @@ def raise_with_op(op, exc_info=None):
raise exc_type, exc_value, exc_trace raise exc_type, exc_value, exc_trace
try: try:
trace = op.tag.trace trace = op.tag.trace
except AttributeError:
try:
trace = op.op.tag.trace
except AttributeError: except AttributeError:
trace = () trace = ()
exc_value.__thunk_trace__ = trace exc_value.__thunk_trace__ = trace
...@@ -108,10 +112,29 @@ def raise_with_op(op, exc_info=None): ...@@ -108,10 +112,29 @@ def raise_with_op(op, exc_info=None):
if raise_with_op.print_thunk_trace: if raise_with_op.print_thunk_trace:
log_thunk_trace(exc_value) log_thunk_trace(exc_value)
if hasattr(op, "outputs"): if theano.config.exception_verbosity == 'high':
exc_value = exc_type(exc_value, op, op.outputs) f = StringIO.StringIO()
theano.printing.debugprint(op, file=f, stop_on_name=True)
if thunk is not None:
shapes = [getattr(ipt[0], 'shape', 'No shapes')
for ipt in thunk.inputs]
strides = [getattr(ipt[0], 'strides', 'No strides')
for ipt in thunk.inputs]
detailed_err_msg = ("\nInputs shapes: %s \n" % shapes +
"Inputs strides: %s \n" % strides +
"Debugprint of the apply node: \n" +
f.getvalue())
else:
detailed_err_msg = "\nDebugprint of the apply node: \n" + f.getvalue()
else: else:
exc_value = exc_type(exc_value, op) detailed_err_msg = ("\nUse the Theano flag"
" 'exception_verbosity=high' for more"
" information on the inputs of this apply"
" node.")
exc_value = exc_type(str(exc_value) +
"\nApply node that caused the error: " + str(op) +
detailed_err_msg)
raise exc_type, exc_value, exc_trace raise exc_type, exc_value, exc_trace
raise_with_op.print_thunk_trace = False raise_with_op.print_thunk_trace = False
...@@ -364,7 +387,7 @@ def streamline(fgraph, thunks, order, post_thunk_old_storage=None, ...@@ -364,7 +387,7 @@ def streamline(fgraph, thunks, order, post_thunk_old_storage=None,
for old_s in old_storage: for old_s in old_storage:
old_s[0] = None old_s[0] = None
except Exception: except Exception:
raise_with_op(node) raise_with_op(node, thunk)
f = streamline_default_f f = streamline_default_f
elif nice_errors: elif nice_errors:
thunk_node_list = zip(thunks, order) thunk_node_list = zip(thunks, order)
...@@ -376,7 +399,7 @@ def streamline(fgraph, thunks, order, post_thunk_old_storage=None, ...@@ -376,7 +399,7 @@ def streamline(fgraph, thunks, order, post_thunk_old_storage=None,
for thunk, node in thunk_node_list: for thunk, node in thunk_node_list:
thunk() thunk()
except Exception: except Exception:
raise_with_op(node) raise_with_op(node, thunk)
f = streamline_nice_errors_f f = streamline_nice_errors_f
else: else:
# don't worry about raise_with_op, just go a little faster. # don't worry about raise_with_op, just go a little faster.
...@@ -670,7 +693,7 @@ class WrapLinker(Linker): ...@@ -670,7 +693,7 @@ class WrapLinker(Linker):
try: try:
wrapper(i, node, *thunks) wrapper(i, node, *thunks)
except Exception: except Exception:
raise_with_op(node) raise_with_op(node, thunk)
f.thunk_groups = thunk_groups f.thunk_groups = thunk_groups
return f, inputs0, outputs0 return f, inputs0, outputs0
......
...@@ -165,7 +165,7 @@ class Loop(VM): ...@@ -165,7 +165,7 @@ class Loop(VM):
self.call_counts[i] += 1 self.call_counts[i] += 1
self.call_times[i] += t1 - t0 self.call_times[i] += t1 - t0
except: except:
raise_with_op(node) raise_with_op(node, thunk)
else: else:
for cont in self.pre_call_clear: for cont in self.pre_call_clear:
cont[0] = None cont[0] = None
...@@ -173,7 +173,7 @@ class Loop(VM): ...@@ -173,7 +173,7 @@ class Loop(VM):
for thunk, node in zip(self.thunks, self.nodes): for thunk, node in zip(self.thunks, self.nodes):
thunk() thunk()
except: except:
raise_with_op(node) raise_with_op(node, thunk)
class LoopGC(VM): class LoopGC(VM):
...@@ -205,7 +205,7 @@ class LoopGC(VM): ...@@ -205,7 +205,7 @@ class LoopGC(VM):
old_s[0] = None old_s[0] = None
i += 1 i += 1
except: except:
raise_with_op(node) raise_with_op(node, thunk)
else: else:
for cont in self.pre_call_clear: for cont in self.pre_call_clear:
cont[0] = None cont[0] = None
...@@ -216,7 +216,7 @@ class LoopGC(VM): ...@@ -216,7 +216,7 @@ class LoopGC(VM):
for old_s in old_storage: for old_s in old_storage:
old_s[0] = None old_s[0] = None
except: except:
raise_with_op(node) raise_with_op(node, thunk)
class Stack(VM): class Stack(VM):
...@@ -400,7 +400,8 @@ class Stack(VM): ...@@ -400,7 +400,8 @@ class Stack(VM):
st = 'c' st = 'c'
self.variable_strides[var] = st self.variable_strides[var] = st
except Exception: except Exception:
raise_with_op(current_apply) raise_with_op(current_apply,
self.thunks[self.node_idx[current_apply]])
for o in current_apply.outputs: for o in current_apply.outputs:
compute_map[o][0] = 1 compute_map[o][0] = 1
if self.allow_gc: if self.allow_gc:
...@@ -458,7 +459,8 @@ class Stack(VM): ...@@ -458,7 +459,8 @@ class Stack(VM):
self.call_times[current_idx] += dt self.call_times[current_idx] += dt
except Exception: except Exception:
raise_with_op(current_apply) raise_with_op(current_apply,
self.thunks[self.node_idx[current_apply]])
if requires: if requires:
for r in requires: for r in requires:
......
...@@ -974,10 +974,18 @@ class Scan(PureOp): ...@@ -974,10 +974,18 @@ class Scan(PureOp):
fn() fn()
except Exception: except Exception:
if hasattr(fn, 'position_of_error'): if hasattr(fn, 'position_of_error'):
# this is a new vm-provided function # this is a new vm-provided function or c linker
# the C VM needs this because the exception manipulation # they need this because the exception manipulation
# done by raise_with_op is not implemented in C. # done by raise_with_op is not implemented in C.
gof.vm.raise_with_op(fn.nodes[fn.position_of_error]) if hasattr(self.fn, 'thunks'):
# For the CVM
gof.vm.raise_with_op(self.fn.nodes[self.fn.position_of_error],
self.fn.thunks[self.fn.position_of_error])
else:
# For the c linker
# We don't have access from python to all the temps values
# So for now, we just don't print the extra shapes/strides info
gof.vm.raise_with_op(self.fn.nodes[self.fn.position_of_error])
else: else:
# old-style linkers raise their own exceptions # old-style linkers raise their own exceptions
raise raise
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论