提交 cc0db312 authored 作者: Frederic Bastien's avatar Frederic Bastien 提交者: Samira Shabanian

Refactor how we print variable stacktrace and fix errors with trace in unpickled graph.

上级 037fd298
...@@ -13,7 +13,6 @@ import numpy ...@@ -13,7 +13,6 @@ import numpy
import os import os
import re import re
import sys import sys
import traceback
import warnings import warnings
import theano import theano
...@@ -21,7 +20,6 @@ from theano import config ...@@ -21,7 +20,6 @@ from theano import config
import theano.gof.cc import theano.gof.cc
from six import itervalues from six import itervalues
from six.moves import StringIO
from theano.gof import graph from theano.gof import graph
from theano.gof import utils from theano.gof import utils
from theano.gof.cmodule import GCC_compiler from theano.gof.cmodule import GCC_compiler
...@@ -554,16 +552,7 @@ class PureOp(object): ...@@ -554,16 +552,7 @@ class PureOp(object):
detailed_err_msg = ( detailed_err_msg = (
"For compute_test_value, one input test value does not" "For compute_test_value, one input test value does not"
" have the requested type.\n") " have the requested type.\n")
tr = getattr(v.tag, 'trace', []) detailed_err_msg += utils.get_variable_trace_string(v)
if isinstance(tr, list) and len(tr) > 0:
detailed_err_msg += (
" \nBacktrace when that variable is created:\n")
# Print separate message for each element in the list
# of batcktraces
sio = StringIO()
for subtr in tr:
traceback.print_list(subtr, sio)
detailed_err_msg += str(sio.getvalue())
detailed_err_msg += ( detailed_err_msg += (
"\nThe error when converting the test value to that" "\nThe error when converting the test value to that"
...@@ -575,11 +564,7 @@ class PureOp(object): ...@@ -575,11 +564,7 @@ class PureOp(object):
e.args = ("\n".join(args),) e.args = ("\n".join(args),)
raise raise
return ret return ret
detailed_err_msg = utils.get_variable_trace_string(v)
sio = StringIO()
for subtr in getattr(v.tag, "trace", []):
traceback.print_list(subtr, sio)
detailed_err_msg = sio.getvalue()
raise AttributeError('%s has no test value %s' % (v, detailed_err_msg)) raise AttributeError('%s has no test value %s' % (v, detailed_err_msg))
def __call__(self, *inputs, **kwargs): def __call__(self, *inputs, **kwargs):
...@@ -634,10 +619,7 @@ class PureOp(object): ...@@ -634,10 +619,7 @@ class PureOp(object):
(i, ins, node), stacklevel=2) (i, ins, node), stacklevel=2)
run_perform = False run_perform = False
elif config.compute_test_value == 'raise': elif config.compute_test_value == 'raise':
sio = StringIO() detailed_err_msg = utils.get_variable_trace_string(ins)
for subtr in getattr(ins.tag, "trace", []):
traceback.print_list(subtr, sio)
detailed_err_msg = sio.getvalue()
raise ValueError( raise ValueError(
'Cannot compute test value: input %i (%s) of Op %s missing default value. %s' % 'Cannot compute test value: input %i (%s) of Op %s missing default value. %s' %
......
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import linecache import linecache
import sys import sys
import traceback
import numpy import numpy
from six import iteritems, integer_types, string_types from six import iteritems, integer_types, string_types
from six.moves import StringIO
from theano import config from theano import config
from theano.compat import OrderedDict, PY3 from theano.compat import OrderedDict, PY3
...@@ -112,6 +114,23 @@ def add_tag_trace(thing, user_line=None): ...@@ -112,6 +114,23 @@ def add_tag_trace(thing, user_line=None):
return thing return thing
def get_variable_trace_string(v):
sio = StringIO()
# For backward compatibility with old trace
tr = getattr(v.tag, 'trace', [])
if isinstance(tr, list) and len(tr) > 0:
print(" \nBacktrace when that variable is created:\n", file=sio)
# The isinstance is needed to handle old pickled trace
if isinstance(tr[0], tuple):
traceback.print_list(v.tag.trace, sio)
else:
# Print separate message for each element in the list of
# batcktraces
for subtr in tr:
traceback.print_list(subtr, sio)
return sio.getvalue()
def hashtype(self): def hashtype(self):
t = type(self) t = type(self)
return hash(t.__name__) ^ hash(t.__module__) return hash(t.__name__) ^ hash(t.__module__)
......
...@@ -3,17 +3,15 @@ from __future__ import absolute_import, print_function, division ...@@ -3,17 +3,15 @@ from __future__ import absolute_import, print_function, division
import six.moves.builtins as builtins import six.moves.builtins as builtins
import logging import logging
import time import time
import traceback
import warnings import warnings
import numpy # for numeric_grad import numpy # for numeric_grad
from six import itervalues from six import itervalues
from six.moves import StringIO
import theano import theano
from theano import gof from theano import gof
from theano.gof import Variable from theano.gof import utils, Variable
from theano.compat import OrderedDict, izip from theano.compat import OrderedDict, izip
from six.moves import xrange, reduce from six.moves import xrange, reduce
from theano.gof.null_type import NullType, null_type from theano.gof.null_type import NullType, null_type
...@@ -518,17 +516,7 @@ def grad(cost, wrt, consider_constant=None, ...@@ -518,17 +516,7 @@ def grad(cost, wrt, consider_constant=None,
elif disconnected_inputs == 'warn': elif disconnected_inputs == 'warn':
warnings.warn(message, stacklevel=2) warnings.warn(message, stacklevel=2)
elif disconnected_inputs == 'raise': elif disconnected_inputs == 'raise':
# Add the var trace message = utils.get_variable_trace_string(var)
tr = getattr(var.tag, 'trace', [])
if len(tr) > 0:
message += "\nBacktrace when the node is created:\n"
# Print separate message for each element in the list of batcktraces
sio = StringIO()
for subtr in tr:
traceback.print_list(subtr, sio)
message += str(sio.getvalue())
raise DisconnectedInputError(message) raise DisconnectedInputError(message)
else: else:
raise ValueError("Invalid value for keyword " raise ValueError("Invalid value for keyword "
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论