提交 8fc19902 authored 作者: abergeron's avatar abergeron

Merge pull request #2504 from nouiz/err_msg

Better error message when the test_value do not have the right type
......@@ -295,7 +295,8 @@ AddConfigVar('traceback.limit',
"The number of stack to trace. -1 mean all.",
# We default to 6 to be able to know where v1 + v2 is created in the
# user script. The bigger this number is, the more run time it takes.
IntParam(6),
# We need to default to 7 to support theano.tensor.tensor(...).
IntParam(7),
in_c_key=False)
AddConfigVar('experimental.mrg',
......
......@@ -16,8 +16,10 @@ import inspect
import logging
import numpy
import os
import sys
import re
import StringIO
import sys
import traceback
import warnings
import theano
......@@ -448,7 +450,31 @@ class PureOp(object):
return v.get_value(borrow=True, return_internal_type=True)
elif isinstance(v, graph.Variable) and hasattr(v.tag, 'test_value'):
# ensure that the test value is correct
return v.type.filter(v.tag.test_value)
try:
ret = v.type.filter(v.tag.test_value)
except Exception, e:
# Better error message.
detailed_err_msg = (
"For compute_test_value, one input test value does not"
" have the requested type.\n")
tr = getattr(v.tag, 'trace', None)
if tr:
sio = StringIO.StringIO()
traceback.print_list(tr, sio)
tr = sio.getvalue()
detailed_err_msg += (
" \nBacktrace when that variable is created:\n")
detailed_err_msg += str(tr)
detailed_err_msg += (
"\nThe error when converting the test value to that"
" variable type:")
# We need to only have 1 args and it should be of type
# string. Otherwise, it print the tuple and so the
# new line do not get printed.
args = (detailed_err_msg,) + tuple(str(arg) for arg in e.args)
e.args = ("\n".join(args),)
raise
return ret
raise AttributeError('%s has no test value' % v)
......
......@@ -307,8 +307,7 @@ class PureType(object):
def make_constant(self, value, name=None):
return self.Constant(type=self, data=value, name=name)
def __call__(self, name = None):
def __call__(self, name=None):
"""Return a new `Variable` instance of Type `self`.
:Parameters:
......
......@@ -50,10 +50,16 @@ if sys.version_info[:2] > (3, 4):
simple_extract_stack = traceback.extract_stack
def add_tag_trace(thing):
def add_tag_trace(thing, user_line=1):
"""Add tag.trace to an node or variable.
The argument is returned after being affected (inplace).
:param thing: the object where we add .tag.trace
:param user_line: The max number of user line to keep.
:note: we alse use config.traceback.limit for the maximum number
of stack level we look.
"""
limit = config.traceback.limit
if limit == -1:
......@@ -68,14 +74,21 @@ def add_tag_trace(thing):
file_path = tr[-1][0]
rm = False
for p in ["theano/tensor/",
"theano/gof/"]:
"theano/gof/",
"theano/scalar/basic.py",
"theano/sandbox/",
"theano/scan_module/",
"theano/sparse/",
"theano/typed_list/",
]:
if p in file_path:
tr = tr[:-1]
rm = True
break
if not rm:
break
if len(tr) > user_line:
tr = tr[:user_line]
thing.tag.trace = tr
return thing
......
......@@ -744,7 +744,7 @@ def get_scalar_constant_value(orig_v, elemwise=True):
def tensor(*args, **kwargs):
name = kwargs.pop('name', None)
return TensorType(*args, **kwargs).make_variable(name=name)
return TensorType(*args, **kwargs)(name=name)
def _multi(*fns):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论