提交 d7963c11 authored 作者: lamblin's avatar lamblin

Merge pull request #1428 from nouiz/err_msg

Better error message when an apply node execution fail.
...@@ -1908,7 +1908,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -1908,7 +1908,7 @@ class _Linker(gof.link.LocalLinker):
try: try:
thunk_c() thunk_c()
except Exception: except Exception:
raise_with_op(node) raise_with_op(node, thunk_c)
for r in node.outputs: for r in node.outputs:
# check output values for type-correctness # check output values for type-correctness
...@@ -1958,7 +1958,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -1958,7 +1958,7 @@ class _Linker(gof.link.LocalLinker):
try: try:
thunk_c() thunk_c()
except Exception: except Exception:
raise_with_op(node) raise_with_op(node, thunk_c)
_logger.debug( _logger.debug(
'%i - calling _check_preallocated_output ' '%i - calling _check_preallocated_output '
'with thunk_c', i) 'with thunk_c', i)
......
...@@ -580,10 +580,18 @@ class Function(object): ...@@ -580,10 +580,18 @@ class Function(object):
outputs = self.fn() outputs = self.fn()
except Exception: except Exception:
if hasattr(self.fn, 'position_of_error'): if hasattr(self.fn, 'position_of_error'):
# this is a new vm-provided function # this is a new vm-provided function or c linker
# the C VM needs this because the exception manipulation # they need this because the exception manipulation
# done by raise_with_op is not implemented in C. # done by raise_with_op is not implemented in C.
gof.vm.raise_with_op(self.fn.nodes[self.fn.position_of_error]) if hasattr(self.fn, 'thunks'):
# For the CVM
gof.vm.raise_with_op(self.fn.nodes[self.fn.position_of_error],
self.fn.thunks[self.fn.position_of_error])
else:
# For the c linker
# We don't have access from python to all the temps values
# So for now, we just don't print the extra shapes/strides info
gof.vm.raise_with_op(self.fn.nodes[self.fn.position_of_error])
else: else:
# old-style linkers raise their own exceptions # old-style linkers raise their own exceptions
raise raise
......
...@@ -933,6 +933,7 @@ class CLinker(link.Linker): ...@@ -933,6 +933,7 @@ class CLinker(link.Linker):
keep_lock=keep_lock) keep_lock=keep_lock)
res = _CThunk(cthunk, init_tasks, tasks, error_storage) res = _CThunk(cthunk, init_tasks, tasks, error_storage)
res.nodes = self.node_order
return res, in_storage, out_storage return res, in_storage, out_storage
def cmodule_key(self): def cmodule_key(self):
...@@ -1391,11 +1392,9 @@ class _CThunk(object): ...@@ -1391,11 +1392,9 @@ class _CThunk(object):
trace = () trace = ()
try: try:
exc_type, _exc_value, exc_trace = self.error_storage exc_type, _exc_value, exc_trace = self.error_storage
if hasattr(task, "outputs"): self.position_of_error = self.nodes.index(task)
exc_value = exc_type(_exc_value, task, task.outputs)
else:
exc_value = exc_type(_exc_value, task)
# this can be used to retrieve the location the Op was declared # this can be used to retrieve the location the Op was declared
exc_value = exc_type(_exc_value)
exc_value.__thunk_trace__ = trace exc_value.__thunk_trace__ = trace
except Exception: except Exception:
print >> sys.stderr, ('ERROR retrieving error_storage.' print >> sys.stderr, ('ERROR retrieving error_storage.'
......
"""WRITEME""" """WRITEME"""
from copy import copy from copy import copy
import StringIO
import sys import sys
import traceback import traceback
...@@ -55,15 +56,15 @@ sys.excepthook = thunk_hook ...@@ -55,15 +56,15 @@ sys.excepthook = thunk_hook
# TODO: Make this work with linker defined schedule # TODO: Make this work with linker defined schedule
def raise_with_op(op, exc_info=None): def raise_with_op(op, thunk=None, exc_info=None):
""" """
Re-raise an exception while annotating the exception object with Re-raise an exception while annotating the exception object with
debug info. debug info.
Parameters Parameters
---------- ----------
op : object op : Apply node
The Op object that resulted in the raised exception. The Apply node object that resulted in the raised exception.
exc_info : tuple, optional exc_info : tuple, optional
A tuple containing the exception type, exception object and A tuple containing the exception type, exception object and
associated traceback, as would be returned by a call to associated traceback, as would be returned by a call to
...@@ -95,7 +96,10 @@ def raise_with_op(op, exc_info=None): ...@@ -95,7 +96,10 @@ def raise_with_op(op, exc_info=None):
try: try:
trace = op.tag.trace trace = op.tag.trace
except AttributeError: except AttributeError:
trace = () try:
trace = op.op.tag.trace
except AttributeError:
trace = ()
exc_value.__thunk_trace__ = trace exc_value.__thunk_trace__ = trace
exc_value.__op_instance__ = op exc_value.__op_instance__ = op
if op in op.fgraph.toposort(): if op in op.fgraph.toposort():
...@@ -108,6 +112,29 @@ def raise_with_op(op, exc_info=None): ...@@ -108,6 +112,29 @@ def raise_with_op(op, exc_info=None):
if raise_with_op.print_thunk_trace: if raise_with_op.print_thunk_trace:
log_thunk_trace(exc_value) log_thunk_trace(exc_value)
if theano.config.exception_verbosity == 'high':
f = StringIO.StringIO()
theano.printing.debugprint(op, file=f, stop_on_name=True)
if thunk is not None:
shapes = [getattr(ipt[0], 'shape', 'No shapes')
for ipt in thunk.inputs]
strides = [getattr(ipt[0], 'strides', 'No strides')
for ipt in thunk.inputs]
detailed_err_msg = ("\nInputs shapes: %s \n" % shapes +
"Inputs strides: %s \n" % strides +
"Debugprint of the apply node: \n" +
f.getvalue())
else:
detailed_err_msg = "\nDebugprint of the apply node: \n" + f.getvalue()
else:
detailed_err_msg = ("\nUse the Theano flag"
" 'exception_verbosity=high' for more"
" information on the inputs of this apply"
" node.")
exc_value = exc_type(str(exc_value) +
"\nApply node that caused the error: " + str(op) +
detailed_err_msg)
raise exc_type, exc_value, exc_trace raise exc_type, exc_value, exc_trace
raise_with_op.print_thunk_trace = False raise_with_op.print_thunk_trace = False
...@@ -360,7 +387,7 @@ def streamline(fgraph, thunks, order, post_thunk_old_storage=None, ...@@ -360,7 +387,7 @@ def streamline(fgraph, thunks, order, post_thunk_old_storage=None,
for old_s in old_storage: for old_s in old_storage:
old_s[0] = None old_s[0] = None
except Exception: except Exception:
raise_with_op(node) raise_with_op(node, thunk)
f = streamline_default_f f = streamline_default_f
elif nice_errors: elif nice_errors:
thunk_node_list = zip(thunks, order) thunk_node_list = zip(thunks, order)
...@@ -372,7 +399,7 @@ def streamline(fgraph, thunks, order, post_thunk_old_storage=None, ...@@ -372,7 +399,7 @@ def streamline(fgraph, thunks, order, post_thunk_old_storage=None,
for thunk, node in thunk_node_list: for thunk, node in thunk_node_list:
thunk() thunk()
except Exception: except Exception:
raise_with_op(node) raise_with_op(node, thunk)
f = streamline_nice_errors_f f = streamline_nice_errors_f
else: else:
# don't worry about raise_with_op, just go a little faster. # don't worry about raise_with_op, just go a little faster.
...@@ -666,7 +693,7 @@ class WrapLinker(Linker): ...@@ -666,7 +693,7 @@ class WrapLinker(Linker):
try: try:
wrapper(i, node, *thunks) wrapper(i, node, *thunks)
except Exception: except Exception:
raise_with_op(node) raise_with_op(node, thunk)
f.thunk_groups = thunk_groups f.thunk_groups = thunk_groups
return f, inputs0, outputs0 return f, inputs0, outputs0
......
...@@ -165,7 +165,7 @@ class Loop(VM): ...@@ -165,7 +165,7 @@ class Loop(VM):
self.call_counts[i] += 1 self.call_counts[i] += 1
self.call_times[i] += t1 - t0 self.call_times[i] += t1 - t0
except: except:
raise_with_op(node) raise_with_op(node, thunk)
else: else:
for cont in self.pre_call_clear: for cont in self.pre_call_clear:
cont[0] = None cont[0] = None
...@@ -173,7 +173,7 @@ class Loop(VM): ...@@ -173,7 +173,7 @@ class Loop(VM):
for thunk, node in zip(self.thunks, self.nodes): for thunk, node in zip(self.thunks, self.nodes):
thunk() thunk()
except: except:
raise_with_op(node) raise_with_op(node, thunk)
class LoopGC(VM): class LoopGC(VM):
...@@ -205,7 +205,7 @@ class LoopGC(VM): ...@@ -205,7 +205,7 @@ class LoopGC(VM):
old_s[0] = None old_s[0] = None
i += 1 i += 1
except: except:
raise_with_op(node) raise_with_op(node, thunk)
else: else:
for cont in self.pre_call_clear: for cont in self.pre_call_clear:
cont[0] = None cont[0] = None
...@@ -216,7 +216,7 @@ class LoopGC(VM): ...@@ -216,7 +216,7 @@ class LoopGC(VM):
for old_s in old_storage: for old_s in old_storage:
old_s[0] = None old_s[0] = None
except: except:
raise_with_op(node) raise_with_op(node, thunk)
class Stack(VM): class Stack(VM):
...@@ -400,7 +400,8 @@ class Stack(VM): ...@@ -400,7 +400,8 @@ class Stack(VM):
st = 'c' st = 'c'
self.variable_strides[var] = st self.variable_strides[var] = st
except Exception: except Exception:
raise_with_op(current_apply) raise_with_op(current_apply,
self.thunks[self.node_idx[current_apply]])
for o in current_apply.outputs: for o in current_apply.outputs:
compute_map[o][0] = 1 compute_map[o][0] = 1
if self.allow_gc: if self.allow_gc:
...@@ -458,7 +459,8 @@ class Stack(VM): ...@@ -458,7 +459,8 @@ class Stack(VM):
self.call_times[current_idx] += dt self.call_times[current_idx] += dt
except Exception: except Exception:
raise_with_op(current_apply) raise_with_op(current_apply,
self.thunks[self.node_idx[current_apply]])
if requires: if requires:
for r in requires: for r in requires:
......
"""Symbolic Op for raising an exception.""" """Symbolic Op for raising an exception."""
__authors__ = "James Bergstra" __authors__ = "James Bergstra"
__copyright__ = "(c) 2011, Universite de Montreal" __copyright__ = "(c) 2011, Universite de Montreal"
__license__ = "3-clause BSD License" __license__ = "3-clause BSD License"
__contact__ = "theano-dev <theano-dev@googlegroups.com>" __contact__ = "theano-dev <theano-dev@googlegroups.com>"
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
from theano import gof from theano import gof
class Raise(gof.Op): class Raise(gof.Op):
"""Op whose perform() raises an exception. """Op whose perform() raises an exception.
""" """
...@@ -18,19 +19,22 @@ class Raise(gof.Op): ...@@ -18,19 +19,22 @@ class Raise(gof.Op):
""" """
self.msg = msg self.msg = msg
self.exc = exc self.exc = exc
def __eq__(self, other): def __eq__(self, other):
# Note: the msg does not technically have to be in the hash and eq # Note: the msg does not technically have to be in the hash and eq
# because it doesn't affect the return value. # because it doesn't affect the return value.
return (type(self) == type(other) return (type(self) == type(other)
and self.msg == other.msg and self.msg == other.msg
and self.exc == other.exc) and self.exc == other.exc)
def __hash__(self): def __hash__(self):
return hash((type(self), self.msg, self.exc)) return hash((type(self), self.msg, self.exc))
def __str__(self): def __str__(self):
return "Raise{%s(%s)}"%(self.exc, self.msg) return "Raise{%s(%s)}" % (self.exc, self.msg)
def make_node(self, x): def make_node(self, x):
return gof.Apply(self, [x], [x.type()]) return gof.Apply(self, [x], [x.type()])
def perform(self, node, inputs, out_storage): def perform(self, node, inputs, out_storage):
raise self.exc(self.msg) raise self.exc(self.msg)
...@@ -974,10 +974,18 @@ class Scan(PureOp): ...@@ -974,10 +974,18 @@ class Scan(PureOp):
fn() fn()
except Exception: except Exception:
if hasattr(fn, 'position_of_error'): if hasattr(fn, 'position_of_error'):
# this is a new vm-provided function # this is a new vm-provided function or c linker
# the C VM needs this because the exception manipulation # they need this because the exception manipulation
# done by raise_with_op is not implemented in C. # done by raise_with_op is not implemented in C.
gof.vm.raise_with_op(fn.nodes[fn.position_of_error]) if hasattr(self.fn, 'thunks'):
# For the CVM
gof.vm.raise_with_op(self.fn.nodes[self.fn.position_of_error],
self.fn.thunks[self.fn.position_of_error])
else:
# For the c linker
# We don't have access from python to all the temps values
# So for now, we just don't print the extra shapes/strides info
gof.vm.raise_with_op(self.fn.nodes[self.fn.position_of_error])
else: else:
# old-style linkers raise their own exceptions # old-style linkers raise their own exceptions
raise raise
......
...@@ -7862,24 +7862,11 @@ class Dot(Op): ...@@ -7862,24 +7862,11 @@ class Dot(Op):
def perform(self, node, inp, out): def perform(self, node, inp, out):
x, y = inp x, y = inp
z, = out z, = out
try:
# the asarray is here because dot between two vectors # the asarray is here because dot between two vectors
# gives a numpy float object but we need to return a 0d # gives a numpy float object but we need to return a 0d
# ndarray # ndarray
z[0] = numpy.asarray(numpy.dot(x, y)) z[0] = numpy.asarray(numpy.dot(x, y))
except ValueError, e:
# The error raised by numpy has no shape information, we mean to
# add that
if config.exception_verbosity == 'high':
raise ValueError('dot product failed.\n'
'First arg dims: ' + str(x.shape) + '\n'
'Second arg dims: ' + str(y.shape) + '\n'
'First arg: \n' +
min_informative_str(node.inputs[0]) +
'\nSecond arg: \n' +
min_informative_str(node.inputs[1]))
e.args = e.args + (x.shape, y.shape)
raise
def grad(self, inp, grads): def grad(self, inp, grads):
......
...@@ -806,14 +806,7 @@ class Elemwise(Op): ...@@ -806,14 +806,7 @@ class Elemwise(Op):
base_exc_str = 'Dimension mismatch; shapes are %s' % ( base_exc_str = 'Dimension mismatch; shapes are %s' % (
', '.join(msg)) ', '.join(msg))
if config.exception_verbosity == 'high': raise ValueError(base_exc_str)
msg_chunks = [base_exc_str]
for i, ipt in enumerate(node.inputs):
msg_chunks.append('input %d: %s' %
(i, min_informative_str(ipt)))
raise ValueError('\n'.join(msg_chunks))
else:
raise ValueError(base_exc_str)
# Determine the shape of outputs # Determine the shape of outputs
out_shape = [] out_shape = []
...@@ -874,29 +867,7 @@ class Elemwise(Op): ...@@ -874,29 +867,7 @@ class Elemwise(Op):
self.scalar_op.nout)) self.scalar_op.nout))
nout = ufunc.nout nout = ufunc.nout
try: variables = ufunc(*ufunc_args)
variables = ufunc(*ufunc_args)
except Exception, e:
errormsg = ('While computing ' + str(node.outputs) +
': Failed calling ufunc for op ' +
str(self.scalar_op) +
' for params of shape ' +
str([arg.shape for arg in ufunc_args]))
if config.exception_verbosity == 'high':
errormsg += 'inputs are: \n'
for i, ipt in enumerate(node.inputs):
errormsg += '(' + str(i) + ') ' + \
min_informative_str(ipt) + '\n'
errormsg += 'outputs are: \n'
for i, output in enumerate(node.outputs):
errormsg += '(' + str(i) + ') ' + \
min_informative_str(output) + '\n'
errormsg += 'original exception was: ' + '\n'.join(
traceback.format_exception_only(*sys.exc_info()[0:2]))
e.args = e.args + (errormsg, )
raise
if nout == 1: if nout == 1:
variables = [variables] variables = [variables]
......
...@@ -2898,7 +2898,7 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -2898,7 +2898,7 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
self.eval_output_and_check(t) self.eval_output_and_check(t)
assert 0 assert 0
except Exception, e: except Exception, e:
if exc_message(e) != 'index out of bounds': if 'out of bounds' not in exc_message(e):
raise raise
finally: finally:
_logger.setLevel(oldlevel) _logger.setLevel(oldlevel)
...@@ -4907,9 +4907,6 @@ class t_dot(unittest.TestCase): ...@@ -4907,9 +4907,6 @@ class t_dot(unittest.TestCase):
# Reported by Theano perform # Reported by Theano perform
e0.split()[0:4] e0.split()[0:4]
== ['Incompatible', 'shapes', 'for', 'gemv'] or == ['Incompatible', 'shapes', 'for', 'gemv'] or
# Reported by Theano when 'exception_verbosity' is set
# to 'high'.
e0.split()[0:3] == ['dot', 'product', 'failed.'],
e) e)
finally: finally:
_logger.setLevel(oldlevel) _logger.setLevel(oldlevel)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论