提交 25f9cdf4 authored 作者: Frederic's avatar Frederic

Make the inputs shape and strides printed by default when there is an error.

Also print the inputs type. exception_verbosity now only add the debugprint that now also have the print_type parameter
上级 e27ddec5
...@@ -112,29 +112,32 @@ def raise_with_op(op, thunk=None, exc_info=None): ...@@ -112,29 +112,32 @@ def raise_with_op(op, thunk=None, exc_info=None):
if raise_with_op.print_thunk_trace: if raise_with_op.print_thunk_trace:
log_thunk_trace(exc_value) log_thunk_trace(exc_value)
if theano.config.exception_verbosity == 'high': detailed_err_msg = "\nApply node that caused the error: " + str(op)
f = StringIO.StringIO()
theano.printing.debugprint(op, file=f, stop_on_name=True)
if thunk is not None: if thunk is not None:
shapes = [getattr(ipt[0], 'shape', 'No shapes') shapes = [getattr(ipt[0], 'shape', 'No shapes')
for ipt in thunk.inputs] for ipt in thunk.inputs]
strides = [getattr(ipt[0], 'strides', 'No strides') strides = [getattr(ipt[0], 'strides', 'No strides')
for ipt in thunk.inputs] for ipt in thunk.inputs]
detailed_err_msg = ("\nInputs shapes: %s \n" % shapes + types = [getattr(ipt[0], 'type', 'No type')
"Inputs strides: %s \n" % strides + for ipt in op.inputs]
"Debugprint of the apply node: \n" + detailed_err_msg += ("\nInputs shapes: %s" % shapes +
f.getvalue()) "\nInputs strides: %s" % strides +
"\nInputs types: %s" % types)
else: else:
detailed_err_msg = "\nDebugprint of the apply node: \n" + f.getvalue() detailed_err_msg += ("\nUse another linker then the c linker to"
" have the inputs shapes and strides printed.")
if theano.config.exception_verbosity == 'high':
f = StringIO.StringIO()
theano.printing.debugprint(op, file=f, stop_on_name=True,
print_type=True)
detailed_err_msg += "\nDebugprint of the apply node: \n" + f.getvalue()
else: else:
detailed_err_msg = ("\nUse the Theano flag" detailed_err_msg += ("\nUse the Theano flag 'exception_verbosity=high'"
" 'exception_verbosity=high' for more" " for a debugprint of this apply node.")
" information on the inputs of this apply"
" node.") exc_value = exc_type(str(exc_value) + detailed_err_msg)
exc_value = exc_type(str(exc_value) +
"\nApply node that caused the error: " + str(op) +
detailed_err_msg)
raise exc_type, exc_value, exc_trace raise exc_type, exc_value, exc_trace
raise_with_op.print_thunk_trace = False raise_with_op.print_thunk_trace = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论