提交 3dafd49d authored 作者: Frederic Bastien's avatar Frederic Bastien

Better error message when the test_value do not have the right type

fix gh-1372
上级 565fd3eb
......@@ -16,8 +16,10 @@ import inspect
import logging
import numpy
import os
import sys
import re
import StringIO
import sys
import traceback
import warnings
import theano
......@@ -448,7 +450,31 @@ class PureOp(object):
return v.get_value(borrow=True, return_internal_type=True)
elif isinstance(v, graph.Variable) and hasattr(v.tag, 'test_value'):
# ensure that the test value is correct
return v.type.filter(v.tag.test_value)
try:
ret = v.type.filter(v.tag.test_value)
except Exception, e:
# Better error message.
detailed_err_msg = (
"For compute_test_value, one input test value do not"
" have the requested type.\n")
tr = getattr(v.tag, 'trace', None)
if tr:
sio = StringIO.StringIO()
traceback.print_list(tr, sio)
tr = sio.getvalue()
detailed_err_msg += (
" \nBacktrace when that variable is created:\n")
detailed_err_msg += str(tr)
detailed_err_msg += (
"\nThe error when converting the test value to that"
" variable type:")
# We need to only have 1 args and it should be of type
# string. Otherwise, it print the tuple and so the
# new line do not get printed.
args = (detailed_err_msg,) + tuple(str(arg) for arg in e.args)
e.args = ("\n".join(args),)
raise
return ret
raise AttributeError('%s has no test value' % v)
......
......@@ -307,8 +307,7 @@ class PureType(object):
def make_constant(self, value, name=None):
return self.Constant(type=self, data=value, name=name)
def __call__(self, name = None):
def __call__(self, name=None, limit=None):
"""Return a new `Variable` instance of Type `self`.
:Parameters:
......@@ -316,7 +315,7 @@ class PureType(object):
A pretty string for printing and debugging.
"""
return utils.add_tag_trace(self.make_variable(name))
return utils.add_tag_trace(self.make_variable(name), limit=limit)
def values_eq(self, a, b):
"""
......
......@@ -50,12 +50,16 @@ if sys.version_info[:2] > (3, 4):
simple_extract_stack = traceback.extract_stack
def add_tag_trace(thing):
def add_tag_trace(thing, limit=None):
"""Add tag.trace to an node or variable.
The argument is returned after being affected (inplace).
:param thing: the object where we add .tag.trace
:param limit: The limit of the stack size.
If None use, config.traceback.limit
"""
limit = config.traceback.limit
if limit is None:
limit = config.traceback.limit
if limit == -1:
limit = None
tr = simple_extract_stack(limit=limit)[:-1]
......
......@@ -744,7 +744,12 @@ def get_scalar_constant_value(orig_v, elemwise=True):
def tensor(*args, **kwargs):
name = kwargs.pop('name', None)
return TensorType(*args, **kwargs).make_variable(name=name)
# This add an indirection to the normal call stack. So raise the
# limit to keep the good user line.
limit = config.traceback.limit
if limit != -1:
limit += 1
return TensorType(*args, **kwargs)(name=name, limit=limit)
def _multi(*fns):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论